Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Q.UNWATCH to Websockets #1180

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
10 changes: 3 additions & 7 deletions integration_tests/commands/websocket/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,24 @@ import (
func TestMain(m *testing.M) {
var wg sync.WaitGroup

// Run the test server
// This is a synchronous method, because internally it
// checks for available port and then forks a goroutine
// to start the server
opts := TestServerOptions{
Port: testPort1,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
RunWebsocketServer(ctx, &wg, opts)

// Wait for the server to start
time.Sleep(2 * time.Second)

executor := NewWebsocketCommandExecutor()

// Run the test suite
exitCode := m.Run()

// shut down gracefully
executor := NewWebsocketCommandExecutor()
conn := executor.ConnectToServer()
executor.FireCommand(conn, "abort")

cancel()
wg.Wait()
os.Exit(exitCode)
}
131 changes: 128 additions & 3 deletions integration_tests/commands/websocket/qwatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ package websocket
import (
"testing"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)

func TestQWatch(t *testing.T) {
exec := NewWebsocketCommandExecutor()
conn := exec.ConnectToServer()

testCases := []struct {
name string
cmds []string
Expand All @@ -30,8 +30,8 @@ func TestQWatch(t *testing.T) {
// Add unregister test case to handle this scenario once qunwatch support is added
{
name: "Successful register",
cmds: []string{`Q.WATCH "SELECT $key, $value WHERE $key like 'test-key?'"`},
expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-key?'", []interface{}{}},
cmds: []string{`Q.WATCH "SELECT $key, $value WHERE $key like 'test-qwatch-key'"`},
expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-qwatch-key'", []interface{}{}},
},
}

Expand All @@ -51,3 +51,128 @@ func TestQWatch(t *testing.T) {
})
}
}

func TestQWatchWithMultipleClients(t *testing.T) {
numOfClients := 10
exec := NewWebsocketCommandExecutor()
clients := []*websocket.Conn{}

for i := 0; i < numOfClients; i++ {
client := exec.ConnectToServer()
defer client.Close()
clients = append(clients, client)
}

tc := struct {
cmd string
expect interface{} // immediate response after firing command
update interface{} // update received after value change
}{
cmd: `Q.WATCH "SELECT $key, $value WHERE $key like 'test-multiple-clients-key'"`,
expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-clients-key'", []interface{}{}},
update: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-clients-key'", []interface{}{[]interface{}{"test-multiple-clients-key", float64(1)}}},
}

for i := 0; i < numOfClients; i++ {
conn := clients[i]
// subscribe query
resp, err := exec.FireCommandAndReadResponse(conn, tc.cmd)
assert.Nil(t, err)
assert.ElementsMatch(t, tc.expect, resp, "Value mismatch for cmd %s", tc.cmd)
}

// create a fresh client to update key
c := exec.ConnectToServer()
defer c.Close()

// update key
resp, err := exec.FireCommandAndReadResponse(c, "SET test-multiple-clients-key 1")
assert.Nil(t, err)
assert.Equal(t, "OK", resp, "Value mismatch for cmd SET test-multiple-clients-key 1")

// read and validate query updates
for i := 0; i < numOfClients; i++ {
resp, err := exec.ReadResponse(clients[i], tc.cmd)
assert.Nil(t, err)
assert.Equal(t, tc.update, resp, "Value mismatch for reading query update message")
}
}

func TestQWatchWithMultipleClientsAndQueries(t *testing.T) {
numOfClients := 3
exec := NewWebsocketCommandExecutor()
clients := []*websocket.Conn{}

for i := 0; i < numOfClients; i++ {
client := exec.ConnectToServer()
defer client.Close()
clients = append(clients, client)
}

tests := []struct {
cmd string
expect interface{} // immediate response after firing command
update interface{} // update received after value change
}{
{
cmd: `Q.WATCH "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key1'"`,
expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key1'", []interface{}{}},
update: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key1'", []interface{}{[]interface{}{"test-multiple-client-queries-key1", float64(1)}}},
},
{
cmd: `Q.WATCH "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key2'"`,
expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key2'", []interface{}{}},
update: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key2'", []interface{}{[]interface{}{"test-multiple-client-queries-key2", float64(2)}}},
},
{
cmd: `Q.WATCH "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key3'"`,
expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key3'", []interface{}{}},
update: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key3'", []interface{}{[]interface{}{"test-multiple-client-queries-key3", float64(3)}}},
},
}

for i := 0; i < numOfClients; i++ {
conn := clients[i]

for j := 0; j < len(tests); j++ {
// subscribe query
resp, err := exec.FireCommandAndReadResponse(conn, tests[j].cmd)
assert.Nil(t, err)
assert.ElementsMatch(t, tests[j].expect, resp, "Value mismatch for cmd %s", tests[j].cmd)
}
}

// create a fresh client to update keys
c := exec.ConnectToServer()
defer c.Close()

// update keys
resp, err := exec.FireCommandAndReadResponse(c, "SET test-multiple-client-queries-key1 1")
assert.Nil(t, err)
assert.Equal(t, "OK", resp, "Value mismatch for cmd SET test-multiple-client-queries-key1 1")

resp, err = exec.FireCommandAndReadResponse(c, "SET test-multiple-client-queries-key2 2")
assert.Nil(t, err)
assert.Equal(t, "OK", resp, "Value mismatch for cmd SET test-multiple-client-queries-key2 1")

resp, err = exec.FireCommandAndReadResponse(c, "SET test-multiple-client-queries-key3 3")
assert.Nil(t, err)
assert.Equal(t, "OK", resp, "Value mismatch for cmd SET test-multiple-client-queries-key3 3")

// prepare expected updates array
want := []interface{}{}
for j := 0; j < len(tests); j++ {
want = append(want, tests[j].update)
}

// read and validate query updates
for i := 0; i < numOfClients; i++ {
var respArr []interface{}
for j := 0; j < len(tests); j++ {
resp, err := exec.ReadResponse(clients[i], tests[0].cmd)
assert.Nil(t, err)
respArr = append(respArr, resp)
}
assert.ElementsMatch(t, want, respArr, "Value mismatch for reading query update message")
}
}
33 changes: 20 additions & 13 deletions integration_tests/commands/websocket/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ type TestServerOptions struct {
Port int
}

type CommandExecutor interface {
FireCommand(cmd string) interface{}
Name() string
}

type WebsocketCommandExecutor struct {
baseURL string
websocketClient *http.Client
Expand Down Expand Up @@ -95,8 +90,20 @@ func (e *WebsocketCommandExecutor) FireCommand(conn *websocket.Conn, cmd string)
return nil
}

func (e *WebsocketCommandExecutor) Name() string {
return "Websocket"
func (e *WebsocketCommandExecutor) ReadResponse(conn *websocket.Conn, cmd string) (interface{}, error) {
// read the response
_, resp, err := conn.ReadMessage()
if err != nil {
return nil, err
}

// marshal to json
var respJSON interface{}
if err = json.Unmarshal(resp, &respJSON); err != nil {
return nil, fmt.Errorf("error unmarshaling response")
}

return respJSON, nil
}

func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOptions) {
Expand All @@ -110,33 +117,33 @@ func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerO
queryWatcherLocal := querymanager.NewQueryManager()
config.WebsocketPort = opt.Port
testServer := server.NewWebSocketServer(shardManager, testPort1)
shardManagerCtx, cancelShardManager := context.WithCancel(ctx)
setupCtx, cancelSetupCtx := context.WithCancel(ctx)

// run shard manager
wg.Add(1)
go func() {
defer wg.Done()
shardManager.Run(shardManagerCtx)
shardManager.Run(setupCtx)
}()

// run query manager
wg.Add(1)
go func() {
defer wg.Done()
queryWatcherLocal.Run(ctx, watchChan)
queryWatcherLocal.Run(setupCtx, watchChan)
}()

// start websocket server
wg.Add(1)
go func() {
defer wg.Done()
srverr := testServer.Run(ctx)
cancelSetupCtx()
if srverr != nil {
cancelShardManager()
if errors.Is(srverr, derrors.ErrAborted) {
if errors.Is(srverr, derrors.ErrAborted) || errors.Is(srverr, http.ErrServerClosed) {
return
}
slog.Debug("Websocket test server encountered an error: %v", slog.Any("error", srverr))
slog.Debug("Websocket test server encountered an error", slog.Any("error", srverr))
}
}()
}
4 changes: 2 additions & 2 deletions internal/eval/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -1848,7 +1848,7 @@ func EvalQWATCH(args []string, httpOp, websocketOp bool, client *comm.Client, st
}

// EvalQUNWATCH removes the specified key from the watch list for the caller client.
func EvalQUNWATCH(args []string, httpOp bool, client *comm.Client) []byte {
func EvalQUNWATCH(args []string, httpOp, websocketOp bool, client *comm.Client) []byte {
if len(args) != 1 {
return diceerrors.NewErrArity("Q.UNWATCH")
}
Expand All @@ -1857,7 +1857,7 @@ func EvalQUNWATCH(args []string, httpOp bool, client *comm.Client) []byte {
return clientio.Encode(e, false)
}

if httpOp {
if httpOp || websocketOp {
querymanager.QuerySubscriptionChan <- querymanager.QuerySubscription{
Subscribe: false,
Query: query,
Expand Down
2 changes: 1 addition & 1 deletion internal/eval/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func ExecuteCommand(c *cmd.DiceDBCmd, client *comm.Client, store *dstore.Store,
case "SUBSCRIBE", "Q.WATCH":
return &EvalResponse{Result: EvalQWATCH(c.Args, httpOp, websocketOp, client, store), Error: nil}
case "UNSUBSCRIBE", "Q.UNWATCH":
return &EvalResponse{Result: EvalQUNWATCH(c.Args, httpOp, client), Error: nil}
return &EvalResponse{Result: EvalQUNWATCH(c.Args, httpOp, websocketOp, client), Error: nil}
case auth.Cmd:
return &EvalResponse{Result: EvalAUTH(c.Args, client), Error: nil}
case "ABORT":
Expand Down
75 changes: 67 additions & 8 deletions internal/server/utils/redisCmdAdapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"errors"
"fmt"
"io"
"math"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -35,6 +36,7 @@
)

const QWatch string = "Q.WATCH"
const QUnwatch string = "Q.UNWATCH"

func ParseHTTPRequest(r *http.Request) (*cmd.DiceDBCmd, error) {
commandParts := strings.Split(strings.TrimPrefix(r.URL.Path, "/"), "/")
Expand Down Expand Up @@ -138,20 +140,29 @@

// handle commands with args
command = strings.ToUpper(cmdStr[:idx])
cmdStr = cmdStr[idx+1:]
args := cmdStr[idx+1:]

var cmdArr []string // args
// handle qwatch commands

// handle q.watch command
if command == QWatch {
// remove quotes from query string
cmdStr, err := strconv.Unquote(cmdStr)
arr, err := parseQWatchArgs(args)
if err != nil {
return nil, fmt.Errorf("error parsing q.watch query: %v", err)
return nil, err
}
cmdArr = []string{cmdStr}
} else {
cmdArr = arr

// handle q.unwatch command
} else if command == QUnwatch {
arr, err := parseQUnwatchArgs(args)
if err != nil {
return nil, err
}
cmdArr = arr

// handle other commands
cmdArr = strings.Split(cmdStr, " ")
} else {
cmdArr = strings.Split(args, " ")
}

// if key prefix is empty for JSON.INGEST command
Expand All @@ -166,6 +177,54 @@
}, nil
}

// parseQWatchArgs will parse Q.Watch args
// example: q.watch "query"
func parseQWatchArgs(args string) ([]string, error) {
// remove quotes from query string
query, err := strconv.Unquote(args)
if err != nil {
return nil, fmt.Errorf("error parsing q.watch query: %v", err)
}
return []string{query}, nil
}

// parseQUnwatchArgs will parse Q.Unwatch args
// example: q.unwatch 615405144 "query"
func parseQUnwatchArgs(args string) ([]string, error) {

Check failure on line 193 in internal/server/utils/redisCmdAdapter.go

View workflow job for this annotation

GitHub Actions / lint

empty-lines: extra empty line at the start of a block (revive)

// check if there are two args
idx := strings.Index(args, " \"")
if idx == -1 {
return nil, fmt.Errorf("error parsing q.unwatch args: clientID or query not found")
}

// extract first arg
id, err := strconv.Atoi(args[:idx])
if err != nil {
return nil, fmt.Errorf("invalid clientID")
}

if id < 0 {
return nil, fmt.Errorf("clientID must be positive")
}

if id > math.MaxUint32 {
return nil, fmt.Errorf("clientID must be less than 4294967295 (uint32)")
}

// clientID can be safely converted to uint32
clientID := strconv.Itoa(id)

// remove quotes from query string
query := args[idx+1:]
query, err = strconv.Unquote(query)
if err != nil {
return nil, fmt.Errorf("error parsing q.unwatch query: %v", err)
}

return []string{clientID, query}, nil
}

func processPriorityKeys(jsonBody map[string]interface{}, args *[]string) {
for _, key := range getPriorityKeys() {
if val, exists := jsonBody[key]; exists {
Expand Down
Loading
Loading