Skip to content

Commit

Permalink
sql/mysql: mysql implementation for schema.Normalizer
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Feb 24, 2022
1 parent d8b6015 commit 30c43c0
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
101 changes: 101 additions & 0 deletions sql/mysql/normalize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.

package mysql

import (
"context"
"crypto/md5"
"fmt"
"time"

"ariga.io/atlas/sql/schema"
)

// NormalizeRealm returns the normal representation of the given database.
func (d *Driver) NormalizeRealm(ctx context.Context, r *schema.Realm) (nr *schema.Realm, err error) {
for _, s := range r.Schemas {
switch s.Name {
case "mysql", "information_schema", "performance_schema", "sys":
return nil, fmt.Errorf("sql/mysql: normalizing internal schema %q is not supported", s.Name)
}
}
var (
twins = make(map[string]string)
changes = make([]schema.Change, 0, len(r.Schemas))
reverse = make([]schema.Change, 0, len(r.Schemas))
opts = &schema.InspectRealmOption{
Schemas: make([]string, 0, len(r.Schemas)),
}
)
for _, s := range r.Schemas {
twin := twinName(s.Name)
twins[twin] = s.Name
s.Name = twin
opts.Schemas = append(opts.Schemas, s.Name)
// Skip adding the schema.IfNotExists clause
// to fail if the schema exists.
st := schema.New(twin).AddAttrs(s.Attrs...)
changes = append(changes, &schema.AddSchema{S: st})
reverse = append(reverse, &schema.DropSchema{S: st, Extra: []schema.Clause{&schema.IfExists{}}})
for _, t := range s.Tables {
// If objects are not strongly connected.
if t.Schema != s {
t.Schema = s
}
changes = append(changes, &schema.AddTable{T: t})
}
}
patch := func(r *schema.Realm) {
for _, s := range r.Schemas {
s.Name = twins[s.Name]
}
}
// Delete the twin resources, and return
// the source realm to its initial state.
defer func() {
patch(r)
uerr := d.ApplyChanges(ctx, reverse)
if err != nil {
err = fmt.Errorf("%w: %v", err, uerr)
}
err = uerr
}()
if err := d.ApplyChanges(ctx, changes); err != nil {
return nil, err
}
if nr, err = d.InspectRealm(ctx, opts); err != nil {
return nil, err
}
patch(nr)
return nr, nil
}

// NormalizeSchema returns the normal representation of the given database.
func (d *Driver) NormalizeSchema(ctx context.Context, s *schema.Schema) (*schema.Schema, error) {
r := &schema.Realm{}
if s.Realm != nil {
r.Attrs = s.Realm.Attrs
}
r.Schemas = append(r.Schemas, s)
nr, err := d.NormalizeRealm(ctx, r)
if err != nil {
return nil, err
}
ns, ok := nr.Schema(s.Name)
if !ok {
return nil, fmt.Errorf("sql/mysql: missing normalized schema %q", s.Name)
}
return ns, nil
}

const maxLen = 64

func twinName(name string) string {
twin := fmt.Sprintf("atlas_twin_%s_%d", name, time.Now().Unix())
if len(twin) <= maxLen {
return twin
}
return fmt.Sprintf("%s_%x", twin[:maxLen-33], md5.Sum([]byte(twin)))
}
62 changes: 62 additions & 0 deletions sql/mysql/normalize_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.

package mysql

import (
"context"
"strings"
"testing"

"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/schema"

"github.com/stretchr/testify/require"
)

func TestDriver_NormalizeRealm(t *testing.T) {
var (
apply = &mockApply{}
inspect = &mockInspect{
realm: schema.NewRealm(schema.New("test").SetCharset("utf8mb4")),
}
drv = &Driver{
Inspector: inspect,
PlanApplier: apply,
}
)
normal, err := drv.NormalizeRealm(context.Background(), schema.NewRealm(schema.New("test")))
require.NoError(t, err)
require.Equal(t, normal, inspect.realm)

require.Len(t, inspect.schemas, 1)
require.True(t, strings.HasPrefix(inspect.schemas[0], "atlas_twin_test_"))

require.Len(t, apply.changes, 2, "expect 2 calls (create and drop)")
require.Len(t, apply.changes[0], 1)
require.Equal(t, &schema.AddSchema{S: schema.New(inspect.schemas[0])}, apply.changes[0][0])
require.Len(t, apply.changes[1], 1)
require.Equal(t, &schema.DropSchema{S: schema.New(inspect.schemas[0]), Extra: []schema.Clause{&schema.IfExists{}}}, apply.changes[1][0])
}

type mockInspect struct {
schema.Inspector
schemas []string
realm *schema.Realm
}

func (m *mockInspect) InspectRealm(_ context.Context, opts *schema.InspectRealmOption) (*schema.Realm, error) {
m.schemas = append(m.schemas, opts.Schemas...)
return m.realm, nil
}

type mockApply struct {
migrate.PlanApplier
changes [][]schema.Change
}

func (m *mockApply) ApplyChanges(_ context.Context, changes []schema.Change) error {
m.changes = append(m.changes, changes)
return nil
}
12 changes: 12 additions & 0 deletions sql/schema/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,15 @@ type (
InspectRealm(ctx context.Context, opts *InspectRealmOption) (*Realm, error)
}
)

// Normalizer is the interface implemented by the different database drivers for
// "normalizing" schema objects. i.e. converting schema objects defined in natural
// form to their representation in the database. Thus, two schema objects are equal
// if their normal forms are equal.
type Normalizer interface {
// NormalizeSchema returns the normal representation of a schema.
NormalizeSchema(context.Context, *Schema) (*Schema, error)

// NormalizeRealm returns the normal representation of a database.
NormalizeRealm(context.Context, *Realm) (*Realm, error)
}

0 comments on commit 30c43c0

Please sign in to comment.