diff --git a/app/clustering.go b/app/clustering.go index c580568c..a08b756c 100644 --- a/app/clustering.go +++ b/app/clustering.go @@ -670,7 +670,7 @@ func (a *App) unassignTarget(ctx context.Context, name string, serviceID string) } rsp, err := client.Do(req) if err != nil { - rsp.Body.Close() + // don't close the body here since Body will be nil a.Logger.Printf("failed HTTP request: %v", err) continue } diff --git a/app/metrics.go b/app/metrics.go index ec3c042f..7c6a7aa6 100644 --- a/app/metrics.go +++ b/app/metrics.go @@ -69,7 +69,7 @@ func (a *App) startClusterMetrics() { if err != nil { a.Logger.Printf("failed to get leader key: %v", err) } - if leader[leaderKey] == a.Config.InstanceName { + if leader[leaderKey] == a.Config.Clustering.InstanceName { clusterIsLeader.Set(1) } else { clusterIsLeader.Set(0) @@ -84,7 +84,7 @@ func (a *App) startClusterMetrics() { } numLockedNodes := 0 for _, v := range lockedNodes { - if v == a.Config.InstanceName { + if v == a.Config.Clustering.InstanceName { numLockedNodes++ } } diff --git a/go.mod b/go.mod index 6aee0b5e..1e496669 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/fsnotify/fsnotify v1.6.0 github.com/fullstorydev/grpcurl v1.8.7 github.com/go-redis/redis/v8 v8.11.5 + github.com/go-redsync/redsync/v4 v4.10.0 github.com/go-resty/resty/v2 v2.7.0 github.com/google/go-cmp v0.5.9 github.com/google/uuid v1.3.1 @@ -44,6 +45,7 @@ require ( github.com/prometheus/client_golang v1.16.0 github.com/prometheus/client_model v0.4.0 github.com/prometheus/prometheus v0.45.0 + github.com/redis/go-redis/v9 v9.2.1 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.15.0 @@ -88,6 +90,7 @@ require ( github.com/go-openapi/jsonpointer v0.19.6 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.22.3 // indirect + github.com/gomodule/redigo v2.0.0+incompatible // indirect github.com/google/gnostic v0.6.9 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/s2a-go v0.1.4 // indirect diff --git a/go.sum b/go.sum index 6df42ae5..a0d4fd4e 100644 --- a/go.sum +++ b/go.sum @@ -236,6 +236,10 @@ github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+Ce github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/bufbuild/protocompile v0.5.1 h1:mixz5lJX4Hiz4FpqFREJHIXLfaLBntfaJv1h+/jS+Qg= github.com/bufbuild/protocompile v0.5.1/go.mod h1:G5iLmavmF4NsYtpZFvE3B/zFch2GIY8+wjsYLR/lc40= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= @@ -407,8 +411,14 @@ github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTM github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= +github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/go-redis/redis/v7 v7.4.0 h1:7obg6wUoj05T0EpY0o8B59S9w5yeMWql7sw2kwNW1x4= +github.com/go-redis/redis/v7 v7.4.0/go.mod h1:JDNMw23GTyLNC4GZu9njt15ctBQVn7xjRfnwdHj/Dcg= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-redsync/redsync/v4 v4.10.0 h1:hTeAak4C73mNBQSTq6KCKDFaiIlfC+z5yTTl8fCJuBs= +github.com/go-redsync/redsync/v4 v4.10.0/go.mod h1:ZfayzutkgeBmEmBlUR3j+rF6kN44UUGtEdfzhBFZTPc= github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY= github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= @@ -470,6 +480,8 @@ github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golangci/lint-1 v0.0.0-20181222135242-d2cdd8c08219/go.mod h1:/X8TswGSh1pIozq4ZwCfxS0WA5JGXguxk94ar/4c87Y= +github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= +github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= @@ -1011,6 +1023,10 @@ github.com/prometheus/prometheus v0.45.0/go.mod h1:jC5hyO8ItJBnDWGecbEucMyXjzxGv github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/redis/go-redis/v9 v9.2.1 h1:WlYJg71ODF0dVspZZCpYmoF1+U1Jjk9Rwd7pq6QmlCg= +github.com/redis/go-redis/v9 v9.2.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= +github.com/redis/rueidis v1.0.19 h1:s65oWtotzlIFN8eMPhyYwxlwLR1lUdhza2KtWprKYSo= +github.com/redis/rueidis v1.0.19/go.mod h1:8B+r5wdnjwK3lTFml5VtxjzGOQAC+5UmujoD12pDrEo= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -1083,6 +1099,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203 h1:QVqDTf3h2WHt08YuiTGPZLls0Wq99X9bWd0Q5ZSBesM= +github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203/go.mod h1:oqN97ltKNihBbwlX8dLpwxCl3+HnXKV/R0e+sRLd9C8= github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8= github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= diff --git a/lockers/all/all.go b/lockers/all/all.go index d4fab594..e2d99a0f 100644 --- a/lockers/all/all.go +++ b/lockers/all/all.go @@ -11,4 +11,5 @@ package all import ( _ "github.com/openconfig/gnmic/lockers/consul_locker" _ "github.com/openconfig/gnmic/lockers/k8s_locker" + _ "github.com/openconfig/gnmic/lockers/redis_locker" ) diff --git a/lockers/locker.go b/lockers/locker.go index 535ee311..bd9a6b03 100644 --- a/lockers/locker.go +++ b/lockers/locker.go @@ -17,27 +17,44 @@ import ( "github.com/mitchellh/mapstructure" ) -var ( - ErrCanceled = errors.New("canceled") -) +var ErrCanceled = errors.New("canceled") type Locker interface { + // Init initialises the locker data, with the given configuration read from flags/files. Init(context.Context, map[string]interface{}, ...Option) error + // Stop is called when the locker instance is called. It should unlock all aquired locks. + Stop() error + SetLogger(*log.Logger) + + // This is the Target locking logic. + // Lock acquires a lock on given key. Lock(context.Context, string, []byte) (bool, error) + // KeepLock maintains the lock on the target. KeepLock(context.Context, string) (chan struct{}, chan error) + // IsLocked replys if the target given as string is currently locked or not. IsLocked(context.Context, string) (bool, error) + // Unlock unlocks the target log. Unlock(context.Context, string) error + // This is the instance registration logic. + + // Register registers this instance in the registry. It must also maintain the registration (called in a goroutine from the main). ServiceRegistration.ID contains the ID of the service to register. Register(context.Context, *ServiceRegistration) error + // Deregister removes this instance from the registry. This looks like it's not called. Deregister(string) error - List(context.Context, string) (map[string]string, error) + // GetServices must return the gnmic instances. GetServices(ctx context.Context, serviceName string, tags []string) ([]*Service, error) + // WatchServices must push all existing discovered gnmic instances + // into the provided channel. WatchServices(ctx context.Context, serviceName string, tags []string, ch chan<- []*Service, dur time.Duration) error - Stop() error - SetLogger(*log.Logger) + // Mixed registration/target lock functions + + // List returns all locks that start with prefix string, + // indexed by the lock name. Could be target locks or leader lock. It must return a map of matching keys to instance name. + List(ctx context.Context, prefix string) (map[string]string, error) } type Initializer func() Locker @@ -55,6 +72,7 @@ func WithLogger(logger *log.Logger) Option { var LockerTypes = []string{ "consul", "k8s", + "redis", } func Register(name string, initFn Initializer) { diff --git a/lockers/redis_locker/redis_locker.go b/lockers/redis_locker/redis_locker.go new file mode 100644 index 00000000..a2da63ba --- /dev/null +++ b/lockers/redis_locker/redis_locker.go @@ -0,0 +1,257 @@ +package redis_locker + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "sync" + "time" + + "github.com/go-redsync/redsync/v4" + "github.com/go-redsync/redsync/v4/redis/goredis/v9" + "github.com/openconfig/gnmic/lockers" + "github.com/openconfig/gnmic/utils" + goredislib "github.com/redis/go-redis/v9" +) + +const ( + defaultLeaseDuration = 10 * time.Second + defaultRetryTimer = 2 * time.Second + defaultPollTimer = 10 * time.Second + loggingPrefix = "[redis_locker] " +) + +func init() { + lockers.Register("redis", func() lockers.Locker { + return &redisLocker{ + Cfg: &config{}, + m: new(sync.RWMutex), + acquiredLocks: make(map[string]*redsync.Mutex), + attemptingLocks: make(map[string]*redsync.Mutex), + registerLock: make(map[string]context.CancelFunc), + logger: log.New(io.Discard, loggingPrefix, utils.DefaultLoggingFlags), + } + }) +} + +type redisLocker struct { + Cfg *config + logger *log.Logger + m *sync.RWMutex + acquiredLocks map[string]*redsync.Mutex + attemptingLocks map[string]*redsync.Mutex + registerLock map[string]context.CancelFunc + + client goredislib.UniversalClient + redisLocker *redsync.Redsync +} + +type config struct { + Servers []string `mapstructure:"servers,omitempty" json:"servers,omitempty"` + MasterName string `mapstructure:"master-name,omitempty" json:"master-name,omitempty"` + Password string `mapstructure:"password,omitempty" json:"password,omitempty"` + LeaseDuration time.Duration `mapstructure:"lease-duration,omitempty" json:"lease-duration,omitempty"` + RenewPeriod time.Duration `mapstructure:"renew-period,omitempty" json:"renew-period,omitempty"` + RetryTimer time.Duration `mapstructure:"retry-timer,omitempty" json:"retry-timer,omitempty"` + PollTimer time.Duration `mapstructure:"poll-timer,omitempty" json:"poll-timer,omitempty"` + Debug bool `mapstructure:"debug,omitempty" json:"debug,omitempty"` +} + +func (k *redisLocker) Init(ctx context.Context, cfg map[string]interface{}, opts ...lockers.Option) error { + err := lockers.DecodeConfig(cfg, k.Cfg) + if err != nil { + return err + } + for _, opt := range opts { + opt(k) + } + err = k.setDefaults() + if err != nil { + return err + } + k.client = goredislib.NewUniversalClient(&goredislib.UniversalOptions{ + Addrs: k.Cfg.Servers, + MasterName: k.Cfg.MasterName, + Password: k.Cfg.Password, + }) + if err := k.client.Ping(ctx).Err(); err != nil { + return fmt.Errorf("cannot contact redis server: %w", err) + } + k.redisLocker = redsync.New(goredis.NewPool(k.client)) + return nil +} + +func (k *redisLocker) Lock(ctx context.Context, key string, val []byte) (bool, error) { + if k.Cfg.Debug { + k.logger.Printf("attempting to lock=%s", key) + } + mu := k.redisLocker.NewMutex( + key, + redsync.WithGenValueFunc(func() (string, error) { + rand, err := k.genRandValue() + if err != nil { + return "", err + } + return fmt.Sprintf("%s-%s", val, rand), nil + }), + redsync.WithExpiry(k.Cfg.LeaseDuration), + ) + k.m.Lock() + k.attemptingLocks[key] = mu + k.m.Unlock() + defer func() { + k.m.Lock() + defer k.m.Unlock() + delete(k.attemptingLocks, key) + }() + + for { + select { + case <-ctx.Done(): + return false, ctx.Err() + default: + err := mu.LockContext(ctx) + if err != nil { + switch err.(type) { + case *redsync.ErrTaken: + if k.Cfg.Debug { + k.logger.Printf("lock already taken lock=%s: %v", key, err) + } + return false, nil + default: + return false, fmt.Errorf("failed to acquire lock=%s: %w", key, err) + } + } + + k.m.Lock() + k.acquiredLocks[key] = mu + k.m.Unlock() + return true, nil + } + } +} + +func (k *redisLocker) KeepLock(ctx context.Context, key string) (chan struct{}, chan error) { + doneChan := make(chan struct{}) + errChan := make(chan error) + + go func() { + defer close(doneChan) + ticker := time.NewTicker(k.Cfg.RenewPeriod) + k.m.RLock() + lock, ok := k.acquiredLocks[key] + k.m.RUnlock() + for { + select { + case <-ctx.Done(): + errChan <- ctx.Err() + return + case <-doneChan: + return + case <-ticker.C: + if !ok { + errChan <- fmt.Errorf("unable to maintain lock %q: not found in acquiredlocks", key) + return + } + ok, err := lock.ExtendContext(ctx) + if err != nil { + errChan <- err + return + } + if !ok { + errChan <- fmt.Errorf("could not keep lock") + return + } + + } + } + }() + return doneChan, errChan +} + +func (k *redisLocker) Unlock(ctx context.Context, key string) error { + k.m.Lock() + defer k.m.Unlock() + if lock, ok := k.acquiredLocks[key]; ok { + delete(k.acquiredLocks, key) + ok, err := lock.Unlock() + if err != nil { + return err + } + if !ok { + return fmt.Errorf("failed to unlock lock %s", key) + } + } + if lock, ok := k.attemptingLocks[key]; ok { + delete(k.attemptingLocks, key) + _, err := lock.Unlock() + if err != nil { + return err + } + } + return nil +} + +func (k *redisLocker) Stop() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + keys := []string{} + k.m.RLock() + for key := range k.acquiredLocks { + keys = append(keys, key) + } + k.m.RUnlock() + for _, key := range keys { + k.Unlock(ctx, key) + } + return k.Deregister("") +} + +func (k *redisLocker) SetLogger(logger *log.Logger) { + if logger != nil && k.logger != nil { + k.logger.SetOutput(logger.Writer()) + k.logger.SetFlags(logger.Flags()) + } +} + +// helpers + +func (k *redisLocker) setDefaults() error { + if k.Cfg.LeaseDuration <= 0 { + k.Cfg.LeaseDuration = defaultLeaseDuration + } + if k.Cfg.RenewPeriod <= 0 || k.Cfg.RenewPeriod >= k.Cfg.LeaseDuration { + k.Cfg.RenewPeriod = k.Cfg.LeaseDuration / 2 + } + if k.Cfg.RetryTimer <= 0 { + k.Cfg.RetryTimer = defaultRetryTimer + } + if k.Cfg.PollTimer <= 0 { + k.Cfg.PollTimer = defaultPollTimer + } + return nil +} + +func (k *redisLocker) String() string { + b, err := json.Marshal(k.Cfg) + if err != nil { + return "" + } + return string(b) +} + +// genRandValue is required to generate a random value +// so that the redislock algorithm works properly +// especially in multi-server setups. +func (k *redisLocker) genRandValue() (string, error) { + b := make([]byte, 16) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(b), nil +} diff --git a/lockers/redis_locker/redis_registration.go b/lockers/redis_locker/redis_registration.go new file mode 100644 index 00000000..3154127a --- /dev/null +++ b/lockers/redis_locker/redis_registration.go @@ -0,0 +1,309 @@ +package redis_locker + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "time" + + "github.com/go-redsync/redsync/v4" + "github.com/openconfig/gnmic/lockers" + goredis "github.com/redis/go-redis/v9" +) + +// defaultWatchTimeout +const defaultWatchTimeout = 10 * time.Second + +// redisRegistration represents a gnmic endpoint in redis. +// It's serialised in the redis value to allow recovering +// it during service discovery. +type redisRegistration struct { + ID string + Address string + Port int + Tags []string + Rand string +} + +func (k *redisLocker) Register(ctx context.Context, s *lockers.ServiceRegistration) error { + ctx, cancel := context.WithCancel(ctx) + k.m.Lock() + k.registerLock[s.ID] = cancel + k.m.Unlock() + if k.Cfg.Debug { + k.logger.Printf("locking service=%s", s.ID) + } + mutex := k.redisLocker.NewMutex( + fmt.Sprintf("%s-%s", s.Name, s.ID), + redsync.WithGenValueFunc(func() (string, error) { + rand, err := k.genRandValue() + if err != nil { + return "", err + } + reg := &redisRegistration{ + ID: s.ID, + Address: s.Address, + Port: s.Port, + Tags: s.Tags, + Rand: rand, + } + val, err := json.Marshal(reg) + if err != nil { + return "", err + } + return string(val), nil + }), + redsync.WithExpiry(s.TTL), + ) + + err := mutex.LockContext(ctx) + if err != nil { + return fmt.Errorf("failed to lock service=%s, %w", s.ID, err) + } + + ticker := time.NewTicker(s.TTL / 2) + defer ticker.Stop() + for { + select { + case <-ticker.C: + ok, err := mutex.ExtendContext(ctx) + if err != nil { + return fmt.Errorf("failed to extend lock for service=%s: %w", s.ID, err) + } + if !ok { + return fmt.Errorf("could not extend lock for service=%s", s.ID) + } + case <-ctx.Done(): + mutex.Unlock() + return nil + } + } +} + +func (k *redisLocker) Deregister(s string) error { + k.m.Lock() + defer k.m.Unlock() + for sid, lockCancel := range k.registerLock { + if k.Cfg.Debug { + k.logger.Printf("unlocking service=%s", sid) + } + lockCancel() + delete(k.registerLock, sid) + } + return nil +} + +func (k *redisLocker) WatchServices(ctx context.Context, serviceName string, tags []string, sChan chan<- []*lockers.Service, watchTimeout time.Duration) error { + if watchTimeout <= 0 { + watchTimeout = defaultWatchTimeout + } + var err error + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + if k.Cfg.Debug { + k.logger.Printf("(re)starting watch service=%q", serviceName) + } + err = k.watch(ctx, serviceName, tags, sChan, watchTimeout) + if err != nil { + k.logger.Printf("watch ended with error: %s", err) + time.Sleep(k.Cfg.RetryTimer) + continue + } + + time.Sleep(k.Cfg.PollTimer) + } + } +} + +func (k *redisLocker) watch(ctx context.Context, serviceName string, tags []string, sChan chan<- []*lockers.Service, watchTimeout time.Duration) error { + // timeoutSeconds := int64(watchTimeout.Seconds()) + // TODO: implement watch + services, err := k.GetServices(ctx, serviceName, tags) + if err != nil { + return err + } + + sChan <- services + return nil +} + +func (k *redisLocker) getBatchOfKeys(ctx context.Context, key string, batchSize int64, cursor uint64) (uint64, map[string]*goredis.StringCmd, error) { + keys, cursor, err := k.client.Scan( + ctx, + cursor, + key, + batchSize, + ).Result() + if err != nil { + return 0, nil, fmt.Errorf("failed to scan keys: %w", err) + } + + results := map[string]*goredis.StringCmd{} + _, err = k.client.Pipelined(ctx, func(p goredis.Pipeliner) error { + for _, k := range keys { + results[k] = p.Get(ctx, k) + } + return nil + }) + + if err != nil { + return cursor, nil, fmt.Errorf("error getting contents of keys") + } + + return cursor, results, nil +} + +func (k *redisLocker) GetServices(ctx context.Context, serviceName string, tags []string) ([]*lockers.Service, error) { + var pageSize int64 = 50 + var cursor uint64 + var err error + var cmds map[string]*goredis.StringCmd + discoveredServiceRegistrations := []*redisRegistration{} + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + // to select all gnmic instances, matching the given prefix + cursor, cmds, err = k.getBatchOfKeys( + ctx, + fmt.Sprintf("%s-*", serviceName), + pageSize, + cursor, + ) + + if err != nil { + return nil, err + } + for _, cmd := range cmds { + bytesVal, err := cmd.Bytes() + if err != nil { + // key removed from redis + // could be that it has expired + // doesn't make a difference, we skip it + continue + } + serviceRegistration := &redisRegistration{} + if err := json.Unmarshal(bytesVal, serviceRegistration); err != nil { + // we don't have the data we expect + // skip it + continue + } + + discoveredServiceRegistrations = append( + discoveredServiceRegistrations, + serviceRegistration, + ) + } + // termination condition for redis scan + if cursor == 0 { + if k.Cfg.Debug { + k.logger.Printf("got %d services from redis", len(discoveredServiceRegistrations)) + } + // convert discovered servicesRegistrations to services + discoveredServices := make([]*lockers.Service, len(discoveredServiceRegistrations)) + for i, registration := range discoveredServiceRegistrations { + // match the required tags + if !matchTags(registration.Tags, tags) { + continue + } + discoveredServices[i] = &lockers.Service{ + ID: registration.ID, + Tags: registration.Tags, + Address: fmt.Sprintf( + "%s:%d", + registration.Address, + registration.Port, + ), + } + } + return discoveredServices, nil + } + } + } +} + +func (k *redisLocker) IsLocked(ctx context.Context, key string) (bool, error) { + count, err := k.client.Exists(ctx, key).Result() + if err != nil { + return false, fmt.Errorf("error during redis query: %w", err) + } + + if count > 0 { + return true, nil + } + return false, nil +} + +func (k *redisLocker) List(ctx context.Context, prefix string) (map[string]string, error) { + var cursor uint64 + var err error + var cmds map[string]*goredis.StringCmd + data := map[string]string{} + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + cursor, cmds, err = k.getBatchOfKeys( + ctx, + fmt.Sprintf("%s*", prefix), + 100, + cursor, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch from redis: %w", err) + } + if k.Cfg.Debug { + k.logger.Printf( + "got %d keys from redis for prefix=%s", + len(cmds), + prefix, + ) + } + for key, cmd := range cmds { + bytesVal, err := cmd.Bytes() + if err != nil { + // key removed from redis + // could be that it has expired + // doesn't make a difference, we skip it + continue + } + // we add a random string at the end of the value for redis + // redlock algorithm, so we need to remove it here + lastIndex := bytes.LastIndex(bytesVal, []byte("-")) + // if it's not there, we skip the key + if lastIndex < 0 { + continue + } + data[key] = string(bytesVal[:lastIndex]) + } + + if cursor == 0 { + return data, nil + } + } +} + +func matchTags(tags, wantedTags []string) bool { + if wantedTags == nil { + return true + } + tagsMap := map[string]struct{}{} + + for _, t := range tags { + tagsMap[t] = struct{}{} + } + + for _, wt := range wantedTags { + if _, ok := tagsMap[wt]; !ok { + return false + } + } + return true +}