Skip to content

Commit

Permalink
Marshal/Unmarshal missing UDT fields as null instead of failing in un…
Browse files Browse the repository at this point in the history
…safe mode

We can't return an error in case a field is added to the UDT,
otherwise existing code would break by simply altering the UDT in the
database. For extra fields at the end of the UDT put nulls to be in
line with gocql, but also python-driver and java-driver.

In gocql it was fixed in scylladb/gocql@d2ed1bb
  • Loading branch information
dkropachev authored and sylwiaszunejko committed Jun 25, 2024
1 parent c6f942a commit 6a60650
Show file tree
Hide file tree
Showing 3 changed files with 341 additions and 15 deletions.
325 changes: 323 additions & 2 deletions iterx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package gocqlx_test

import (
"math/big"
"reflect"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -48,6 +49,328 @@ type FullNamePtrUDT struct {
*FullName
}

func diff(t *testing.T, expected, got interface{}) {
t.Helper()

if d := cmp.Diff(expected, got, diffOpts); d != "" {
t.Errorf("got %+v expected %+v, diff: %s", got, expected, d)
}
}

var diffOpts = cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{})

func TestIterxUDT(t *testing.T) {
session := gocqlxtest.CreateSession(t)
t.Cleanup(func() {
session.Close()
})

if err := session.ExecStmt(`CREATE TYPE gocqlx_test.UDTTest_Full (first text, second text)`); err != nil {
t.Fatal("create type:", err)
}

if err := session.ExecStmt(`CREATE TABLE gocqlx_test.udt_table (
testuuid timeuuid PRIMARY KEY,
testudt gocqlx_test.UDTTest_Full
)`); err != nil {
t.Fatal("create table:", err)
}

type Full struct {
First string
Second string
}

type Part struct {
First string
}

type Extra struct {
First string
Second string
Third string
}

type FullUDT struct {
gocqlx.UDT
Full
}

type PartUDT struct {
gocqlx.UDT
Part
}

type ExtraUDT struct {
gocqlx.UDT
Extra
}

type FullUDTPtr struct {
gocqlx.UDT
*Full
}

type PartUDTPtr struct {
gocqlx.UDT
*Part
}

type ExtraUDTPtr struct {
gocqlx.UDT
*Extra
}

full := FullUDT{
Full: Full{
First: "John",
Second: "Doe",
},
}

makeStruct := func(testuuid gocql.UUID, insert interface{}) interface{} {
b := reflect.New(reflect.StructOf([]reflect.StructField{
{
Name: "TestUUID",
Type: reflect.TypeOf(gocql.UUID{}),
},
{
Name: "TestUDT",
Type: reflect.TypeOf(insert),
},
})).Interface()
reflect.ValueOf(b).Elem().FieldByName("TestUUID").Set(reflect.ValueOf(testuuid))
reflect.ValueOf(b).Elem().FieldByName("TestUDT").Set(reflect.ValueOf(insert))
return b
}

tcases := []struct {
name string
insert interface{}
expected interface{}
expectedOnDB FullUDT
}{
{
name: "exact-match",
insert: full,
expectedOnDB: full,
expected: full,
},
{
name: "exact-match-ptr",
insert: FullUDTPtr{
Full: &Full{
First: "John",
Second: "Doe",
},
},
expectedOnDB: full,
expected: FullUDTPtr{
Full: &Full{
First: "John",
Second: "Doe",
},
},
},
{
name: "extra-field",
insert: ExtraUDT{
Extra: Extra{
First: "John",
Second: "Doe",
Third: "Smith",
},
},
expectedOnDB: full,
expected: ExtraUDT{
Extra: Extra{
First: "John",
Second: "Doe",
Third: "", // Since the UDT has only 2 fields, the third field should be empty
},
},
},
{
name: "extra-field-ptr",
insert: ExtraUDTPtr{
Extra: &Extra{
First: "John",
Second: "Doe",
Third: "Smith",
},
},
expectedOnDB: full,
expected: ExtraUDTPtr{
Extra: &Extra{
First: "John",
Second: "Doe",
Third: "", // Since the UDT has only 2 fields, the third field should be empty
},
},
},
{
name: "absent-field",
insert: PartUDT{
Part: Part{
First: "John",
},
},
expectedOnDB: FullUDT{
Full: Full{
First: "John",
Second: "",
},
},
expected: PartUDT{
Part: Part{
First: "John",
},
},
},
{
name: "absent-field-ptr",
insert: PartUDTPtr{
Part: &Part{
First: "John",
},
},
expectedOnDB: FullUDT{
Full: Full{
First: "John",
Second: "",
},
},
expected: PartUDTPtr{
Part: &Part{
First: "John",
},
},
},
}

const insertStmt = `INSERT INTO udt_table (testuuid, testudt) VALUES (?, ?)`
const deleteStmt = `DELETE FROM udt_table WHERE testuuid = ?`

for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
testuuid := gocql.TimeUUID()

if reflect.TypeOf(tc.insert) != reflect.TypeOf(tc.expected) {
t.Fatalf("insert and expectedOnDB must have the same type")
}

t.Cleanup(func() {
session.Query(deleteStmt, nil).Bind(testuuid).ExecRelease() // nolint:errcheck
})

t.Run("insert-bind", func(t *testing.T) {
if err := session.Query(insertStmt, nil).Unsafe().Bind(
testuuid,
tc.insert,
).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := FullUDT{}
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(&v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expectedOnDB, v)
})

t.Run("scan", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Scan(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
})

t.Run("get", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
})

t.Run("delete", func(t *testing.T) {
if err := session.Query(deleteStmt, nil).Bind(
testuuid,
).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
})

t.Run("insert-bind-struct", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().BindStruct(b).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})

t.Run("insert-bind-struct-map", func(t *testing.T) {
t.Run("empty-map", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindStructMap(b, nil).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})

t.Run("empty-struct", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindStructMap(struct{}{}, map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
}).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})
})

t.Run("insert-bind-map", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindMap(map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
}).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})
})
}
}

func TestIterxStruct(t *testing.T) {
session := gocqlxtest.CreateSession(t)
defer session.Close()
Expand Down Expand Up @@ -153,8 +476,6 @@ func TestIterxStruct(t *testing.T) {
t.Fatal("insert:", err)
}

diffOpts := cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{})

const stmt = `SELECT * FROM struct_table`

t.Run("get", func(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions queryx.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ func (q *Queryx) Bind(v ...interface{}) *Queryx {
return q
}

// Scan executes the query, copies the columns of the first selected
// row into the values pointed at by dest and discards the rest. If no rows
// were selected, ErrNotFound is returned.
func (q *Queryx) Scan(v ...interface{}) error {
return q.Query.Scan(udtWrapSlice(q.Mapper, q.unsafe, v)...)
}

// Err returns any binding errors.
func (q *Queryx) Err() error {
return q.err
Expand Down
24 changes: 11 additions & 13 deletions udt.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,24 @@ func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt {

func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) {
value, ok := u.field[name]

var data []byte
var err error
if ok {
data, err = gocql.Marshal(info, value.Interface())
if err != nil {
return nil, err
}
return gocql.Marshal(info, value.Interface())
}

return data, err
if u.unsafe {
return nil, nil
}
return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type())
}

func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
value, ok := u.field[name]
if !ok && !u.unsafe {
return fmt.Errorf("missing name %q in %s", name, u.value.Type())
if ok {
return gocql.Unmarshal(info, data, value.Addr().Interface())
}

return gocql.Unmarshal(info, data, value.Addr().Interface())
if u.unsafe {
return nil
}
return fmt.Errorf("missing name %q in %s", name, u.value.Type())
}

// udtWrapValue adds UDT wrapper if needed.
Expand Down

0 comments on commit 6a60650

Please sign in to comment.