diff --git a/indexer/postgres/internal/testdata/example_schema.go b/indexer/postgres/internal/testdata/example_schema.go index 396f08fbd1d2..a40d794034b7 100644 --- a/indexer/postgres/internal/testdata/example_schema.go +++ b/indexer/postgres/internal/testdata/example_schema.go @@ -36,7 +36,7 @@ func init() { AllKindsObject.ValueFields = append(AllKindsObject.ValueFields, field) } - ExampleSchema = schema.MustNewModuleSchema( + ExampleSchema = schema.MustCompileModuleSchema( AllKindsObject, SingletonObject, VoteObject, diff --git a/schema/decoding/decoding_test.go b/schema/decoding/decoding_test.go index 52b3f846ea48..7e1a08c92b7c 100644 --- a/schema/decoding/decoding_test.go +++ b/schema/decoding/decoding_test.go @@ -371,7 +371,7 @@ func (e exampleBankModule) subBalance(acct, denom string, amount uint64) error { func init() { var err error - exampleBankSchema, err = schema.NewModuleSchema(schema.ObjectType{ + exampleBankSchema, err = schema.CompileModuleSchema(schema.ObjectType{ Name: "balances", KeyFields: []schema.Field{ { @@ -435,7 +435,7 @@ type oneValueModule struct { func init() { var err error - oneValueModSchema, err = schema.NewModuleSchema(schema.ObjectType{ + oneValueModSchema, err = schema.CompileModuleSchema(schema.ObjectType{ Name: "item", ValueFields: []schema.Field{ {Name: "value", Kind: schema.StringKind}, diff --git a/schema/decoding/resolver_test.go b/schema/decoding/resolver_test.go index 144a1d0a8295..397b97bd6c33 100644 --- a/schema/decoding/resolver_test.go +++ b/schema/decoding/resolver_test.go @@ -10,7 +10,7 @@ import ( type modA struct{} func (m modA) ModuleCodec() (schema.ModuleCodec, error) { - modSchema, err := schema.NewModuleSchema(schema.ObjectType{Name: "A", KeyFields: []schema.Field{{Name: "field1", Kind: schema.StringKind}}}) + modSchema, err := schema.CompileModuleSchema(schema.ObjectType{Name: "A", KeyFields: []schema.Field{{Name: "field1", Kind: schema.StringKind}}}) if err != nil { return schema.ModuleCodec{}, err } @@ -22,7 +22,7 @@ func (m modA) ModuleCodec() (schema.ModuleCodec, error) { type modB struct{} func (m modB) ModuleCodec() (schema.ModuleCodec, error) { - modSchema, err := schema.NewModuleSchema(schema.ObjectType{Name: "B", KeyFields: []schema.Field{{Name: "field2", Kind: schema.StringKind}}}) + modSchema, err := schema.CompileModuleSchema(schema.ObjectType{Name: "B", KeyFields: []schema.Field{{Name: "field2", Kind: schema.StringKind}}}) if err != nil { return schema.ModuleCodec{}, err } @@ -44,7 +44,7 @@ var testResolver = ModuleSetDecoderResolver(moduleSet) func TestModuleSetDecoderResolver_IterateAll(t *testing.T) { objectTypes := map[string]bool{} err := testResolver.IterateAll(func(moduleName string, cdc schema.ModuleCodec) error { - cdc.Schema.Types(func(t schema.Type) bool { + cdc.Schema.AllTypes(func(t schema.Type) bool { objTyp, ok := t.(schema.ObjectType) if ok { objectTypes[objTyp.Name] = true diff --git a/schema/diff/diff_test.go b/schema/diff/diff_test.go index 867d3e1d660c..c193cbe01ba1 100644 --- a/schema/diff/diff_test.go +++ b/schema/diff/diff_test.go @@ -332,7 +332,7 @@ func TestCompareModuleSchemas(t *testing.T) { } func requireModuleSchema(t *testing.T, types ...schema.Type) schema.ModuleSchema { - s, err := schema.NewModuleSchema(types...) + s, err := schema.CompileModuleSchema(types...) if err != nil { t.Fatal(err) } diff --git a/schema/enum.go b/schema/enum.go index 942758033de8..4783ccff6fc6 100644 --- a/schema/enum.go +++ b/schema/enum.go @@ -44,7 +44,7 @@ func (e EnumType) TypeName() string { func (EnumType) isType() {} // Validate validates the enum definition. -func (e EnumType) Validate(Schema) error { +func (e EnumType) Validate(TypeSet) error { if !ValidateName(e.Name) { return fmt.Errorf("invalid enum definition name %q", e.Name) } diff --git a/schema/enum_test.go b/schema/enum_test.go index 332648accfec..387d03e2cfd6 100644 --- a/schema/enum_test.go +++ b/schema/enum_test.go @@ -108,7 +108,7 @@ func TestEnumDefinition_Validate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := tt.enum.Validate(EmptySchema{}) + err := tt.enum.Validate(EmptyTypeSet()) if tt.errContains == "" { if err != nil { t.Errorf("expected valid enum definition to pass validation, got: %v", err) diff --git a/schema/field.go b/schema/field.go index df6b6139cd12..af7374e367f7 100644 --- a/schema/field.go +++ b/schema/field.go @@ -20,7 +20,7 @@ type Field struct { } // Validate validates the field. -func (c Field) Validate(schema Schema) error { +func (c Field) Validate(typeSet TypeSet) error { // valid name if !ValidateName(c.Name) { return fmt.Errorf("invalid field name %q", c.Name) @@ -38,7 +38,7 @@ func (c Field) Validate(schema Schema) error { return fmt.Errorf("enum field %q must have a referenced type", c.Name) } - ty, ok := schema.LookupType(c.ReferencedType) + ty, ok := typeSet.LookupType(c.ReferencedType) if !ok { return fmt.Errorf("enum field %q references unknown type %q", c.Name, c.ReferencedType) } @@ -58,7 +58,7 @@ func (c Field) Validate(schema Schema) error { // ValidateValue validates that the value conforms to the field's kind and nullability. // Unlike Kind.ValidateValue, it also checks that the value conforms to the EnumType // if the field is an EnumKind. -func (c Field) ValidateValue(value interface{}, schema Schema) error { +func (c Field) ValidateValue(value interface{}, typeSet TypeSet) error { if value == nil { if !c.Nullable { return fmt.Errorf("field %q cannot be null", c.Name) @@ -72,7 +72,7 @@ func (c Field) ValidateValue(value interface{}, schema Schema) error { switch c.Kind { case EnumKind: - ty, ok := schema.LookupType(c.ReferencedType) + ty, ok := typeSet.LookupType(c.ReferencedType) if !ok { return fmt.Errorf("enum field %q references unknown type %q", c.Name, c.ReferencedType) } diff --git a/schema/field_test.go b/schema/field_test.go index d8873c5b9fb2..0756f5a7ca86 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -225,7 +225,7 @@ func TestFieldJSON(t *testing.T) { } } -var testEnumSchema = MustNewModuleSchema(EnumType{ +var testEnumSchema = MustCompileModuleSchema(EnumType{ Name: "enum", Values: []EnumValueDefinition{{Name: "a", Value: 1}, {Name: "b", Value: 2}}, }) diff --git a/schema/fields.go b/schema/fields.go index 19104e0fa340..d0bf1106dfec 100644 --- a/schema/fields.go +++ b/schema/fields.go @@ -4,16 +4,16 @@ import "fmt" // ValidateObjectKey validates that the value conforms to the set of fields as a Key in an ObjectUpdate. // See ObjectUpdate.Key for documentation on the requirements of such keys. -func ValidateObjectKey(keyFields []Field, value interface{}, schema Schema) error { - return validateFieldsValue(keyFields, value, schema) +func ValidateObjectKey(keyFields []Field, value interface{}, typeSet TypeSet) error { + return validateFieldsValue(keyFields, value, typeSet) } // ValidateObjectValue validates that the value conforms to the set of fields as a Value in an ObjectUpdate. // See ObjectUpdate.Value for documentation on the requirements of such values. -func ValidateObjectValue(valueFields []Field, value interface{}, schema Schema) error { +func ValidateObjectValue(valueFields []Field, value interface{}, typeSet TypeSet) error { valueUpdates, ok := value.(ValueUpdates) if !ok { - return validateFieldsValue(valueFields, value, schema) + return validateFieldsValue(valueFields, value, typeSet) } values := map[string]interface{}{} @@ -31,7 +31,7 @@ func ValidateObjectValue(valueFields []Field, value interface{}, schema Schema) continue } - if err := field.ValidateValue(v, schema); err != nil { + if err := field.ValidateValue(v, typeSet); err != nil { return err } @@ -45,13 +45,13 @@ func ValidateObjectValue(valueFields []Field, value interface{}, schema Schema) return nil } -func validateFieldsValue(fields []Field, value interface{}, schema Schema) error { +func validateFieldsValue(fields []Field, value interface{}, typeSet TypeSet) error { if len(fields) == 0 { return nil } if len(fields) == 1 { - return fields[0].ValidateValue(value, schema) + return fields[0].ValidateValue(value, typeSet) } values, ok := value.([]interface{}) @@ -63,7 +63,7 @@ func validateFieldsValue(fields []Field, value interface{}, schema Schema) error return fmt.Errorf("expected %d key fields, got %d values", len(fields), len(value.([]interface{}))) } for i, field := range fields { - if err := field.ValidateValue(values[i], schema); err != nil { + if err := field.ValidateValue(values[i], typeSet); err != nil { return err } } diff --git a/schema/fields_test.go b/schema/fields_test.go index b1789524f17d..57b303a669bc 100644 --- a/schema/fields_test.go +++ b/schema/fields_test.go @@ -56,7 +56,7 @@ func TestValidateForKeyFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := ValidateObjectKey(tt.keyFields, tt.key, EmptySchema{}) + err := ValidateObjectKey(tt.keyFields, tt.key, EmptyTypeSet()) if tt.errContains == "" { if err != nil { t.Fatalf("unexpected error: %v", err) @@ -128,7 +128,7 @@ func TestValidateForValueFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := ValidateObjectValue(tt.valueFields, tt.value, EmptySchema{}) + err := ValidateObjectValue(tt.valueFields, tt.value, EmptyTypeSet()) if tt.errContains == "" { if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/schema/module_schema.go b/schema/module_schema.go index 8fc92d9f9f65..dbd5365f5563 100644 --- a/schema/module_schema.go +++ b/schema/module_schema.go @@ -11,9 +11,9 @@ type ModuleSchema struct { types map[string]Type } -// NewModuleSchema constructs a new ModuleSchema and validates it. Any module schema returned without an error -// is guaranteed to be valid. -func NewModuleSchema(types ...Type) (ModuleSchema, error) { +// CompileModuleSchema compiles the types into a ModuleSchema and validates it. +// Any module schema returned without an error is guaranteed to be valid. +func CompileModuleSchema(types ...Type) (ModuleSchema, error) { typeMap := map[string]Type{} for _, typ := range types { @@ -34,10 +34,10 @@ func NewModuleSchema(types ...Type) (ModuleSchema, error) { return res, nil } -// MustNewModuleSchema constructs a new ModuleSchema and panics if it is invalid. +// MustCompileModuleSchema constructs a new ModuleSchema and panics if it is invalid. // This should only be used in test code or static initialization where it is safe to panic! -func MustNewModuleSchema(types ...Type) ModuleSchema { - sch, err := NewModuleSchema(types...) +func MustCompileModuleSchema(types ...Type) ModuleSchema { + sch, err := CompileModuleSchema(types...) if err != nil { panic(err) } @@ -79,7 +79,7 @@ func (s ModuleSchema) LookupType(name string) (Type, bool) { // Types calls the provided function for each type in the module schema and stops if the function returns false. // The types are iterated over in sorted order by name. This function is compatible with go 1.23 iterators. -func (s ModuleSchema) Types(f func(Type) bool) { +func (s ModuleSchema) AllTypes(f func(Type) bool) { keys := make([]string, 0, len(s.types)) for k := range s.types { keys = append(keys, k) @@ -94,7 +94,7 @@ func (s ModuleSchema) Types(f func(Type) bool) { // ObjectTypes iterators over all the object types in the schema in alphabetical order. func (s ModuleSchema) ObjectTypes(f func(ObjectType) bool) { - s.Types(func(t Type) bool { + s.AllTypes(func(t Type) bool { objTyp, ok := t.(ObjectType) if ok { return f(objTyp) @@ -105,7 +105,7 @@ func (s ModuleSchema) ObjectTypes(f func(ObjectType) bool) { // EnumTypes iterators over all the enum types in the schema in alphabetical order. func (s ModuleSchema) EnumTypes(f func(EnumType) bool) { - s.Types(func(t Type) bool { + s.AllTypes(func(t Type) bool { enumType, ok := t.(EnumType) if ok { return f(enumType) @@ -169,4 +169,6 @@ func (s *ModuleSchema) UnmarshalJSON(data []byte) error { return nil } -var _ Schema = ModuleSchema{} +func (ModuleSchema) isTypeSet() {} + +var _ TypeSet = ModuleSchema{} diff --git a/schema/module_schema_test.go b/schema/module_schema_test.go index 9c356f3457f7..a52085a60f5e 100644 --- a/schema/module_schema_test.go +++ b/schema/module_schema_test.go @@ -66,8 +66,8 @@ func TestModuleSchema_Validate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // because validate is called when calling NewModuleSchema, we just call NewModuleSchema - _, err := NewModuleSchema(tt.types...) + // because validate is called when calling CompileModuleSchema, we just call CompileModuleSchema + _, err := CompileModuleSchema(tt.types...) if tt.errContains == "" { if err != nil { t.Fatalf("unexpected error: %v", err) @@ -169,7 +169,7 @@ func TestModuleSchema_ValidateObjectUpdate(t *testing.T) { func requireModuleSchema(t *testing.T, types ...Type) ModuleSchema { t.Helper() - moduleSchema, err := NewModuleSchema(types...) + moduleSchema, err := CompileModuleSchema(types...) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -240,7 +240,7 @@ func TestModuleSchema_Types(t *testing.T) { moduleSchema := exampleSchema(t) var typeNames []string - moduleSchema.Types(func(typ Type) bool { + moduleSchema.AllTypes(func(typ Type) bool { typeNames = append(typeNames, typ.TypeName()) return true }) @@ -252,7 +252,7 @@ func TestModuleSchema_Types(t *testing.T) { typeNames = nil // scan just the first type and return false - moduleSchema.Types(func(typ Type) bool { + moduleSchema.AllTypes(func(typ Type) bool { typeNames = append(typeNames, typ.TypeName()) return false }) diff --git a/schema/object_type.go b/schema/object_type.go index c35f27c7226d..f961d06062e3 100644 --- a/schema/object_type.go +++ b/schema/object_type.go @@ -37,7 +37,7 @@ func (o ObjectType) TypeName() string { func (ObjectType) isType() {} // Validate validates the object type. -func (o ObjectType) Validate(schema Schema) error { +func (o ObjectType) Validate(typeSet TypeSet) error { if !ValidateName(o.Name) { return fmt.Errorf("invalid object type name %q", o.Name) } @@ -45,7 +45,7 @@ func (o ObjectType) Validate(schema Schema) error { fieldNames := map[string]bool{} for _, field := range o.KeyFields { - if err := field.Validate(schema); err != nil { + if err := field.Validate(typeSet); err != nil { return fmt.Errorf("invalid key field %q: %v", field.Name, err) //nolint:errorlint // false positive due to using go1.12 } @@ -64,7 +64,7 @@ func (o ObjectType) Validate(schema Schema) error { } for _, field := range o.ValueFields { - if err := field.Validate(schema); err != nil { + if err := field.Validate(typeSet); err != nil { return fmt.Errorf("invalid value field %q: %v", field.Name, err) //nolint:errorlint // false positive due to using go1.12 } @@ -82,12 +82,12 @@ func (o ObjectType) Validate(schema Schema) error { } // ValidateObjectUpdate validates that the update conforms to the object type. -func (o ObjectType) ValidateObjectUpdate(update ObjectUpdate, schema Schema) error { +func (o ObjectType) ValidateObjectUpdate(update ObjectUpdate, typeSet TypeSet) error { if o.Name != update.TypeName { return fmt.Errorf("object type name %q does not match update type name %q", o.Name, update.TypeName) } - if err := ValidateObjectKey(o.KeyFields, update.Key, schema); err != nil { + if err := ValidateObjectKey(o.KeyFields, update.Key, typeSet); err != nil { return fmt.Errorf("invalid key for object type %q: %v", update.TypeName, err) //nolint:errorlint // false positive due to using go1.12 } @@ -95,5 +95,5 @@ func (o ObjectType) ValidateObjectUpdate(update ObjectUpdate, schema Schema) err return nil } - return ValidateObjectValue(o.ValueFields, update.Value, schema) + return ValidateObjectValue(o.ValueFields, update.Value, typeSet) } diff --git a/schema/object_type_test.go b/schema/object_type_test.go index be6f00f24818..e2b68590241c 100644 --- a/schema/object_type_test.go +++ b/schema/object_type_test.go @@ -205,7 +205,7 @@ func TestObjectType_Validate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := tt.objectType.Validate(EmptySchema{}) + err := tt.objectType.Validate(EmptyTypeSet()) if tt.errContains == "" { if err != nil { t.Fatalf("unexpected error: %v", err) @@ -267,7 +267,7 @@ func TestObjectType_ValidateObjectUpdate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := tt.objectType.ValidateObjectUpdate(tt.object, EmptySchema{}) + err := tt.objectType.ValidateObjectUpdate(tt.object, EmptyTypeSet()) if tt.errContains == "" { if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/schema/testing/enum_test.go b/schema/testing/enum_test.go index 8c89ceb994d2..c961d85764a1 100644 --- a/schema/testing/enum_test.go +++ b/schema/testing/enum_test.go @@ -12,6 +12,6 @@ import ( func TestEnumType(t *testing.T) { rapid.Check(t, func(t *rapid.T) { enumType := EnumType().Draw(t, "enum") - require.NoError(t, enumType.Validate(schema.EmptySchema{})) + require.NoError(t, enumType.Validate(schema.EmptyTypeSet())) }) } diff --git a/schema/testing/example_schema.go b/schema/testing/example_schema.go index c22cf6fff3de..81461b91b97f 100644 --- a/schema/testing/example_schema.go +++ b/schema/testing/example_schema.go @@ -10,7 +10,7 @@ import ( // that can be used in reproducible unit testing and property based testing. var ExampleAppSchema = map[string]schema.ModuleSchema{ "all_kinds": mkAllKindsModule(), - "test_cases": schema.MustNewModuleSchema( + "test_cases": schema.MustCompileModuleSchema( schema.ObjectType{ Name: "Singleton", KeyFields: []schema.Field{}, @@ -138,7 +138,7 @@ func mkAllKindsModule() schema.ModuleSchema { types = append(types, typ) } - return schema.MustNewModuleSchema(types...) + return schema.MustCompileModuleSchema(types...) } func mkTestObjectType(kind schema.Kind) schema.ObjectType { diff --git a/schema/testing/field.go b/schema/testing/field.go index 5a5f7cadc81d..87154afec78e 100644 --- a/schema/testing/field.go +++ b/schema/testing/field.go @@ -19,8 +19,8 @@ var ( ) // FieldGen generates random Field's based on the validity criteria of fields. -func FieldGen(sch schema.Schema) *rapid.Generator[schema.Field] { - enumTypes := slices.DeleteFunc(slices.Collect(sch.Types), func(t schema.Type) bool { +func FieldGen(typeSet schema.TypeSet) *rapid.Generator[schema.Field] { + enumTypes := slices.DeleteFunc(slices.Collect(typeSet.AllTypes), func(t schema.Type) bool { _, ok := t.(schema.EnumType) return !ok }) @@ -50,16 +50,16 @@ func FieldGen(sch schema.Schema) *rapid.Generator[schema.Field] { } // KeyFieldGen generates random key fields based on the validity criteria of key fields. -func KeyFieldGen(sch schema.Schema) *rapid.Generator[schema.Field] { - return FieldGen(sch).Filter(func(f schema.Field) bool { +func KeyFieldGen(typeSet schema.TypeSet) *rapid.Generator[schema.Field] { + return FieldGen(typeSet).Filter(func(f schema.Field) bool { return !f.Nullable && f.Kind.ValidKeyKind() }) } // FieldValueGen generates random valid values for the field, aiming to exercise the full range of possible // values for the field. -func FieldValueGen(field schema.Field, sch schema.Schema) *rapid.Generator[any] { - gen := baseFieldValue(field, sch) +func FieldValueGen(field schema.Field, typeSet schema.TypeSet) *rapid.Generator[any] { + gen := baseFieldValue(field, typeSet) if field.Nullable { return rapid.OneOf(gen, rapid.Just[any](nil)).AsAny() @@ -68,7 +68,7 @@ func FieldValueGen(field schema.Field, sch schema.Schema) *rapid.Generator[any] return gen } -func baseFieldValue(field schema.Field, sch schema.Schema) *rapid.Generator[any] { +func baseFieldValue(field schema.Field, typeSet schema.TypeSet) *rapid.Generator[any] { switch field.Kind { case schema.StringKind: return rapid.StringOf(rapid.Rune().Filter(func(r rune) bool { @@ -113,7 +113,7 @@ func baseFieldValue(field schema.Field, sch schema.Schema) *rapid.Generator[any] case schema.AddressKind: return rapid.SliceOfN(rapid.Byte(), 20, 64).AsAny() case schema.EnumKind: - typ, found := sch.LookupType(field.ReferencedType) + typ, found := typeSet.LookupType(field.ReferencedType) enumTyp, ok := typ.(schema.EnumType) if !found || !ok { panic(fmt.Errorf("enum type %q not found", field.ReferencedType)) @@ -128,18 +128,18 @@ func baseFieldValue(field schema.Field, sch schema.Schema) *rapid.Generator[any] } // ObjectKeyGen generates a value that is valid for the provided object key fields. -func ObjectKeyGen(keyFields []schema.Field, sch schema.Schema) *rapid.Generator[any] { +func ObjectKeyGen(keyFields []schema.Field, typeSet schema.TypeSet) *rapid.Generator[any] { if len(keyFields) == 0 { return rapid.Just[any](nil) } if len(keyFields) == 1 { - return FieldValueGen(keyFields[0], sch) + return FieldValueGen(keyFields[0], typeSet) } gens := make([]*rapid.Generator[any], len(keyFields)) for i, field := range keyFields { - gens[i] = FieldValueGen(field, sch) + gens[i] = FieldValueGen(field, typeSet) } return rapid.Custom(func(t *rapid.T) any { @@ -156,7 +156,7 @@ func ObjectKeyGen(keyFields []schema.Field, sch schema.Schema) *rapid.Generator[ // are valid for insertion (in the case forUpdate is false) or for update (in the case forUpdate is true). // Values that are for update may skip some fields in a ValueUpdates instance whereas values for insertion // will always contain all values. -func ObjectValueGen(valueFields []schema.Field, forUpdate bool, sch schema.Schema) *rapid.Generator[any] { +func ObjectValueGen(valueFields []schema.Field, forUpdate bool, typeSet schema.TypeSet) *rapid.Generator[any] { if len(valueFields) == 0 { // if we have no value fields, always return nil return rapid.Just[any](nil) @@ -164,7 +164,7 @@ func ObjectValueGen(valueFields []schema.Field, forUpdate bool, sch schema.Schem gens := make([]*rapid.Generator[any], len(valueFields)) for i, field := range valueFields { - gens[i] = FieldValueGen(field, sch) + gens[i] = FieldValueGen(field, typeSet) } return rapid.Custom(func(t *rapid.T) any { // return ValueUpdates 50% of the time diff --git a/schema/testing/field_test.go b/schema/testing/field_test.go index 91510f0f223a..f83807ddfe85 100644 --- a/schema/testing/field_test.go +++ b/schema/testing/field_test.go @@ -27,7 +27,7 @@ var checkFieldValue = func(t *rapid.T) { require.NoError(t, field.ValidateValue(fieldValue, testEnumSchema)) } -var testEnumSchema = schema.MustNewModuleSchema(schema.EnumType{ +var testEnumSchema = schema.MustCompileModuleSchema(schema.EnumType{ Name: "test_enum", Values: []schema.EnumValueDefinition{{Name: "a", Value: 1}, {Name: "b", Value: 2}}, }) diff --git a/schema/testing/module_schema.go b/schema/testing/module_schema.go index 7dcce247c48a..c874ac73a9f1 100644 --- a/schema/testing/module_schema.go +++ b/schema/testing/module_schema.go @@ -11,7 +11,7 @@ func ModuleSchemaGen() *rapid.Generator[schema.ModuleSchema] { enumTypesGen := distinctTypes(EnumType()) return rapid.Custom(func(t *rapid.T) schema.ModuleSchema { enumTypes := enumTypesGen.Draw(t, "enumTypes") - tempSchema, err := schema.NewModuleSchema(enumTypes...) + tempSchema, err := schema.CompileModuleSchema(enumTypes...) if err != nil { t.Fatal(err) } @@ -19,7 +19,7 @@ func ModuleSchemaGen() *rapid.Generator[schema.ModuleSchema] { objectTypes := distinctTypes(ObjectTypeGen(tempSchema)).Draw(t, "objectTypes") allTypes := append(enumTypes, objectTypes...) - modSchema, err := schema.NewModuleSchema(allTypes...) + modSchema, err := schema.CompileModuleSchema(allTypes...) if err != nil { t.Fatal(err) } diff --git a/schema/testing/object.go b/schema/testing/object.go index 85588d2ecd24..3e7665aa4bc5 100644 --- a/schema/testing/object.go +++ b/schema/testing/object.go @@ -8,12 +8,12 @@ import ( ) // ObjectTypeGen generates random ObjectType's based on the validity criteria of object types. -func ObjectTypeGen(sch schema.Schema) *rapid.Generator[schema.ObjectType] { - keyFieldsGen := rapid.SliceOfNDistinct(KeyFieldGen(sch), 1, 6, func(f schema.Field) string { +func ObjectTypeGen(typeSet schema.TypeSet) *rapid.Generator[schema.ObjectType] { + keyFieldsGen := rapid.SliceOfNDistinct(KeyFieldGen(typeSet), 1, 6, func(f schema.Field) string { return f.Name }) - valueFieldsGen := rapid.SliceOfNDistinct(FieldGen(sch), 1, 12, func(f schema.Field) string { + valueFieldsGen := rapid.SliceOfNDistinct(FieldGen(typeSet), 1, 12, func(f schema.Field) string { return f.Name }) @@ -21,7 +21,7 @@ func ObjectTypeGen(sch schema.Schema) *rapid.Generator[schema.ObjectType] { typ := schema.ObjectType{ Name: NameGen.Filter(func(s string) bool { // filter out names that already exist in the schema - _, found := sch.LookupType(s) + _, found := typeSet.LookupType(s) return !found }).Draw(t, "name"), } @@ -56,13 +56,13 @@ func hasDuplicateFieldNames(typeNames map[string]bool, fields []schema.Field) bo } // ObjectInsertGen generates object updates that are valid for insertion. -func ObjectInsertGen(objectType schema.ObjectType, sch schema.Schema) *rapid.Generator[schema.ObjectUpdate] { - return ObjectUpdateGen(objectType, nil, sch) +func ObjectInsertGen(objectType schema.ObjectType, typeSet schema.TypeSet) *rapid.Generator[schema.ObjectUpdate] { + return ObjectUpdateGen(objectType, nil, typeSet) } // ObjectUpdateGen generates object updates that are valid for updates using the provided state map as a source // of valid existing keys. -func ObjectUpdateGen(objectType schema.ObjectType, state *btree.Map[string, schema.ObjectUpdate], sch schema.Schema) *rapid.Generator[schema.ObjectUpdate] { +func ObjectUpdateGen(objectType schema.ObjectType, state *btree.Map[string, schema.ObjectUpdate], sch schema.TypeSet) *rapid.Generator[schema.ObjectUpdate] { keyGen := ObjectKeyGen(objectType.KeyFields, sch).Filter(func(key interface{}) bool { // filter out keys that exist in the state if state != nil { diff --git a/schema/testing/statesim/object_coll.go b/schema/testing/statesim/object_coll.go index e59fa0dae44e..e2c026a898a2 100644 --- a/schema/testing/statesim/object_coll.go +++ b/schema/testing/statesim/object_coll.go @@ -14,16 +14,16 @@ import ( type ObjectCollection struct { options Options objectType schema.ObjectType - sch schema.Schema + typeSet schema.TypeSet objects *btree.Map[string, schema.ObjectUpdate] updateGen *rapid.Generator[schema.ObjectUpdate] valueFieldIndices map[string]int } // NewObjectCollection creates a new ObjectCollection for the given object type. -func NewObjectCollection(objectType schema.ObjectType, options Options, sch schema.Schema) *ObjectCollection { +func NewObjectCollection(objectType schema.ObjectType, options Options, typeSet schema.TypeSet) *ObjectCollection { objects := &btree.Map[string, schema.ObjectUpdate]{} - updateGen := schematesting.ObjectUpdateGen(objectType, objects, sch) + updateGen := schematesting.ObjectUpdateGen(objectType, objects, typeSet) valueFieldIndices := make(map[string]int, len(objectType.ValueFields)) for i, field := range objectType.ValueFields { valueFieldIndices[field.Name] = i @@ -32,7 +32,7 @@ func NewObjectCollection(objectType schema.ObjectType, options Options, sch sche return &ObjectCollection{ options: options, objectType: objectType, - sch: sch, + typeSet: typeSet, objects: objects, updateGen: updateGen, valueFieldIndices: valueFieldIndices, @@ -45,7 +45,7 @@ func (o *ObjectCollection) ApplyUpdate(update schema.ObjectUpdate) error { return fmt.Errorf("update type name %q does not match object type name %q", update.TypeName, o.objectType.Name) } - err := o.objectType.ValidateObjectUpdate(update, o.sch) + err := o.objectType.ValidateObjectUpdate(update, o.typeSet) if err != nil { return err } diff --git a/schema/type.go b/schema/type.go index 0398fa295db6..f525e84b817b 100644 --- a/schema/type.go +++ b/schema/type.go @@ -7,30 +7,45 @@ type Type interface { TypeName() string // Validate validates the type. - Validate(Schema) error + Validate(TypeSet) error // isType is a private method that ensures that only types in this package can be marked as types. isType() } -// Schema represents something that has types and allows them to be looked up by name. +// TypeSet represents something that has types and allows them to be looked up by name. // Currently, the only implementation is ModuleSchema. -type Schema interface { +type TypeSet interface { // LookupType looks up a type by name. LookupType(name string) (Type, bool) - // Types calls the given function for each type in the schema. - Types(f func(Type) bool) + // AllTypes calls the given function for each type in the type set. + // This function is compatible with go 1.23 iterators and can be used like this: + // for t := range types.AllTypes { + // // do something with t + // } + AllTypes(f func(Type) bool) + + // isTypeSet is a private method that ensures that only types in this package can be marked as type sets. + isTypeSet() } -// EmptySchema is a schema that contains no types. +// EmptyTypeSet is a schema that contains no types. // It can be used in Validate methods when there is no schema needed or available. -type EmptySchema struct{} +func EmptyTypeSet() TypeSet { + return emptyTypeSetInst +} + +var emptyTypeSetInst = emptyTypeSet{} -// LookupType always returns false because there are no types in an EmptySchema. -func (EmptySchema) LookupType(name string) (Type, bool) { +type emptyTypeSet struct{} + +// LookupType always returns false because there are no types in an EmptyTypeSet. +func (emptyTypeSet) LookupType(string) (Type, bool) { return nil, false } -// Types does nothing because there are no types in an EmptySchema. -func (EmptySchema) Types(f func(Type) bool) {} +// Types does nothing because there are no types in an EmptyTypeSet. +func (emptyTypeSet) AllTypes(func(Type) bool) {} + +func (emptyTypeSet) isTypeSet() {}