Skip to content

Commit

Permalink
DiceDB#761: Added ZADD options implementation - XX|NX|CH|INCR| LT|GT …
Browse files Browse the repository at this point in the history
…according to Redis source code (DiceDB#1262)
  • Loading branch information
rushabhk04 authored Nov 12, 2024
1 parent 125f075 commit 9f9b441
Show file tree
Hide file tree
Showing 8 changed files with 1,957 additions and 35 deletions.
713 changes: 713 additions & 0 deletions integration_tests/commands/http/zset_test.go

Large diffs are not rendered by default.

536 changes: 536 additions & 0 deletions integration_tests/commands/resp/zset_test.go

Large diffs are not rendered by default.

538 changes: 538 additions & 0 deletions integration_tests/commands/websocket/zset_test.go

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions internal/eval/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ const (
NX string = "NX"
GT string = "GT"
LT string = "LT"
CH string = "CH"
INCR string = "INCR"
KeepTTL string = "KEEPTTL"
Sync string = "SYNC"
Async string = "ASYNC"
Expand Down
4 changes: 2 additions & 2 deletions internal/eval/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -2438,10 +2438,10 @@ func evalGEOADD(args []string, store *dstore.Store) []byte {
// Parse options
for startIdx < len(args) {
option := strings.ToUpper(args[startIdx])
if option == "NX" {
if option == NX {
nx = true
startIdx++
} else if option == "XX" {
} else if option == XX {
xx = true
startIdx++
} else {
Expand Down
173 changes: 153 additions & 20 deletions internal/eval/store_eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -695,32 +695,94 @@ func evalGETRANGE(args []string, store *dstore.Store) *EvalResponse {
// reinserted at the right position to ensure the correct ordering.
// If key does not exist, a new sorted set with the specified members as sole members is created.
func evalZADD(args []string, store *dstore.Store) *EvalResponse {
if len(args) < 3 || len(args)%2 == 0 {
// if length of command is 3, throw error as it is not possible
if len(args) < 3 {
return &EvalResponse{
Result: nil,
Error: diceerrors.ErrWrongArgumentCount("ZADD"),
}
}

key := args[0]
obj := store.Get(key)
var sortedSet *sortedset.Set
sortedSet, err := getOrCreateSortedSet(store, key)
if err != nil {
return &EvalResponse{
Result: nil,
Error: err,
}
}
// flags parsing
flags, nextIndex := parseFlags(args[1:])
if nextIndex >= len(args) || (len(args)-nextIndex)%2 != 0 {
return &EvalResponse{
Result: nil,
Error: diceerrors.ErrWrongArgumentCount("ZADD"),
}
}
// only valid flags works
if err := validateFlagsAndArgs(args[nextIndex:], flags); err != nil {
return &EvalResponse{
Result: nil,
Error: err,
}
}
// all processing takes place here
return processMembersWithFlags(args[nextIndex:], sortedSet, store, key, flags)
}

if obj != nil {
var err []byte
sortedSet, err = sortedset.FromObject(obj)
if err != nil {
return &EvalResponse{
Result: nil,
Error: diceerrors.ErrWrongTypeOperation,
}
// parseFlags identifies and parses the flags used in ZADD.
func parseFlags(args []string) (parsedFlags map[string]bool, nextIndex int) {
parsedFlags = map[string]bool{
NX: false,
XX: false,
LT: false,
GT: false,
CH: false,
INCR: false,
}
for i := 0; i < len(args); i++ {
switch strings.ToUpper(args[i]) {
case NX:
parsedFlags[NX] = true
case XX:
parsedFlags[XX] = true
case LT:
parsedFlags[LT] = true
case GT:
parsedFlags[GT] = true
case CH:
parsedFlags[CH] = true
case INCR:
parsedFlags[INCR] = true
default:
return parsedFlags, i + 1
}
} else {
sortedSet = sortedset.New()
}

added := 0
for i := 1; i < len(args); i += 2 {
return parsedFlags, len(args) + 1
}

// only valid combination of options works
func validateFlagsAndArgs(args []string, flags map[string]bool) error {
if len(args)%2 != 0 {
return diceerrors.ErrGeneral("syntax error")
}
if flags[NX] && flags[XX] {
return diceerrors.ErrGeneral("xx and nx options at the same time are not compatible")
}
if (flags[GT] && flags[NX]) || (flags[LT] && flags[NX]) || (flags[GT] && flags[LT]) {
return diceerrors.ErrGeneral("gt and LT and NX options at the same time are not compatible")
}
if flags[INCR] && len(args)/2 > 1 {
return diceerrors.ErrGeneral("incr option supports a single increment-element pair")
}
return nil
}

// processMembersWithFlags processes the members and scores while handling flags.
func processMembersWithFlags(args []string, sortedSet *sortedset.Set, store *dstore.Store, key string, flags map[string]bool) *EvalResponse {
added, updated := 0, 0

for i := 0; i < len(args); i += 2 {
scoreStr := args[i]
member := args[i+1]

Expand All @@ -732,22 +794,93 @@ func evalZADD(args []string, store *dstore.Store) *EvalResponse {
}
}

currentScore, exists := sortedSet.Get(member)

// If INCR is used, increment the score first
if flags[INCR] {
if exists {
score += currentScore
} else {
score = 0.0 + score
}

// Now check GT and LT conditions based on the incremented score and return accordingly
if (flags[GT] && exists && score <= currentScore) ||
(flags[LT] && exists && score >= currentScore) {
return &EvalResponse{
Result: nil,
Error: nil,
}
}
}

// Check if the member should be skipped based on NX or XX flags
if shouldSkipMember(score, currentScore, exists, flags) {
continue
}

// Insert or update the member in the sorted set
wasInserted := sortedSet.Upsert(score, member)

if wasInserted {
added += 1
if wasInserted && !exists {
added++
} else if exists && score != currentScore {
updated++
}

// If INCR is used, exit after processing one score-member pair
if flags[INCR] {
return &EvalResponse{
Result: score,
Error: nil,
}
}
}

obj = store.NewObj(sortedSet, -1, object.ObjTypeSortedSet, object.ObjEncodingBTree)
store.Put(key, obj, dstore.WithPutCmd(dstore.ZAdd))
// Store the updated sorted set in the store
storeUpdatedSet(store, key, sortedSet)

if flags[CH] {
return &EvalResponse{
Result: added + updated,
Error: nil,
}
}

// Return only the count of added members
return &EvalResponse{
Result: added,
Error: nil,
}
}

// shouldSkipMember determines if a member should be skipped based on flags.
func shouldSkipMember(score, currentScore float64, exists bool, flags map[string]bool) bool {
useNX, useXX, useLT, useGT := flags[NX], flags[XX], flags[LT], flags[GT]

return (useNX && exists) || (useXX && !exists) ||
(exists && useLT && score >= currentScore) ||
(exists && useGT && score <= currentScore)
}

// storeUpdatedSet stores the updated sorted set in the store.
func storeUpdatedSet(store *dstore.Store, key string, sortedSet *sortedset.Set) {
store.Put(key, store.NewObj(sortedSet, -1, object.ObjTypeSortedSet, object.ObjEncodingBTree), dstore.WithPutCmd(dstore.ZAdd))
}

// getOrCreateSortedSet fetches the sorted set if it exists, otherwise creates a new one.
func getOrCreateSortedSet(store *dstore.Store, key string) (*sortedset.Set, error) {
obj := store.Get(key)
if obj != nil {
sortedSet, err := sortedset.FromObject(obj)
if err != nil {
return nil, diceerrors.ErrWrongTypeOperation
}
return sortedSet, nil
}
return sortedset.New(), nil
}

// The ZCOUNT command in DiceDB counts the number of members in a sorted set at the specified key
// whose scores fall within a given range. The command takes three arguments: the key of the sorted set
// the minimum score, and the maximum score.
Expand Down
14 changes: 7 additions & 7 deletions internal/store/constants.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package store

const (
Set string = "SET"
Del string = "DEL"
Get string = "GET"
Rename string = "RENAME"
ZAdd string = "ZADD"
ZRange string = "ZRANGE"
PFADD string = "PFADD"
Set string = "SET"
Del string = "DEL"
Get string = "GET"
Rename string = "RENAME"
ZAdd string = "ZADD"
ZRange string = "ZRANGE"
PFADD string = "PFADD"
PFCOUNT string = "PFCOUNT"
PFMERGE string = "PFMERGE"
)
12 changes: 6 additions & 6 deletions internal/watchmanager/watch_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ type (

var (
affectedCmdMap = map[string]map[string]struct{}{
dstore.Set: {dstore.Get: struct{}{}},
dstore.Del: {dstore.Get: struct{}{}},
dstore.Rename: {dstore.Get: struct{}{}},
dstore.ZAdd: {dstore.ZRange: struct{}{}},
dstore.PFADD: {dstore.PFCOUNT: struct{}{}},
dstore.PFMERGE:{dstore.PFCOUNT: struct{}{}},
dstore.Set: {dstore.Get: struct{}{}},
dstore.Del: {dstore.Get: struct{}{}},
dstore.Rename: {dstore.Get: struct{}{}},
dstore.ZAdd: {dstore.ZRange: struct{}{}},
dstore.PFADD: {dstore.PFCOUNT: struct{}{}},
dstore.PFMERGE: {dstore.PFCOUNT: struct{}{}},
}
)

Expand Down

0 comments on commit 9f9b441

Please sign in to comment.