From 78e9d241f155188b229f3962fa3f00258a2e4da5 Mon Sep 17 00:00:00 2001 From: psr Date: Mon, 21 Oct 2024 15:44:41 +0530 Subject: [PATCH 1/9] added event loop for websocket qwatch --- .../commands/websocket/main_test.go | 10 +- .../commands/websocket/qwatch_test.go | 131 +++++++++++++++++- integration_tests/commands/websocket/setup.go | 33 +++-- internal/server/websocketServer.go | 120 ++++++++++++---- 4 files changed, 246 insertions(+), 48 deletions(-) diff --git a/integration_tests/commands/websocket/main_test.go b/integration_tests/commands/websocket/main_test.go index cc330a19d..cabf26798 100644 --- a/integration_tests/commands/websocket/main_test.go +++ b/integration_tests/commands/websocket/main_test.go @@ -16,29 +16,25 @@ func TestMain(m *testing.M) { slog.SetDefault(l) var wg sync.WaitGroup - // Run the test server - // This is a synchronous method, because internally it - // checks for available port and then forks a goroutine - // to start the server opts := TestServerOptions{ Port: testPort1, Logger: l, } ctx, cancel := context.WithCancel(context.Background()) + defer cancel() RunWebsocketServer(ctx, &wg, opts) // Wait for the server to start time.Sleep(2 * time.Second) - executor := NewWebsocketCommandExecutor() - // Run the test suite exitCode := m.Run() + // shut down gracefully + executor := NewWebsocketCommandExecutor() conn := executor.ConnectToServer() executor.FireCommand(conn, "abort") - cancel() wg.Wait() os.Exit(exitCode) } diff --git a/integration_tests/commands/websocket/qwatch_test.go b/integration_tests/commands/websocket/qwatch_test.go index 27c8009f0..2cbb93228 100644 --- a/integration_tests/commands/websocket/qwatch_test.go +++ b/integration_tests/commands/websocket/qwatch_test.go @@ -3,13 +3,13 @@ package websocket import ( "testing" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) func TestQWatch(t *testing.T) { exec := NewWebsocketCommandExecutor() conn := exec.ConnectToServer() - testCases := []struct { name string cmds []string @@ -30,8 +30,8 @@ func TestQWatch(t *testing.T) { // Add unregister test case to handle this scenario once qunwatch support is added { name: "Successful register", - cmds: []string{`Q.WATCH "SELECT $key, $value WHERE $key like 'test-key?'"`}, - expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-key?'", []interface{}{}}, + cmds: []string{`Q.WATCH "SELECT $key, $value WHERE $key like 'test-qwatch-key'"`}, + expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-qwatch-key'", []interface{}{}}, }, } @@ -51,3 +51,128 @@ func TestQWatch(t *testing.T) { }) } } + +func TestQWatchWithMultipleClients(t *testing.T) { + numOfClients := 10 + exec := NewWebsocketCommandExecutor() + clients := []*websocket.Conn{} + + for i := 0; i < numOfClients; i++ { + client := exec.ConnectToServer() + defer client.Close() + clients = append(clients, client) + } + + tc := struct { + cmd string + expect interface{} // immediate response after firing command + update interface{} // update received after value change + }{ + cmd: `Q.WATCH "SELECT $key, $value WHERE $key like 'test-multiple-clients-key'"`, + expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-clients-key'", []interface{}{}}, + update: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-clients-key'", []interface{}{[]interface{}{"test-multiple-clients-key", float64(1)}}}, + } + + for i := 0; i < numOfClients; i++ { + conn := clients[i] + // subscribe query + resp, err := exec.FireCommandAndReadResponse(conn, tc.cmd) + assert.Nil(t, err) + assert.ElementsMatch(t, tc.expect, resp, "Value mismatch for cmd %s", tc.cmd) + } + + // create a fresh client to update key + c := exec.ConnectToServer() + defer c.Close() + + // update key + resp, err := exec.FireCommandAndReadResponse(c, "SET test-multiple-clients-key 1") + assert.Nil(t, err) + assert.Equal(t, "OK", resp, "Value mismatch for cmd SET test-multiple-clients-key 1") + + // read and validate query updates + for i := 0; i < numOfClients; i++ { + resp, err := exec.ReadResponse(clients[i], tc.cmd) + assert.Nil(t, err) + assert.Equal(t, tc.update, resp, "Value mismatch for reading query update message") + } +} + +func TestQWatchWithMultipleClientsAndQueries(t *testing.T) { + numOfClients := 3 + exec := NewWebsocketCommandExecutor() + clients := []*websocket.Conn{} + + for i := 0; i < numOfClients; i++ { + client := exec.ConnectToServer() + defer client.Close() + clients = append(clients, client) + } + + tests := []struct { + cmd string + expect interface{} // immediate response after firing command + update interface{} // update received after value change + }{ + { + cmd: `Q.WATCH "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key1'"`, + expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key1'", []interface{}{}}, + update: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key1'", []interface{}{[]interface{}{"test-multiple-client-queries-key1", float64(1)}}}, + }, + { + cmd: `Q.WATCH "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key2'"`, + expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key2'", []interface{}{}}, + update: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key2'", []interface{}{[]interface{}{"test-multiple-client-queries-key2", float64(2)}}}, + }, + { + cmd: `Q.WATCH "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key3'"`, + expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key3'", []interface{}{}}, + update: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-multiple-client-queries-key3'", []interface{}{[]interface{}{"test-multiple-client-queries-key3", float64(3)}}}, + }, + } + + for i := 0; i < numOfClients; i++ { + conn := clients[i] + + for j := 0; j < len(tests); j++ { + // subscribe query + resp, err := exec.FireCommandAndReadResponse(conn, tests[j].cmd) + assert.Nil(t, err) + assert.ElementsMatch(t, tests[j].expect, resp, "Value mismatch for cmd %s", tests[j].cmd) + } + } + + // create a fresh client to update keys + c := exec.ConnectToServer() + defer c.Close() + + // update keys + resp, err := exec.FireCommandAndReadResponse(c, "SET test-multiple-client-queries-key1 1") + assert.Nil(t, err) + assert.Equal(t, "OK", resp, "Value mismatch for cmd SET test-multiple-client-queries-key1 1") + + resp, err = exec.FireCommandAndReadResponse(c, "SET test-multiple-client-queries-key2 2") + assert.Nil(t, err) + assert.Equal(t, "OK", resp, "Value mismatch for cmd SET test-multiple-client-queries-key2 1") + + resp, err = exec.FireCommandAndReadResponse(c, "SET test-multiple-client-queries-key3 3") + assert.Nil(t, err) + assert.Equal(t, "OK", resp, "Value mismatch for cmd SET test-multiple-client-queries-key3 3") + + // prepare expected updates array + want := []interface{}{} + for j := 0; j < len(tests); j++ { + want = append(want, tests[j].update) + } + + // read and validate query updates + for i := 0; i < numOfClients; i++ { + var respArr []interface{} + for j := 0; j < len(tests); j++ { + resp, err := exec.ReadResponse(clients[i], tests[0].cmd) + assert.Nil(t, err) + respArr = append(respArr, resp) + } + assert.ElementsMatch(t, want, respArr, "Value mismatch for reading query update message") + } +} diff --git a/integration_tests/commands/websocket/setup.go b/integration_tests/commands/websocket/setup.go index 485ca5183..27ec35f25 100644 --- a/integration_tests/commands/websocket/setup.go +++ b/integration_tests/commands/websocket/setup.go @@ -30,11 +30,6 @@ type TestServerOptions struct { Logger *slog.Logger } -type CommandExecutor interface { - FireCommand(cmd string) interface{} - Name() string -} - type WebsocketCommandExecutor struct { baseURL string websocketClient *http.Client @@ -96,8 +91,20 @@ func (e *WebsocketCommandExecutor) FireCommand(conn *websocket.Conn, cmd string) return nil } -func (e *WebsocketCommandExecutor) Name() string { - return "Websocket" +func (e *WebsocketCommandExecutor) ReadResponse(conn *websocket.Conn, cmd string) (interface{}, error) { + // read the response + _, resp, err := conn.ReadMessage() + if err != nil { + return nil, err + } + + // marshal to json + var respJSON interface{} + if err = json.Unmarshal(resp, &respJSON); err != nil { + return nil, fmt.Errorf("error unmarshaling response") + } + + return respJSON, nil } func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOptions) { @@ -112,20 +119,20 @@ func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerO queryWatcherLocal := querymanager.NewQueryManager(opt.Logger) config.WebsocketPort = opt.Port testServer := server.NewWebSocketServer(shardManager, testPort1, opt.Logger) - shardManagerCtx, cancelShardManager := context.WithCancel(ctx) + setupCtx, cancelSetupCtx := context.WithCancel(ctx) // run shard manager wg.Add(1) go func() { defer wg.Done() - shardManager.Run(shardManagerCtx) + shardManager.Run(setupCtx) }() // run query manager wg.Add(1) go func() { defer wg.Done() - queryWatcherLocal.Run(ctx, watchChan) + queryWatcherLocal.Run(setupCtx, watchChan) }() // start websocket server @@ -133,12 +140,12 @@ func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerO go func() { defer wg.Done() srverr := testServer.Run(ctx) + cancelSetupCtx() if srverr != nil { - cancelShardManager() - if errors.Is(srverr, derrors.ErrAborted) { + if errors.Is(srverr, derrors.ErrAborted) || errors.Is(srverr, http.ErrServerClosed) { return } - logger.Debug("Websocket test server encountered an error: %v", slog.Any("error", srverr)) + logger.Debug("Websocket test server encountered an error", slog.Any("error", srverr)) } }() } diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go index b6f1d1df6..16e2a118b 100644 --- a/internal/server/websocketServer.go +++ b/internal/server/websocketServer.go @@ -34,13 +34,23 @@ var unimplementedCommandsWebsocket = map[string]bool{ Qunwatch: true, } +type QuerySubscription struct { + Subscribe bool // true for subscribe, false for unsubscribe + Cmd *cmd.DiceDBCmd + ClientIdentifierID uint32 + Client *websocket.Conn +} + type WebsocketServer struct { shardManager *shard.ShardManager ioChan chan *ops.StoreResponse websocketServer *http.Server upgrader websocket.Upgrader + subscriptionChan chan QuerySubscription // to subscribe clients + subscribedClients sync.Map // to maintain records of subscribed clients qwatchResponseChan chan comm.QwatchResponse shutdownChan chan struct{} + mu sync.Mutex logger *slog.Logger } @@ -61,6 +71,8 @@ func NewWebSocketServer(shardManager *shard.ShardManager, port int, logger *slog ioChan: make(chan *ops.StoreResponse, 1000), websocketServer: srv, upgrader: upgrader, + subscriptionChan: make(chan QuerySubscription), + subscribedClients: sync.Map{}, qwatchResponseChan: make(chan comm.QwatchResponse), shutdownChan: make(chan struct{}), logger: logger, @@ -74,11 +86,23 @@ func (s *WebsocketServer) Run(ctx context.Context) error { var wg sync.WaitGroup var err error - websocketCtx, cancelWebsocket := context.WithCancel(ctx) - defer cancelWebsocket() + wsCtx, cancelWS := context.WithCancel(ctx) + defer cancelWS() s.shardManager.RegisterWorker("wsServer", s.ioChan, nil) + // start server + wg.Add(1) + go func() { + defer wg.Done() + s.logger.Info("Websocket Server running", slog.String("port", s.websocketServer.Addr[1:])) + err = s.websocketServer.ListenAndServe() + if err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error("Error in Websocket Server", slog.Any("time", time.Now()), slog.Any("error", err)) + } + }() + + // shutdown server gracefully wg.Add(1) go func() { defer wg.Done() @@ -86,24 +110,27 @@ func (s *WebsocketServer) Run(ctx context.Context) error { case <-ctx.Done(): case <-s.shutdownChan: err = diceerrors.ErrAborted - s.logger.Debug("Shutting down Websocket Server", slog.Any("time", time.Now())) } - shutdownErr := s.websocketServer.Shutdown(websocketCtx) + shutdownErr := s.websocketServer.Shutdown(wsCtx) if shutdownErr != nil { - s.logger.Error("Websocket Server shutdown failed:", slog.Any("error", err)) + s.logger.Error("Websocket Server shutdown failed:", slog.Any("error", shutdownErr)) return } }() + // listen for Q.WATCH subscriptions wg.Add(1) go func() { defer wg.Done() - s.logger.Info("Websocket Server running", slog.String("port", s.websocketServer.Addr[1:])) - err = s.websocketServer.ListenAndServe() - if err != nil { - s.logger.Debug("Error in Websocket Server", slog.Any("time", time.Now()), slog.Any("error", err)) - } + s.listenForSubscriptions(wsCtx) + }() + + // process Q.WATCH updates + wg.Add(1) + go func() { + defer wg.Done() + s.processQwatchUpdates(wsCtx) }() wg.Wait() @@ -119,6 +146,8 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques // closing handshake defer func() { + s.mu.Lock() + defer s.mu.Unlock() _ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "close 1000 (normal)")) conn.Close() }() @@ -141,7 +170,7 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques if errors.Is(err, diceerrors.ErrEmptyCommand) { continue } else if err != nil { - if err := WriteResponseWithRetries(conn, []byte("error: parsing failed"), maxRetries); err != nil { + if err := s.writeResponseWithRetries(conn, []byte("error: parsing failed"), maxRetries); err != nil { s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) } continue @@ -154,7 +183,7 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques } if unimplementedCommandsWebsocket[diceDBCmd.Cmd] { - if err := WriteResponseWithRetries(conn, []byte("Command is not implemented with Websocket"), maxRetries); err != nil { + if err := s.writeResponseWithRetries(conn, []byte("Command is not implemented with Websocket"), maxRetries); err != nil { s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) } continue @@ -173,8 +202,14 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques clientIdentifierID := generateUniqueInt32(r) sp.Client = comm.NewHTTPQwatchClient(s.qwatchResponseChan, clientIdentifierID) - // start a goroutine for subsequent updates - go s.processQwatchUpdates(clientIdentifierID, conn, diceDBCmd) + // subscribe client for updates + event := QuerySubscription{ + Subscribe: true, + Cmd: diceDBCmd, + ClientIdentifierID: clientIdentifierID, + Client: conn, + } + s.subscriptionChan <- event } s.shardManager.GetShard(0).ReqChan <- sp @@ -185,18 +220,47 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques } } -func (s *WebsocketServer) processQwatchUpdates(clientIdentifierID uint32, conn *websocket.Conn, dicDBCmd *cmd.DiceDBCmd) { +func (s *WebsocketServer) listenForSubscriptions(ctx context.Context) { + for { + select { + case event := <-s.subscriptionChan: + if event.Subscribe { + s.subscribedClients.LoadOrStore(event.ClientIdentifierID, event.Client) + } + case <-s.shutdownChan: + return + case <-ctx.Done(): + return + } + } +} + +func (s *WebsocketServer) processQwatchUpdates(ctx context.Context) { for { select { case resp := <-s.qwatchResponseChan: - if resp.ClientIdentifierID == clientIdentifierID { - if err := s.processResponse(conn, dicDBCmd, resp); err != nil { - s.logger.Debug("Error writing response to client. Shutting down goroutine for q.watch updates", slog.Any("clientIdentifierID", clientIdentifierID), slog.Any("error", err)) - return - } + client, ok := s.subscribedClients.Load(resp.ClientIdentifierID) + if !ok { + s.logger.Error("message received but client not found", slog.Any("clientIdentifierID", resp.ClientIdentifierID)) + } + conn, ok := client.(*websocket.Conn) + if !ok { + s.logger.Error("error typecasting client to *websocket.Conn") + } + + dicDBCmd := &cmd.DiceDBCmd{ + Cmd: Qwatch, + Args: []string{}, + } + + if err := s.processResponse(conn, dicDBCmd, resp); err != nil { + s.logger.Debug("Error writing qwatch update to client", slog.Any("clientIdentifierID", resp.ClientIdentifierID), slog.Any("error", err)) + continue } case <-s.shutdownChan: return + case <-ctx.Done(): + return } } } @@ -216,7 +280,7 @@ func (s *WebsocketServer) processResponse(conn *websocket.Conn, diceDBCmd *cmd.D err = resp.EvalResponse.Error default: s.logger.Debug("Unsupported response type") - if err := WriteResponseWithRetries(conn, []byte("error: 500 Internal Server Error"), maxRetries); err != nil { + if err := s.writeResponseWithRetries(conn, []byte("error: 500 Internal Server Error"), maxRetries); err != nil { s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) return fmt.Errorf("error writing response: %v", err) } @@ -248,7 +312,7 @@ func (s *WebsocketServer) processResponse(conn *websocket.Conn, diceDBCmd *cmd.D responseValue, err = rp.DecodeOne() if err != nil { s.logger.Debug("Error decoding response", "error", err) - if err := WriteResponseWithRetries(conn, []byte("error: 500 Internal Server Error"), maxRetries); err != nil { + if err := s.writeResponseWithRetries(conn, []byte("error: 500 Internal Server Error"), maxRetries); err != nil { s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) return fmt.Errorf("error writing response: %v", err) } @@ -273,7 +337,7 @@ func (s *WebsocketServer) processResponse(conn *websocket.Conn, diceDBCmd *cmd.D respBytes, err := json.Marshal(responseValue) if err != nil { s.logger.Debug("Error marshaling json", "error", err) - if err := WriteResponseWithRetries(conn, []byte("error: marshaling json"), maxRetries); err != nil { + if err := s.writeResponseWithRetries(conn, []byte("error: marshaling json"), maxRetries); err != nil { s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) return fmt.Errorf("error writing response: %v", err) } @@ -281,8 +345,7 @@ func (s *WebsocketServer) processResponse(conn *websocket.Conn, diceDBCmd *cmd.D } // success - // Write response with retries for transient errors - if err := WriteResponseWithRetries(conn, respBytes, config.DiceConfig.WebSocket.MaxWriteResponseRetries); err != nil { + if err := s.writeResponseWithRetries(conn, respBytes, config.DiceConfig.WebSocket.MaxWriteResponseRetries); err != nil { s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) return fmt.Errorf("error writing response: %v", err) } @@ -290,6 +353,13 @@ func (s *WebsocketServer) processResponse(conn *websocket.Conn, diceDBCmd *cmd.D return nil } +func (s *WebsocketServer) writeResponseWithRetries(conn *websocket.Conn, text []byte, maxRetries int) error { + s.mu.Lock() + defer s.mu.Unlock() + return WriteResponseWithRetries(conn, text, maxRetries) +} + +// WriteResponseWithRetries wrties response with retries for transient errors func WriteResponseWithRetries(conn *websocket.Conn, text []byte, maxRetries int) error { for attempts := 0; attempts < maxRetries; attempts++ { // Set a write deadline From 7f2ae4a4c4b875485a8669a6db479711fa052c94 Mon Sep 17 00:00:00 2001 From: psr Date: Mon, 21 Oct 2024 20:44:39 +0530 Subject: [PATCH 2/9] added q.unwatch parsing logic in ParseWebsocketMessage --- internal/server/utils/redisCmdAdapter.go | 75 +++++++++++++++++++++--- 1 file changed, 67 insertions(+), 8 deletions(-) diff --git a/internal/server/utils/redisCmdAdapter.go b/internal/server/utils/redisCmdAdapter.go index ce19a33bc..fa73b18f8 100644 --- a/internal/server/utils/redisCmdAdapter.go +++ b/internal/server/utils/redisCmdAdapter.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "math" "net/http" "strconv" "strings" @@ -35,6 +36,7 @@ const ( ) const QWatch string = "Q.WATCH" +const QUnwatch string = "Q.UNWATCH" func ParseHTTPRequest(r *http.Request) (*cmd.DiceDBCmd, error) { commandParts := strings.Split(strings.TrimPrefix(r.URL.Path, "/"), "/") @@ -138,20 +140,29 @@ func ParseWebsocketMessage(msg []byte) (*cmd.DiceDBCmd, error) { // handle commands with args command = strings.ToUpper(cmdStr[:idx]) - cmdStr = cmdStr[idx+1:] + args := cmdStr[idx+1:] var cmdArr []string // args - // handle qwatch commands + + // handle q.watch command if command == QWatch { - // remove quotes from query string - cmdStr, err := strconv.Unquote(cmdStr) + arr, err := parseQWatchArgs(args) if err != nil { - return nil, fmt.Errorf("error parsing q.watch query: %v", err) + return nil, err } - cmdArr = []string{cmdStr} - } else { + cmdArr = arr + + // handle q.unwatch command + } else if command == QUnwatch { + arr, err := parseQUnwatchArgs(args) + if err != nil { + return nil, err + } + cmdArr = arr + // handle other commands - cmdArr = strings.Split(cmdStr, " ") + } else { + cmdArr = strings.Split(args, " ") } // if key prefix is empty for JSON.INGEST command @@ -166,6 +177,54 @@ func ParseWebsocketMessage(msg []byte) (*cmd.DiceDBCmd, error) { }, nil } +// parseQWatchArgs will parse Q.Watch args +// example: q.watch "query" +func parseQWatchArgs(args string) ([]string, error) { + // remove quotes from query string + query, err := strconv.Unquote(args) + if err != nil { + return nil, fmt.Errorf("error parsing q.watch query: %v", err) + } + return []string{query}, nil +} + +// parseQUnwatchArgs will parse Q.Unwatch args +// example: q.unwatch 615405144 "query" +func parseQUnwatchArgs(args string) ([]string, error) { + + // check if there are two args + idx := strings.Index(args, " \"") + if idx == -1 { + return nil, fmt.Errorf("error parsing q.unwatch args: clientID or query not found") + } + + // extract first arg + id, err := strconv.Atoi(args[:idx]) + if err != nil { + return nil, fmt.Errorf("invalid clientID") + } + + if id < 0 { + return nil, fmt.Errorf("clientID must be positive") + } + + if id > math.MaxUint32 { + return nil, fmt.Errorf("clientID must be less than 4294967295 (uint32)") + } + + // clientID can be safely converted to uint32 + clientID := strconv.Itoa(id) + + // remove quotes from query string + query := args[idx:] + query, err = strconv.Unquote(query) + if err != nil { + return nil, fmt.Errorf("error parsing q.unwatch query: %v", err) + } + + return []string{clientID, query}, nil +} + func processPriorityKeys(jsonBody map[string]interface{}, args *[]string) { for _, key := range getPriorityKeys() { if val, exists := jsonBody[key]; exists { From 397c80dd638009edd1925657f97fbf7846ca9ed1 Mon Sep 17 00:00:00 2001 From: psr Date: Mon, 21 Oct 2024 20:45:14 +0530 Subject: [PATCH 3/9] added unit tests for q.unwatch parsing logic --- internal/server/utils/redisCmdAdapter_test.go | 336 ++++++++++++------ 1 file changed, 219 insertions(+), 117 deletions(-) diff --git a/internal/server/utils/redisCmdAdapter_test.go b/internal/server/utils/redisCmdAdapter_test.go index 929b3b1f9..7f2b45d8c 100644 --- a/internal/server/utils/redisCmdAdapter_test.go +++ b/internal/server/utils/redisCmdAdapter_test.go @@ -237,166 +237,256 @@ func TestParseHTTPRequest(t *testing.T) { func TestParseWebsocketMessage(t *testing.T) { commands := []struct { - name string - message string - expectedCmd string - expectedArgs []string + name string + message string + expectedCmd string + expectedArgs []string + expectedError string }{ { - name: "Test SET command with nx flag", - message: "set k1 v1 nx", - expectedCmd: "SET", - expectedArgs: []string{"k1", "v1", "nx"}, + name: "SET command with nx flag", + message: "set k1 v1 nx", + expectedCmd: "SET", + expectedArgs: []string{"k1", "v1", "nx"}, + expectedError: "", }, { - name: "Test SET command with value as a map", - message: `set k0 {"k1":"v1"} nx`, - expectedCmd: "SET", - expectedArgs: []string{"k0", `{"k1":"v1"}`, "nx"}, + name: "Test SET command with value as a map", + message: `set k0 {"k1":"v1"} nx`, + expectedCmd: "SET", + expectedArgs: []string{"k0", `{"k1":"v1"}`, "nx"}, + expectedError: "", }, { - name: "Test SET command with value as an array", - message: `set k1 ["v1","v2","v3"] nx`, - expectedCmd: "SET", - expectedArgs: []string{"k1", `["v1","v2","v3"]`, "nx"}, + name: "Test SET command with value as an array", + message: `set k1 ["v1","v2","v3"] nx`, + expectedCmd: "SET", + expectedArgs: []string{"k1", `["v1","v2","v3"]`, "nx"}, + expectedError: "", }, { - name: "Test SET command with value as a map containing an array", - message: `set k1 {"k2":["v1","v2"]} nx`, - expectedCmd: "SET", - expectedArgs: []string{"k1", `{"k2":["v1","v2"]}`, "nx"}, + name: "Test SET command with value as a map containing an array", + message: `set k1 {"k2":["v1","v2"]} nx`, + expectedCmd: "SET", + expectedArgs: []string{"k1", `{"k2":["v1","v2"]}`, "nx"}, + expectedError: "", }, { - name: "Test SET command with value as a deeply nested map", - message: `set k1 {"k2":{"k3":{"k4":"value"}}} nx`, - expectedCmd: "SET", - expectedArgs: []string{"k1", `{"k2":{"k3":{"k4":"value"}}}`, "nx"}, + name: "Test SET command with value as a deeply nested map", + message: `set k1 {"k2":{"k3":{"k4":"value"}}} nx`, + expectedCmd: "SET", + expectedArgs: []string{"k1", `{"k2":{"k3":{"k4":"value"}}}`, "nx"}, + expectedError: "", }, { - name: "Test SET command with value as an array of maps", - message: `set k0 [{"k1":"v1"},{"k2":"v2"}] nx`, - expectedCmd: "SET", - expectedArgs: []string{"k0", `[{"k1":"v1"},{"k2":"v2"}]`, "nx"}, + name: "Test SET command with value as an array of maps", + message: `set k0 [{"k1":"v1"},{"k2":"v2"}] nx`, + expectedCmd: "SET", + expectedArgs: []string{"k0", `[{"k1":"v1"},{"k2":"v2"}]`, "nx"}, + expectedError: "", }, { - name: "Test GET command", - message: "get k1", - expectedCmd: "GET", - expectedArgs: []string{"k1"}, + name: "Test GET command", + message: "get k1", + expectedCmd: "GET", + expectedArgs: []string{"k1"}, + expectedError: "", }, { - name: "Test DEL command", - message: "del k1", - expectedCmd: "DEL", - expectedArgs: []string{"k1"}, + name: "Test DEL command", + message: "del k1", + expectedCmd: "DEL", + expectedArgs: []string{"k1"}, + expectedError: "", }, { - name: "Test DEL command with multiple keys", - message: `del k1 k2 k3`, - expectedCmd: "DEL", - expectedArgs: []string{"k1", "k2", "k3"}, + name: "Test DEL command with multiple keys", + message: `del k1 k2 k3`, + expectedCmd: "DEL", + expectedArgs: []string{"k1", "k2", "k3"}, + expectedError: "", }, { - name: "Test KEYS command", - message: "keys *", - expectedCmd: "KEYS", - expectedArgs: []string{"*"}, + name: "Test KEYS command", + message: "keys *", + expectedCmd: "KEYS", + expectedArgs: []string{"*"}, + expectedError: "", }, { - name: "Test MSET command", - message: "mset k1 v1 k2 v2", - expectedCmd: "MSET", - expectedArgs: []string{"k1", "v1", "k2", "v2"}, + name: "Test MSET command", + message: "mset k1 v1 k2 v2", + expectedCmd: "MSET", + expectedArgs: []string{"k1", "v1", "k2", "v2"}, + expectedError: "", }, { - name: "Test MSET command with options", - message: "mset k1 v1 k2 v2 nx", - expectedCmd: "MSET", - expectedArgs: []string{"k1", "v1", "k2", "v2", "nx"}, + name: "Test MSET command with options", + message: "mset k1 v1 k2 v2 nx", + expectedCmd: "MSET", + expectedArgs: []string{"k1", "v1", "k2", "v2", "nx"}, + expectedError: "", }, { - name: "Test SLEEP command", - message: "sleep 1", - expectedCmd: "SLEEP", - expectedArgs: []string{"1"}, + name: "Test SLEEP command", + message: "sleep 1", + expectedCmd: "SLEEP", + expectedArgs: []string{"1"}, + expectedError: "", }, { - name: "Test PING command", - message: "ping", - expectedCmd: "PING", - expectedArgs: nil, + name: "Test PING command", + message: "ping", + expectedCmd: "PING", + expectedArgs: nil, + expectedError: "", }, { - name: "Test EXPIRE command", - message: "expire k1 1", - expectedCmd: "EXPIRE", - expectedArgs: []string{"k1", "1"}, + name: "Test EXPIRE command", + message: "expire k1 1", + expectedCmd: "EXPIRE", + expectedArgs: []string{"k1", "1"}, + expectedError: "", }, { - name: "Test AUTH command", - message: "auth user password", - expectedCmd: "AUTH", - expectedArgs: []string{"user", "password"}, + name: "Test AUTH command", + message: "auth user password", + expectedCmd: "AUTH", + expectedArgs: []string{"user", "password"}, + expectedError: "", }, { - name: "Test LPUSH command", - message: "lpush k1 v1", - expectedCmd: "LPUSH", - expectedArgs: []string{"k1", "v1"}, + name: "Test LPUSH command", + message: "lpush k1 v1", + expectedCmd: "LPUSH", + expectedArgs: []string{"k1", "v1"}, + expectedError: "", }, { - name: "Test LPUSH command with multiple items", - message: `lpush k1 v1 v2 v3`, - expectedCmd: "LPUSH", - expectedArgs: []string{"k1", "v1", "v2", "v3"}, + name: "Test LPUSH command with multiple items", + message: `lpush k1 v1 v2 v3`, + expectedCmd: "LPUSH", + expectedArgs: []string{"k1", "v1", "v2", "v3"}, + expectedError: "", }, { - name: "Test JSON.ARRPOP command", - message: "json.arrpop k1 $ 1", - expectedCmd: "JSON.ARRPOP", - expectedArgs: []string{"k1", "$", "1"}, + name: "Test JSON.ARRPOP command", + message: "json.arrpop k1 $ 1", + expectedCmd: "JSON.ARRPOP", + expectedArgs: []string{"k1", "$", "1"}, + expectedError: "", }, { - name: "Test JSON.SET command", - message: `json.set k1 . {"field":"value"}`, - expectedCmd: "JSON.SET", - expectedArgs: []string{"k1", ".", `{"field":"value"}`}, + name: "Test JSON.SET command", + message: `json.set k1 . {"field":"value"}`, + expectedCmd: "JSON.SET", + expectedArgs: []string{"k1", ".", `{"field":"value"}`}, + expectedError: "", }, { - name: "Test JSON.GET command", - message: "json.get k1", - expectedCmd: "JSON.GET", - expectedArgs: []string{"k1"}, + name: "Test JSON.GET command", + message: "json.get k1", + expectedCmd: "JSON.GET", + expectedArgs: []string{"k1"}, + expectedError: "", }, { - name: "Test HSET command with JSON body", - message: "hset hashkey f1 v1", - expectedCmd: "HSET", - expectedArgs: []string{"hashkey", "f1", "v1"}, + name: "Test HSET command with JSON body", + message: "hset hashkey f1 v1", + expectedCmd: "HSET", + expectedArgs: []string{"hashkey", "f1", "v1"}, + expectedError: "", }, { - name: "Test JSON.INGEST command with key prefix", - message: `json.ingest gmtr_ $..field {"field":"value"}`, - expectedCmd: "JSON.INGEST", - expectedArgs: []string{"gmtr_", "$..field", `{"field":"value"}`}, + name: "Test JSON.INGEST command with key prefix", + message: `json.ingest gmtr_ $..field {"field":"value"}`, + expectedCmd: "JSON.INGEST", + expectedArgs: []string{"gmtr_", "$..field", `{"field":"value"}`}, + expectedError: "", }, { - name: "Test JSON.INGEST command without key prefix", - message: `json.ingest $..field {"field":"value"}`, - expectedCmd: "JSON.INGEST", - expectedArgs: []string{"", "$..field", `{"field":"value"}`}, + name: "Test JSON.INGEST command without key prefix", + message: `json.ingest $..field {"field":"value"}`, + expectedCmd: "JSON.INGEST", + expectedArgs: []string{"", "$..field", `{"field":"value"}`}, + expectedError: "", }, { - name: "Test simple Q.WATCH command", - message: "q.watch \"select $key, $value where $key like 'k?'\"", - expectedCmd: "Q.WATCH", - expectedArgs: []string{"select $key, $value where $key like 'k?'"}, + name: "invalid Q.WATCH no args", + message: "q.watch", + expectedCmd: "Q.WATCH", + expectedArgs: nil, + expectedError: "", }, { - name: "Test complex Q.WATCH command", - message: "q.watch \"SELECT $key, $value WHERE $key LIKE 'player:*' AND '$value.score' > 10 ORDER BY $value.score DESC LIMIT 5\"", - expectedCmd: "Q.WATCH", - expectedArgs: []string{"SELECT $key, $value WHERE $key LIKE 'player:*' AND '$value.score' > 10 ORDER BY $value.score DESC LIMIT 5"}, + name: "invalid Q.WATCH invalid query", + message: `q.watch \"select $key, $value where $key like 'k?'\"`, // backticks will retain escaped characters as it is + expectedCmd: "Q.WATCH", + expectedArgs: nil, + expectedError: "error parsing q.watch query: invalid syntax", + }, + { + name: "valid Q.WATCH simple query", + message: "q.watch \"select $key, $value where $key like 'k?'\"", + expectedCmd: "Q.WATCH", + expectedArgs: []string{"select $key, $value where $key like 'k?'"}, + expectedError: "", + }, + { + name: "valid Q.WATCH complex query", + message: "q.watch \"SELECT $key, $value WHERE $key LIKE 'player:*' AND '$value.score' > 10 ORDER BY $value.score DESC LIMIT 5\"", + expectedCmd: "Q.WATCH", + expectedArgs: []string{"SELECT $key, $value WHERE $key LIKE 'player:*' AND '$value.score' > 10 ORDER BY $value.score DESC LIMIT 5"}, + expectedError: "", + }, + { + name: "invalid Q.UNWATCH no args", + message: "q.unwatch", + expectedCmd: "Q.UNWATCH", + expectedArgs: nil, + expectedError: "", + }, + { + name: "invalid Q.UNWATCH clientID missing", + message: "q.unwatch \"select $key, $value where $key like 'k?'\"", + expectedCmd: "Q.UNWATCH", + expectedArgs: nil, + expectedError: "error parsing q.unwatch args: clientID or query not found", + }, + { + name: "invalid Q.UNWATCH query missing", + message: "q.unwatch 615405144", + expectedCmd: "Q.UNWATCH", + expectedArgs: nil, + expectedError: "error parsing q.unwatch args: clientID or query not found", + }, + { + name: "invalid Q.UNWATCH invalid clientID", + message: "q.unwatch 61abc5144 \"select $key, $value where $key like 'k?'\"", + expectedCmd: "Q.UNWATCH", + expectedArgs: nil, + expectedError: "invalid clientID", + }, + { + name: "invalid Q.UNWATCH negative clientID", + message: "q.unwatch -1 \"select $key, $value where $key like 'k?'\"", + expectedCmd: "Q.UNWATCH", + expectedArgs: nil, + expectedError: "clientID must be positive", + }, + { + name: "invalid Q.UNWATCH overflowing clientID", + message: "q.unwatch 4294967296 \"select $key, $value where $key like 'k?'\"", + expectedCmd: "Q.UNWATCH", + expectedArgs: nil, + expectedError: "clientID must be less than 4294967295 (uint32)", + }, + { + name: "invalid Q.UNWATCH invalid query", + message: "q.unwatch 615405144 \"select $key, $value where $key like 'k?'", // ending " missing for query + expectedCmd: "Q.UNWATCH", + expectedArgs: nil, + expectedError: "error parsing q.unwatch query: invalid syntax", }, } @@ -404,18 +494,30 @@ func TestParseWebsocketMessage(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // parse websocket message diceDBCmd, err := ParseWebsocketMessage([]byte(tc.message)) - assert.NoError(t, err) - expectedCmd := &cmd.DiceDBCmd{ - Cmd: tc.expectedCmd, - Args: tc.expectedArgs, - } + // error cases + if tc.expectedError != "" { + if err == nil { + t.Errorf("received nil error but expected error is: %v", tc.expectedError) + } + assert.Equal(t, tc.expectedError, err.Error()) + assert.Nil(t, tc.expectedArgs, "received error and not nil args") - // Check command match - assert.Equal(t, expectedCmd.Cmd, diceDBCmd.Cmd) + // non error cases + } else { + assert.NoError(t, err) + expectedCmd := &cmd.DiceDBCmd{ + Cmd: tc.expectedCmd, + Args: tc.expectedArgs, + } + + // Check command match + assert.Equal(t, expectedCmd.Cmd, diceDBCmd.Cmd) + + // Check arguments match, regardless of order + assert.ElementsMatch(t, expectedCmd.Args, diceDBCmd.Args, "The parsed arguments should match the expected arguments, ignoring order") + } - // Check arguments match, regardless of order - assert.ElementsMatch(t, expectedCmd.Args, diceDBCmd.Args, "The parsed arguments should match the expected arguments, ignoring order") }) } } From bd55871e325da8ad63ca1908f61d5717ce1b7fab Mon Sep 17 00:00:00 2001 From: psr Date: Tue, 22 Oct 2024 14:27:40 +0530 Subject: [PATCH 4/9] bug fix in parsing logic --- internal/server/utils/redisCmdAdapter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/server/utils/redisCmdAdapter.go b/internal/server/utils/redisCmdAdapter.go index fa73b18f8..228389a86 100644 --- a/internal/server/utils/redisCmdAdapter.go +++ b/internal/server/utils/redisCmdAdapter.go @@ -216,7 +216,7 @@ func parseQUnwatchArgs(args string) ([]string, error) { clientID := strconv.Itoa(id) // remove quotes from query string - query := args[idx:] + query := args[idx+1:] query, err = strconv.Unquote(query) if err != nil { return nil, fmt.Errorf("error parsing q.unwatch query: %v", err) From ab86b01e4ed8f133e82a2ba156de7fd0fed1d7a1 Mon Sep 17 00:00:00 2001 From: psr Date: Tue, 22 Oct 2024 14:28:17 +0530 Subject: [PATCH 5/9] adding websocketOp to evalQUNWATCH --- internal/eval/eval.go | 4 ++-- internal/eval/execute.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/eval/eval.go b/internal/eval/eval.go index 774ade1de..f81ffbed5 100644 --- a/internal/eval/eval.go +++ b/internal/eval/eval.go @@ -1832,7 +1832,7 @@ func EvalQWATCH(args []string, httpOp, websocketOp bool, client *comm.Client, st } // EvalQUNWATCH removes the specified key from the watch list for the caller client. -func EvalQUNWATCH(args []string, httpOp bool, client *comm.Client) []byte { +func EvalQUNWATCH(args []string, httpOp, websocketOp bool, client *comm.Client) []byte { if len(args) != 1 { return diceerrors.NewErrArity("Q.UNWATCH") } @@ -1841,7 +1841,7 @@ func EvalQUNWATCH(args []string, httpOp bool, client *comm.Client) []byte { return clientio.Encode(e, false) } - if httpOp { + if httpOp || websocketOp { querymanager.QuerySubscriptionChan <- querymanager.QuerySubscription{ Subscribe: false, Query: query, diff --git a/internal/eval/execute.go b/internal/eval/execute.go index 35cdda04e..42345b041 100644 --- a/internal/eval/execute.go +++ b/internal/eval/execute.go @@ -33,7 +33,7 @@ func ExecuteCommand(c *cmd.DiceDBCmd, client *comm.Client, store *dstore.Store, case "SUBSCRIBE", "Q.WATCH": return &EvalResponse{Result: EvalQWATCH(c.Args, httpOp, websocketOp, client, store), Error: nil} case "UNSUBSCRIBE", "Q.UNWATCH": - return &EvalResponse{Result: EvalQUNWATCH(c.Args, httpOp, client), Error: nil} + return &EvalResponse{Result: EvalQUNWATCH(c.Args, httpOp, websocketOp, client), Error: nil} case auth.Cmd: return &EvalResponse{Result: EvalAUTH(c.Args, client), Error: nil} case "ABORT": From 4f52c05988e51d6e8f073d51cd330adcd658d059 Mon Sep 17 00:00:00 2001 From: psr Date: Tue, 22 Oct 2024 14:28:59 +0530 Subject: [PATCH 6/9] added qunwatch to websocketServer --- internal/server/websocketServer.go | 111 ++++++++++++++++++++--------- 1 file changed, 76 insertions(+), 35 deletions(-) diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go index 16e2a118b..b2ecacb71 100644 --- a/internal/server/websocketServer.go +++ b/internal/server/websocketServer.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "os" + "strconv" "sync" "syscall" "time" @@ -28,14 +29,9 @@ import ( const Qwatch = "Q.WATCH" const Qunwatch = "Q.UNWATCH" -const Subscribe = "SUBSCRIBE" - -var unimplementedCommandsWebsocket = map[string]bool{ - Qunwatch: true, -} type QuerySubscription struct { - Subscribe bool // true for subscribe, false for unsubscribe + Subscribe bool // true for subscribe, not used for unsubscribe Cmd *cmd.DiceDBCmd ClientIdentifierID uint32 Client *websocket.Conn @@ -165,12 +161,13 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques break } - // parse message to dice command + // parse message diceDBCmd, err := utils.ParseWebsocketMessage(msg) if errors.Is(err, diceerrors.ErrEmptyCommand) { continue } else if err != nil { - if err := s.writeResponseWithRetries(conn, []byte("error: parsing failed"), maxRetries); err != nil { + msg := fmt.Sprintf("error: parsing failed: %v", err) + if err := s.writeResponseWithRetries(conn, []byte(msg), maxRetries); err != nil { s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) } continue @@ -182,13 +179,6 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques break } - if unimplementedCommandsWebsocket[diceDBCmd.Cmd] { - if err := s.writeResponseWithRetries(conn, []byte("Command is not implemented with Websocket"), maxRetries); err != nil { - s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) - } - continue - } - // create request sp := &ops.StoreOp{ Cmd: diceDBCmd, @@ -197,21 +187,29 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques WebsocketOp: true, } - // handle q.watch commands - if diceDBCmd.Cmd == Qwatch || diceDBCmd.Cmd == Subscribe { - clientIdentifierID := generateUniqueInt32(r) - sp.Client = comm.NewHTTPQwatchClient(s.qwatchResponseChan, clientIdentifierID) - - // subscribe client for updates - event := QuerySubscription{ - Subscribe: true, - Cmd: diceDBCmd, - ClientIdentifierID: clientIdentifierID, - Client: conn, + // subscribe + if diceDBCmd.Cmd == Qwatch { + id := generateUniqueInt32(r) + sp.Client = comm.NewHTTPQwatchClient(s.qwatchResponseChan, id) + fmt.Println("id is: ", id) + s.subscribe(conn, diceDBCmd, id) + } + + // unsubscribe + if diceDBCmd.Cmd == Qunwatch { + s.unsubscribe(conn, diceDBCmd) + + id, err := strconv.Atoi(diceDBCmd.Args[0]) + if err != nil { + msg := fmt.Sprintf("invalid client id: %v. err: %v", diceDBCmd.Args[0], err) + s.writeResponseWithRetries(conn, []byte(msg), maxRetries) } - s.subscriptionChan <- event + sp.Client = comm.NewHTTPQwatchClient(s.qwatchResponseChan, uint32(id)) + + diceDBCmd.Args = diceDBCmd.Args[1:] // pop clientID } + // execute command s.shardManager.GetShard(0).ReqChan <- sp resp := <-s.ioChan if err := s.processResponse(conn, diceDBCmd, resp); err != nil { @@ -220,6 +218,37 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques } } +func (s *WebsocketServer) subscribe(conn *websocket.Conn, diceDBCmd *cmd.DiceDBCmd, id uint32) { + // subscribe client + event := QuerySubscription{ + Subscribe: true, + Cmd: diceDBCmd, + ClientIdentifierID: id, + Client: conn, + } + s.subscriptionChan <- event +} + +func (s *WebsocketServer) unsubscribe(conn *websocket.Conn, diceDBCmd *cmd.DiceDBCmd) { + maxRetries := config.DiceConfig.WebSocket.MaxWriteResponseRetries + // convert id + id, err := strconv.Atoi(diceDBCmd.Args[0]) + if err != nil { + msg := fmt.Sprintf("invalid client id: %v. err: %v", diceDBCmd.Args[0], err) + s.writeResponseWithRetries(conn, []byte(msg), maxRetries) + } + + // check if client exits + _, err = s.getClientById(uint32(id)) + if err != nil { + msg := fmt.Sprintf("error getting client: %v", err) + s.writeResponseWithRetries(conn, []byte(msg), maxRetries) + } + + // remove client + s.deleteClientById(uint32(id)) +} + func (s *WebsocketServer) listenForSubscriptions(ctx context.Context) { for { select { @@ -239,13 +268,9 @@ func (s *WebsocketServer) processQwatchUpdates(ctx context.Context) { for { select { case resp := <-s.qwatchResponseChan: - client, ok := s.subscribedClients.Load(resp.ClientIdentifierID) - if !ok { - s.logger.Error("message received but client not found", slog.Any("clientIdentifierID", resp.ClientIdentifierID)) - } - conn, ok := client.(*websocket.Conn) - if !ok { - s.logger.Error("error typecasting client to *websocket.Conn") + client, err := s.getClientById(resp.ClientIdentifierID) + if err != nil { + s.logger.Error("message received but client not found or invalid", slog.Any("error", err)) } dicDBCmd := &cmd.DiceDBCmd{ @@ -253,7 +278,7 @@ func (s *WebsocketServer) processQwatchUpdates(ctx context.Context) { Args: []string{}, } - if err := s.processResponse(conn, dicDBCmd, resp); err != nil { + if err := s.processResponse(client, dicDBCmd, resp); err != nil { s.logger.Debug("Error writing qwatch update to client", slog.Any("clientIdentifierID", resp.ClientIdentifierID), slog.Any("error", err)) continue } @@ -353,6 +378,22 @@ func (s *WebsocketServer) processResponse(conn *websocket.Conn, diceDBCmd *cmd.D return nil } +func (s *WebsocketServer) getClientById(id uint32) (*websocket.Conn, error) { + client, ok := s.subscribedClients.Load(id) + if !ok { + return nil, fmt.Errorf("client not found: %v", id) + } + conn, ok := client.(*websocket.Conn) + if !ok { + return nil, fmt.Errorf("error typecasting client") + } + return conn, nil +} + +func (s *WebsocketServer) deleteClientById(id uint32) { + s.subscribedClients.Delete(id) +} + func (s *WebsocketServer) writeResponseWithRetries(conn *websocket.Conn, text []byte, maxRetries int) error { s.mu.Lock() defer s.mu.Unlock() From d51ad50c3a68b946f22068af156710ad807c90a5 Mon Sep 17 00:00:00 2001 From: psr Date: Wed, 23 Oct 2024 13:16:34 +0530 Subject: [PATCH 7/9] minor fix --- internal/server/websocketServer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go index 56b5a6fdb..5d62bc5a8 100644 --- a/internal/server/websocketServer.go +++ b/internal/server/websocketServer.go @@ -256,7 +256,7 @@ func (s *WebsocketServer) processQwatchUpdates(ctx context.Context) { } if err := s.processResponse(conn, dicDBCmd, resp); err != nil { - slog.Debug("Error writing response to client. Shutting down goroutine for q.watch updates", slog.Any("clientIdentifierID", clientIdentifierID), slog.Any("error", err)) + slog.Debug("Error writing response to client. Shutting down goroutine for q.watch updates", slog.Any("clientIdentifierID", resp.ClientIdentifierID), slog.Any("error", err)) continue } case <-s.shutdownChan: From b21c735ca9d226fd1c535fd90f58bed4e5c44b1c Mon Sep 17 00:00:00 2001 From: psr Date: Wed, 23 Oct 2024 13:32:08 +0530 Subject: [PATCH 8/9] bug fix --- internal/server/websocketServer.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go index 5d62bc5a8..b28c3c482 100644 --- a/internal/server/websocketServer.go +++ b/internal/server/websocketServer.go @@ -244,10 +244,12 @@ func (s *WebsocketServer) processQwatchUpdates(ctx context.Context) { client, ok := s.subscribedClients.Load(resp.ClientIdentifierID) if !ok { slog.Debug("message received but client not found", slog.Any("clientIdentifierID", resp.ClientIdentifierID)) + continue } conn, ok := client.(*websocket.Conn) if !ok { slog.Debug("error typecasting client to *websocket.Conn") + continue } dicDBCmd := &cmd.DiceDBCmd{ From 2fb40f0d914c9bd317784392419b401f6a1a2ebc Mon Sep 17 00:00:00 2001 From: psr Date: Wed, 23 Oct 2024 18:49:45 +0530 Subject: [PATCH 9/9] bug fix --- internal/server/websocketServer.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go index c638e56be..8aec62b6c 100644 --- a/internal/server/websocketServer.go +++ b/internal/server/websocketServer.go @@ -273,6 +273,7 @@ func (s *WebsocketServer) processQwatchUpdates(ctx context.Context) { client, err := s.getClientById(resp.ClientIdentifierID) if err != nil { slog.Debug("message received but client not found or invalid", slog.Any("error", err)) + continue } dicDBCmd := &cmd.DiceDBCmd{