From 6a60650668e48c9878eae4ed83d8c5c329fe03e8 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Fri, 14 Jun 2024 11:28:30 -0400 Subject: [PATCH] Marshal/Unmarshal missing UDT fields as null instead of failing in unsafe mode We can't return an error in case a field is added to the UDT, otherwise existing code would break by simply altering the UDT in the database. For extra fields at the end of the UDT put nulls to be in line with gocql, but also python-driver and java-driver. In gocql it was fixed in https://github.com/scylladb/gocql/commit/d2ed1bb74f3118a83a352e9ce912be765001efa4 --- iterx_test.go | 325 +++++++++++++++++++++++++++++++++++++++++++++++++- queryx.go | 7 ++ udt.go | 24 ++-- 3 files changed, 341 insertions(+), 15 deletions(-) diff --git a/iterx_test.go b/iterx_test.go index d8652a0..96de405 100644 --- a/iterx_test.go +++ b/iterx_test.go @@ -9,6 +9,7 @@ package gocqlx_test import ( "math/big" + "reflect" "strings" "testing" "time" @@ -48,6 +49,328 @@ type FullNamePtrUDT struct { *FullName } +func diff(t *testing.T, expected, got interface{}) { + t.Helper() + + if d := cmp.Diff(expected, got, diffOpts); d != "" { + t.Errorf("got %+v expected %+v, diff: %s", got, expected, d) + } +} + +var diffOpts = cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{}) + +func TestIterxUDT(t *testing.T) { + session := gocqlxtest.CreateSession(t) + t.Cleanup(func() { + session.Close() + }) + + if err := session.ExecStmt(`CREATE TYPE gocqlx_test.UDTTest_Full (first text, second text)`); err != nil { + t.Fatal("create type:", err) + } + + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.udt_table ( + testuuid timeuuid PRIMARY KEY, + testudt gocqlx_test.UDTTest_Full + )`); err != nil { + t.Fatal("create table:", err) + } + + type Full struct { + First string + Second string + } + + type Part struct { + First string + } + + type Extra struct { + First string + Second string + Third string + } + + type FullUDT struct { + gocqlx.UDT + Full + } + + type PartUDT struct { + gocqlx.UDT + Part + } + + type ExtraUDT struct { + gocqlx.UDT + Extra + } + + type FullUDTPtr struct { + gocqlx.UDT + *Full + } + + type PartUDTPtr struct { + gocqlx.UDT + *Part + } + + type ExtraUDTPtr struct { + gocqlx.UDT + *Extra + } + + full := FullUDT{ + Full: Full{ + First: "John", + Second: "Doe", + }, + } + + makeStruct := func(testuuid gocql.UUID, insert interface{}) interface{} { + b := reflect.New(reflect.StructOf([]reflect.StructField{ + { + Name: "TestUUID", + Type: reflect.TypeOf(gocql.UUID{}), + }, + { + Name: "TestUDT", + Type: reflect.TypeOf(insert), + }, + })).Interface() + reflect.ValueOf(b).Elem().FieldByName("TestUUID").Set(reflect.ValueOf(testuuid)) + reflect.ValueOf(b).Elem().FieldByName("TestUDT").Set(reflect.ValueOf(insert)) + return b + } + + tcases := []struct { + name string + insert interface{} + expected interface{} + expectedOnDB FullUDT + }{ + { + name: "exact-match", + insert: full, + expectedOnDB: full, + expected: full, + }, + { + name: "exact-match-ptr", + insert: FullUDTPtr{ + Full: &Full{ + First: "John", + Second: "Doe", + }, + }, + expectedOnDB: full, + expected: FullUDTPtr{ + Full: &Full{ + First: "John", + Second: "Doe", + }, + }, + }, + { + name: "extra-field", + insert: ExtraUDT{ + Extra: Extra{ + First: "John", + Second: "Doe", + Third: "Smith", + }, + }, + expectedOnDB: full, + expected: ExtraUDT{ + Extra: Extra{ + First: "John", + Second: "Doe", + Third: "", // Since the UDT has only 2 fields, the third field should be empty + }, + }, + }, + { + name: "extra-field-ptr", + insert: ExtraUDTPtr{ + Extra: &Extra{ + First: "John", + Second: "Doe", + Third: "Smith", + }, + }, + expectedOnDB: full, + expected: ExtraUDTPtr{ + Extra: &Extra{ + First: "John", + Second: "Doe", + Third: "", // Since the UDT has only 2 fields, the third field should be empty + }, + }, + }, + { + name: "absent-field", + insert: PartUDT{ + Part: Part{ + First: "John", + }, + }, + expectedOnDB: FullUDT{ + Full: Full{ + First: "John", + Second: "", + }, + }, + expected: PartUDT{ + Part: Part{ + First: "John", + }, + }, + }, + { + name: "absent-field-ptr", + insert: PartUDTPtr{ + Part: &Part{ + First: "John", + }, + }, + expectedOnDB: FullUDT{ + Full: Full{ + First: "John", + Second: "", + }, + }, + expected: PartUDTPtr{ + Part: &Part{ + First: "John", + }, + }, + }, + } + + const insertStmt = `INSERT INTO udt_table (testuuid, testudt) VALUES (?, ?)` + const deleteStmt = `DELETE FROM udt_table WHERE testuuid = ?` + + for _, tc := range tcases { + t.Run(tc.name, func(t *testing.T) { + testuuid := gocql.TimeUUID() + + if reflect.TypeOf(tc.insert) != reflect.TypeOf(tc.expected) { + t.Fatalf("insert and expectedOnDB must have the same type") + } + + t.Cleanup(func() { + session.Query(deleteStmt, nil).Bind(testuuid).ExecRelease() // nolint:errcheck + }) + + t.Run("insert-bind", func(t *testing.T) { + if err := session.Query(insertStmt, nil).Unsafe().Bind( + testuuid, + tc.insert, + ).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := FullUDT{} + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(&v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expectedOnDB, v) + }) + + t.Run("scan", func(t *testing.T) { + v := reflect.New(reflect.TypeOf(tc.expected)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Scan(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface()) + }) + + t.Run("get", func(t *testing.T) { + v := reflect.New(reflect.TypeOf(tc.expected)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface()) + }) + + t.Run("delete", func(t *testing.T) { + if err := session.Query(deleteStmt, nil).Bind( + testuuid, + ).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + }) + + t.Run("insert-bind-struct", func(t *testing.T) { + b := makeStruct(testuuid, tc.insert) + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().BindStruct(b).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + + t.Run("insert-bind-struct-map", func(t *testing.T) { + t.Run("empty-map", func(t *testing.T) { + b := makeStruct(testuuid, tc.insert) + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindStructMap(b, nil).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + + t.Run("empty-struct", func(t *testing.T) { + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindStructMap(struct{}{}, map[string]interface{}{ + "test_uuid": testuuid, + "test_udt": tc.insert, + }).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + }) + + t.Run("insert-bind-map", func(t *testing.T) { + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindMap(map[string]interface{}{ + "test_uuid": testuuid, + "test_udt": tc.insert, + }).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + }) + } +} + func TestIterxStruct(t *testing.T) { session := gocqlxtest.CreateSession(t) defer session.Close() @@ -153,8 +476,6 @@ func TestIterxStruct(t *testing.T) { t.Fatal("insert:", err) } - diffOpts := cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{}) - const stmt = `SELECT * FROM struct_table` t.Run("get", func(t *testing.T) { diff --git a/queryx.go b/queryx.go index 331512c..9b67f9b 100644 --- a/queryx.go +++ b/queryx.go @@ -215,6 +215,13 @@ func (q *Queryx) Bind(v ...interface{}) *Queryx { return q } +// Scan executes the query, copies the columns of the first selected +// row into the values pointed at by dest and discards the rest. If no rows +// were selected, ErrNotFound is returned. +func (q *Queryx) Scan(v ...interface{}) error { + return q.Query.Scan(udtWrapSlice(q.Mapper, q.unsafe, v)...) +} + // Err returns any binding errors. func (q *Queryx) Err() error { return q.err diff --git a/udt.go b/udt.go index 63551f7..2785151 100644 --- a/udt.go +++ b/udt.go @@ -39,26 +39,24 @@ func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt { func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) { value, ok := u.field[name] - - var data []byte - var err error if ok { - data, err = gocql.Marshal(info, value.Interface()) - if err != nil { - return nil, err - } + return gocql.Marshal(info, value.Interface()) } - - return data, err + if u.unsafe { + return nil, nil + } + return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type()) } func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error { value, ok := u.field[name] - if !ok && !u.unsafe { - return fmt.Errorf("missing name %q in %s", name, u.value.Type()) + if ok { + return gocql.Unmarshal(info, data, value.Addr().Interface()) } - - return gocql.Unmarshal(info, data, value.Addr().Interface()) + if u.unsafe { + return nil + } + return fmt.Errorf("missing name %q in %s", name, u.value.Type()) } // udtWrapValue adds UDT wrapper if needed.