Skip to content

Commit

Permalink
fix: Better schema, always return valid JSON
Browse files Browse the repository at this point in the history
  • Loading branch information
erezrokah committed Jul 23, 2024
1 parent d4044fc commit a0b04e4
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 75 deletions.
12 changes: 0 additions & 12 deletions transformers/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,3 @@ func WithPrimaryKeyComponents(fields ...string) StructTransformerOption {
t.pkComponentFields = fields
}
}

func withCurrentJSONTypeSchemaDepth(depth int) StructTransformerOption {
return func(t *structTransformer) {
t.currentJSONTypeSchemaDepth = depth
}
}

func useArrowNullForNilColumnType() StructTransformerOption {
return func(t *structTransformer) {
t.useArrowNullForNilColumnType = true
}
}
94 changes: 42 additions & 52 deletions transformers/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ type structTransformer struct {
pkFieldsFound []string
pkComponentFields []string
pkComponentFieldsFound []string

currentJSONTypeSchemaDepth int
useArrowNullForNilColumnType bool
}

func isFieldStruct(reflectType reflect.Type) bool {
Expand Down Expand Up @@ -131,28 +128,11 @@ func (t *structTransformer) getColumnType(field reflect.StructField) (arrow.Data
return columnType, nil
}

func fieldTypeSchema(field arrow.Field) string {
if field.Type == arrow.Null {
return "any"
}
typeSchema, ok := field.Metadata.GetValue(schema.MetadataTypeSchema)
if !ok || typeSchema == "" {
typeSchema = field.Type.String()
}
return typeSchema
}

func structSchemaToJSON(s arrow.Schema) string {
fieldMap := make(map[string]any)
fieldCount := s.NumFields()
for i := 0; i < fieldCount; i++ {
field := s.Field(i)
fieldMap[field.Name] = fieldTypeSchema(field)
}
func structSchemaToJSON(s any) string {
b := new(bytes.Buffer)
encoder := json.NewEncoder(b)
encoder.SetEscapeHTML(false)
_ = encoder.Encode(fieldMap)
_ = encoder.Encode(s)
return strings.TrimSpace(b.String())
}

Expand All @@ -163,45 +143,60 @@ func normalizePointer(field reflect.StructField) reflect.Value {
return reflect.New(field.Type)
}

func (t *structTransformer) transformFieldToSchema(field reflect.StructField) string {
func (t *structTransformer) fieldToJSONSchema(field reflect.StructField, depth int) any {
transformInput := normalizePointer(field)
switch transformInput.Elem().Kind() {
case reflect.Struct:
table := &schema.Table{}
err := TransformWithStruct(
transformInput.Interface(),
WithNameTransformer(t.nameTransformer),
WithTypeTransformer(t.typeTransformer),
WithUnwrapAllEmbeddedStructs(),
withCurrentJSONTypeSchemaDepth(t.currentJSONTypeSchemaDepth+1),
useArrowNullForNilColumnType(),
)(table)
if err != nil {
return ""
fieldsMap := make(map[string]any)
fieldType := transformInput.Elem().Type()
for i := 0; i < fieldType.NumField(); i++ {
name, err := t.nameTransformer(fieldType.Field(i))
if err != nil {
continue
}
columnType, err := t.getColumnType(fieldType.Field(i))
if err != nil {
continue
}
if columnType == nil {
fieldsMap[name] = "any"
continue
}
if columnType == types.ExtensionTypes.JSON && depth < maxJSONTypeSchemaDepth {
fieldsMap[name] = t.fieldToJSONSchema(fieldType.Field(i), depth+1)
continue
}
if arrow.IsListLike(columnType.ID()) {
fieldsMap[name] = []any{columnType.(*arrow.ListType).Elem().String()}
continue
}
fieldsMap[name] = columnType.String()
}
return structSchemaToJSON(*table.ToArrowSchema())
return fieldsMap
case reflect.Map:
keySchema := t.transformFieldToSchema(reflect.StructField{
keySchema, ok := t.fieldToJSONSchema(reflect.StructField{
Type: field.Type.Key(),
})
if keySchema == "" {
}, depth+1).(string)
if keySchema == "" || !ok {
return ""
}
valueSchema := t.transformFieldToSchema(reflect.StructField{
valueSchema := t.fieldToJSONSchema(reflect.StructField{
Type: field.Type.Elem(),
})
}, depth+1)
if valueSchema == "" {
return ""
}
return fmt.Sprintf("map<%s, %s, items_nullable>", keySchema, valueSchema)
return map[string]any{
keySchema: valueSchema,
}
case reflect.Slice:
valueSchema := t.transformFieldToSchema(reflect.StructField{
valueSchema := t.fieldToJSONSchema(reflect.StructField{
Type: field.Type.Elem(),
})
}, depth+1)
if valueSchema == "" {
return ""
}
return fmt.Sprintf("list<%s, items_nullable>", valueSchema)
return []any{valueSchema}
}

columnType, err := t.getColumnType(field)
Expand All @@ -225,12 +220,7 @@ func (t *structTransformer) addColumnFromField(field reflect.StructField, parent
}

if columnType == nil {
// We usually ignore interfaces/any types but if we're trying to figure a JSON field schema
// we still need them to show up in the docs as `any`
if !t.useArrowNullForNilColumnType {
return nil
}
columnType = arrow.Null
return nil
}

path := field.Name
Expand Down Expand Up @@ -267,8 +257,8 @@ func (t *structTransformer) addColumnFromField(field reflect.StructField, parent
}

// Avoid infinite recursion
if columnType == types.ExtensionTypes.JSON && t.currentJSONTypeSchemaDepth < maxJSONTypeSchemaDepth {
column.TypeSchema = t.transformFieldToSchema(field)
if columnType == types.ExtensionTypes.JSON {
column.TypeSchema = structSchemaToJSON(t.fieldToJSONSchema(field, 0))
}

for _, pk := range t.pkFields {
Expand Down
18 changes: 7 additions & 11 deletions transformers/struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ func TestJSONTypeSchema(t *testing.T) {
Tags map[string]string `json:"tags"`
}{},
want: map[string]string{
"tags": "map<utf8, utf8, items_nullable>",
"tags": `{"utf8":"utf8"}`,
},
},
{
Expand All @@ -504,7 +504,7 @@ func TestJSONTypeSchema(t *testing.T) {
} `json:"items"`
}{},
want: map[string]string{
"items": `list<{"name":"utf8"}, items_nullable>`,
"items": `[{"name":"utf8"}]`,
},
},
{
Expand All @@ -531,7 +531,7 @@ func TestJSONTypeSchema(t *testing.T) {
} `json:"item"`
}{},
want: map[string]string{
"item": `{"complex_items":"list<{\"name\":\"utf8\"}, items_nullable>","flat_items":"list<item: utf8, nullable>","name":"utf8","tags":"map<utf8, utf8, items_nullable>"}`,
"item": `{"complex_items":[{"name":"utf8"}],"flat_items":["utf8"],"name":"utf8","tags":{"utf8":"utf8"}}`,
},
},
{
Expand All @@ -548,7 +548,7 @@ func TestJSONTypeSchema(t *testing.T) {
} `json:"item"`
}{},
want: map[string]string{
"item": `{"complex_items":"list<{\"name\":\"utf8\"}, items_nullable>","flat_items":"list<item: utf8, nullable>","name":"utf8","tags":"map<utf8, utf8, items_nullable>"}`,
"item": `{"complex_items":[{"name":"utf8"}],"flat_items":["utf8"],"name":"utf8","tags":{"utf8":"utf8"}}`,
},
},
{
Expand All @@ -569,7 +569,7 @@ func TestJSONTypeSchema(t *testing.T) {
Tags map[string]any `json:"tags"`
}{},
want: map[string]string{
"tags": "map<utf8, any, items_nullable>",
"tags": `{"utf8":"any"}`,
},
},
{
Expand All @@ -578,7 +578,7 @@ func TestJSONTypeSchema(t *testing.T) {
Items []any `json:"items"`
}{},
want: map[string]string{
"items": `list<any, items_nullable>`,
"items": `["any"]`,
},
},
{
Expand All @@ -601,7 +601,7 @@ func TestJSONTypeSchema(t *testing.T) {
} `json:"level0"`
}{},
want: map[string]string{
"level0": "{\"level1\":\"{\\\"level2\\\":\\\"{\\\\\\\"level3\\\\\\\":\\\\\\\"{\\\\\\\\\\\\\\\"level4\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"{\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"level5\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"json\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"}\\\\\\\\\\\\\\\"}\\\\\\\"}\\\"}\"}",
"level0": `{"level1":{"level2":{"level3":{"level4":{"level5":{"level6":"json"}}}}}}`,
},
},
}
Expand All @@ -619,10 +619,6 @@ func TestJSONTypeSchema(t *testing.T) {
}
for col, schema := range tt.want {
column := table.Column(col)
if column == nil {
t.Fatalf("column %q not found", col)
}

if diff := cmp.Diff(column.TypeSchema, schema); diff != "" {
t.Fatalf("table does not match expected. diff (-got, +want): %v", diff)
}
Expand Down

0 comments on commit a0b04e4

Please sign in to comment.