diff --git a/.gitignore b/.gitignore index 62c8935..e0caea8 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -.idea/ \ No newline at end of file +.idea/ +vendor/ \ No newline at end of file diff --git a/expect.go b/expect.go index 249b7fb..717da6a 100644 --- a/expect.go +++ b/expect.go @@ -5,7 +5,6 @@ import ( "reflect" "sync" "time" - "unsafe" "github.com/redis/go-redis/v9" ) @@ -386,7 +385,7 @@ func inflow(cmd redis.Cmder, key string, val interface{}) { if !v.IsValid() { panic(fmt.Sprintf("cmd did not find key '%s'", key)) } - v = reflect.NewAt(v.Type(), unsafe.Pointer(v.UnsafeAddr())).Elem() + v = reflect.NewAt(v.Type(), v.Addr().UnsafePointer()).Elem() setVal := reflect.ValueOf(val) if v.Kind() != reflect.Interface && setVal.Kind() != v.Kind() { diff --git a/go.mod b/go.mod index fc472ef..c627be7 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,9 @@ -module github.com/go-redis/redismock/v9 +module github.com/descope/redismock/v9 go 1.18 require ( + github.com/go-redis/redismock/v9 v9.0.3 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.25.0 github.com/redis/go-redis/v9 v9.0.3 diff --git a/go.sum b/go.sum index 6a81ac5..67bf04e 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-redis/redismock/v9 v9.0.3 h1:mtHQi2l51lCmXIbTRTqb1EiHYe9tL5Yk5oorlSJJqR0= +github.com/go-redis/redismock/v9 v9.0.3/go.mod h1:F6tJRfnU8R/NZ0E+Gjvoluk14MqMC5ueSZX6vVQypc0= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= diff --git a/mock.go b/mock.go index c8c28b9..32cd9b6 100644 --- a/mock.go +++ b/mock.go @@ -41,12 +41,17 @@ func NewClientMock() (*redis.Client, ClientMock) { return m.client.(*redis.Client), m } +func NewClientMockWithHooks(hooks ...redis.Hook) (*redis.Client, ClientMock) { + m := newMock(redisClient, hooks...) + return m.client.(*redis.Client), m +} + func NewClusterMock() (*redis.ClusterClient, ClusterClientMock) { m := newMock(redisCluster) return m.client.(*redis.ClusterClient), m } -func newMock(typ redisClientType) *mock { +func newMock(typ redisClientType, hooks ...redis.Hook) *mock { m := &mock{ ctx: context.Background(), clientType: typ, @@ -59,6 +64,9 @@ func newMock(typ redisClientType) *mock { factory := redis.NewClient(opt) client := redis.NewClient(opt) factory.AddHook(nilHook{}) + for i := range hooks { + client.AddHook(hooks[i]) + } client.AddHook(redisClientHook{fn: m.process}) m.factory = factory @@ -263,13 +271,20 @@ func (m *mock) match(expect expectation, cmd redis.Cmder) error { func (m *mock) compare(isRegexp bool, expect, cmd interface{}) error { expr, ok := expect.(string) if isRegexp && ok { - cmdValue := fmt.Sprint(cmd) + var cmdValue1 string + var cmdValue2 string + if bCmd, ok := cmd.([]byte); ok { + cmdValue1 = string(bCmd) + cmdValue2 = fmt.Sprint(cmd) + } else { + cmdValue1 = fmt.Sprint(cmd) + } re, err := regexp.Compile(expr) if err != nil { return err } - if !re.MatchString(cmdValue) { - return fmt.Errorf("args not match, expectation regular: '%s', but gave: '%s'", expr, cmdValue) + if !re.MatchString(cmdValue1) && !re.MatchString(cmdValue2) { + return fmt.Errorf("args not match, expectation regular: '%s', but gave: '%s', and: '%s", expr, cmdValue1, cmdValue2) } } else if !reflect.DeepEqual(expect, cmd) { return fmt.Errorf("args not `DeepEqual`, expectation: '%+v', but gave: '%+v'", expect, cmd) @@ -774,6 +789,13 @@ func (m *mock) ExpectGet(key string) *ExpectedString { return e } +func (m *mock) ExpectGetDel(key string) *ExpectedString { + e := &ExpectedString{} + e.cmd = m.factory.GetDel(m.ctx, key) + m.pushExpect(e) + return e +} + func (m *mock) ExpectGetRange(key string, start, end int64) *ExpectedString { e := &ExpectedString{} e.cmd = m.factory.GetRange(m.ctx, key, start, end) @@ -795,13 +817,6 @@ func (m *mock) ExpectGetEx(key string, expiration time.Duration) *ExpectedString return e } -func (m *mock) ExpectGetDel(key string) *ExpectedString { - e := &ExpectedString{} - e.cmd = m.factory.GetDel(m.ctx, key) - m.pushExpect(e) - return e -} - func (m *mock) ExpectIncr(key string) *ExpectedInt { e := &ExpectedInt{} e.cmd = m.factory.Incr(m.ctx, key)