diff --git a/integration_tests/commands/websocket/get_test.go b/integration_tests/commands/websocket/get_test.go index 606b28f9e..98d1d854a 100644 --- a/integration_tests/commands/websocket/get_test.go +++ b/integration_tests/commands/websocket/get_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "gotest.tools/v3/assert" + "github.com/stretchr/testify/assert" ) func TestGet(t *testing.T) { @@ -19,9 +19,9 @@ func TestGet(t *testing.T) { }{ { name: "Get with expiration", - cmds: []string{"SET k v EX 4", "GET k", "GET k"}, + cmds: []string{"SET k v EX 1", "GET k", "GET k"}, expect: []interface{}{"OK", "v", "(nil)"}, - delays: []time.Duration{0, 0, 5 * time.Second}, + delays: []time.Duration{0, 0, 2 * time.Second}, }, } @@ -31,7 +31,8 @@ func TestGet(t *testing.T) { if tc.delays[i] > 0 { time.Sleep(tc.delays[i]) } - result := exec.FireCommand(conn, cmd) + result, err := exec.FireCommandAndReadResponse(conn, cmd) + assert.Nil(t, err) assert.Equal(t, tc.expect[i], result, "Value mismatch for cmd %s", cmd) } }) diff --git a/integration_tests/commands/websocket/helper.go b/integration_tests/commands/websocket/helper.go new file mode 100644 index 000000000..ca14c5b43 --- /dev/null +++ b/integration_tests/commands/websocket/helper.go @@ -0,0 +1,17 @@ +package websocket + +import ( + "testing" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" +) + +func DeleteKey(t *testing.T, conn *websocket.Conn, exec *WebsocketCommandExecutor, key string) { + cmd := "DEL " + key + resp, err := exec.FireCommandAndReadResponse(conn, cmd) + assert.Nil(t, err) + respFloat, ok := resp.(float64) + assert.True(t, ok, "error converting response to float64") + assert.True(t, respFloat == 1 || respFloat == 0, "unexpected response in %v: %v", cmd, resp) +} diff --git a/integration_tests/commands/websocket/hyperloglog_test.go b/integration_tests/commands/websocket/hyperloglog_test.go index fe3a27173..65e02dc04 100644 --- a/integration_tests/commands/websocket/hyperloglog_test.go +++ b/integration_tests/commands/websocket/hyperloglog_test.go @@ -3,7 +3,7 @@ package websocket import ( "testing" - "gotest.tools/v3/assert" + "github.com/stretchr/testify/assert" ) func TestHyperLogLogCommands(t *testing.T) { @@ -82,11 +82,12 @@ func TestHyperLogLogCommands(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { conn := exec.ConnectToServer() - exec.FireCommand(conn, "del k") + DeleteKey(t, conn, exec, "k") for i, cmd := range tc.commands { - result := exec.FireCommand(conn, cmd) - assert.DeepEqual(t, tc.expected[i], result) + result, err := exec.FireCommandAndReadResponse(conn, cmd) + assert.Nil(t, err) + assert.Equal(t, tc.expected[i], result) } }) } diff --git a/integration_tests/commands/websocket/json_test.go b/integration_tests/commands/websocket/json_test.go index 2a9b67965..dbc181505 100644 --- a/integration_tests/commands/websocket/json_test.go +++ b/integration_tests/commands/websocket/json_test.go @@ -1,8 +1,9 @@ package websocket import ( - "gotest.tools/v3/assert" "testing" + + "github.com/stretchr/testify/assert" ) func TestJSONClearOperations(t *testing.T) { @@ -10,10 +11,11 @@ func TestJSONClearOperations(t *testing.T) { conn := exec.ConnectToServer() defer conn.Close() - exec.FireCommand(conn, "DEL user") + DeleteKey(t, conn, exec, "user") defer func() { - resp := exec.FireCommand(conn, "DEL user") + resp, err := exec.FireCommandAndReadResponse(conn, "DEL user") + assert.Nil(t, err) assert.Equal(t, float64(1), resp) }() @@ -86,7 +88,8 @@ func TestJSONClearOperations(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { for i, cmd := range tc.commands { - result := exec.FireCommand(conn, cmd) + result, err := exec.FireCommandAndReadResponse(conn, cmd) + assert.Nil(t, err) assert.Equal(t, tc.expected[i], result) } }) @@ -99,10 +102,11 @@ func TestJsonStrlen(t *testing.T) { defer conn.Close() - exec.FireCommand(conn, "DEL doc") + DeleteKey(t, conn, exec, "doc") defer func() { - resp := exec.FireCommand(conn, "DEL doc") + resp, err := exec.FireCommandAndReadResponse(conn, "DEL doc") + assert.Nil(t, err) assert.Equal(t, float64(1), resp) }() @@ -172,12 +176,13 @@ func TestJsonStrlen(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { for i, cmd := range tc.commands { - result := exec.FireCommand(conn, cmd) + result, err := exec.FireCommandAndReadResponse(conn, cmd) + assert.Nil(t, err, "error: %v", err) stringResult, ok := result.(string) if ok { assert.Equal(t, tc.expected[i], stringResult) } else { - assert.Assert(t, arraysArePermutations(tc.expected[i].([]interface{}), result.([]interface{}))) + assert.True(t, arraysArePermutations(tc.expected[i].([]interface{}), result.([]interface{}))) } } }) @@ -189,7 +194,7 @@ func TestJsonObjLen(t *testing.T) { conn := exec.ConnectToServer() defer conn.Close() - exec.FireCommand(conn, "DEL obj") + DeleteKey(t, conn, exec, "obj") a := `{"name":"jerry","partner":{"name":"tom","language":["rust"]}}` b := `{"name":"jerry","partner":{"name":"tom","language":["rust"]},"partner2":{"name":"spike","language":["go","rust"]}}` @@ -197,7 +202,8 @@ func TestJsonObjLen(t *testing.T) { d := `["this","is","an","array"]` defer func() { - resp := exec.FireCommand(conn, "DEL obj") + resp, err := exec.FireCommandAndReadResponse(conn, "DEL obj") + assert.Nil(t, err) assert.Equal(t, float64(1), resp) }() @@ -259,13 +265,14 @@ func TestJsonObjLen(t *testing.T) { } for _, tcase := range testCases { - exec.FireCommand(conn, "DEL obj") + DeleteKey(t, conn, exec, "obj") t.Run(tcase.name, func(t *testing.T) { for i := 0; i < len(tcase.commands); i++ { cmd := tcase.commands[i] out := tcase.expected[i] - result := exec.FireCommand(conn, cmd) - assert.DeepEqual(t, out, result) + result, err := exec.FireCommandAndReadResponse(conn, cmd) + assert.Nil(t, err) + assert.Equal(t, out, result) } }) } diff --git a/integration_tests/commands/websocket/main_test.go b/integration_tests/commands/websocket/main_test.go index a85925541..eec1e6336 100644 --- a/integration_tests/commands/websocket/main_test.go +++ b/integration_tests/commands/websocket/main_test.go @@ -21,10 +21,12 @@ func TestMain(m *testing.M) { // checks for available port and then forks a goroutine // to start the server opts := TestServerOptions{ - Port: 8380, + Port: testPort1, Logger: l, } - RunWebsocketServer(context.Background(), &wg, opts) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + RunWebsocketServer(ctx, &wg, opts) // Wait for the server to start time.Sleep(2 * time.Second) @@ -37,7 +39,6 @@ func TestMain(m *testing.M) { // abort conn := executor.ConnectToServer() executor.FireCommand(conn, "abort") - executor.DisconnectServer(conn) wg.Wait() os.Exit(exitCode) diff --git a/integration_tests/commands/websocket/set_test.go b/integration_tests/commands/websocket/set_test.go index 964652995..9f31d4dde 100644 --- a/integration_tests/commands/websocket/set_test.go +++ b/integration_tests/commands/websocket/set_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "gotest.tools/v3/assert" + "github.com/stretchr/testify/assert" ) type TestCase struct { @@ -39,11 +39,13 @@ func TestSet(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { conn := exec.ConnectToServer() - exec.FireCommand(conn, "del k") + + DeleteKey(t, conn, exec, "k") for i, cmd := range tc.commands { - result := exec.FireCommand(conn, cmd) - assert.DeepEqual(t, tc.expected[i], result) + result, err := exec.FireCommandAndReadResponse(conn, cmd) + assert.Nil(t, err) + assert.Equal(t, tc.expected[i], result) } }) } @@ -124,11 +126,14 @@ func TestSetWithOptions(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { conn := exec.ConnectToServer() - exec.FireCommand(conn, "del k") - exec.FireCommand(conn, "del k1") - exec.FireCommand(conn, "del k2") + + DeleteKey(t, conn, exec, "k") + DeleteKey(t, conn, exec, "k1") + DeleteKey(t, conn, exec, "k2") + for i, cmd := range tc.commands { - result := exec.FireCommand(conn, cmd) + result, err := exec.FireCommandAndReadResponse(conn, cmd) + assert.Nil(t, err) assert.Equal(t, tc.expected[i], result) } }) @@ -143,32 +148,76 @@ func TestSetWithExat(t *testing.T) { t.Run("SET with EXAT", func(t *testing.T) { conn := exec.ConnectToServer() - exec.FireCommand(conn, "DEL k") - assert.Equal(t, "OK", exec.FireCommand(conn, fmt.Sprintf("SET k v EXAT %v", Etime)), "Value mismatch for cmd SET k v EXAT "+Etime) - assert.Equal(t, "v", exec.FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") - assert.Assert(t, exec.FireCommand(conn, "TTL k").(float64) <= 5, "Value mismatch for cmd TTL k") + + DeleteKey(t, conn, exec, "k") + + resp, err := exec.FireCommandAndReadResponse(conn, fmt.Sprintf("SET k v EXAT %v", Etime)) + assert.Nil(t, err) + assert.Equal(t, "OK", resp, "Value mismatch for cmd SET k v EXAT "+Etime) + + resp, err = exec.FireCommandAndReadResponse(conn, "GET k") + assert.Nil(t, err) + assert.Equal(t, "v", resp, "Value mismatch for cmd GET k") + + resp, err = exec.FireCommandAndReadResponse(conn, "TTL k") + assert.Nil(t, err) + respFloat, ok := resp.(float64) + assert.True(t, ok) + assert.True(t, respFloat <= 5, "Value mismatch for cmd TTL k") + time.Sleep(3 * time.Second) - assert.Assert(t, exec.FireCommand(conn, "TTL k").(float64) <= 3, "Value mismatch for cmd TTL k") + resp, err = exec.FireCommandAndReadResponse(conn, "TTL k") + assert.Nil(t, err) + respFloat, ok = resp.(float64) + assert.True(t, ok) + assert.True(t, respFloat <= 3, "Value mismatch for cmd TTL k") + time.Sleep(3 * time.Second) - assert.Equal(t, "(nil)", exec.FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") - assert.Equal(t, float64(-2), exec.FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k") + resp, err = exec.FireCommandAndReadResponse(conn, "GET k") + assert.Nil(t, err) + assert.Equal(t, "(nil)", resp, "Value mismatch for cmd GET k") + + resp, err = exec.FireCommandAndReadResponse(conn, "TTL k") + assert.Nil(t, err) + respFloat, ok = resp.(float64) + assert.True(t, ok) + assert.Equal(t, float64(-2), respFloat, "Value mismatch for cmd TTL k") }) t.Run("SET with invalid EXAT expires key immediately", func(t *testing.T) { conn := exec.ConnectToServer() - exec.FireCommand(conn, "DEL k") - assert.Equal(t, "OK", exec.FireCommand(conn, "SET k v EXAT "+BadTime), "Value mismatch for cmd SET k v EXAT "+BadTime) - assert.Equal(t, "(nil)", exec.FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") - assert.Equal(t, float64(-2), exec.FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k") + + DeleteKey(t, conn, exec, "k") + + resp, err := exec.FireCommandAndReadResponse(conn, "SET k v EXAT "+BadTime) + assert.Nil(t, err) + assert.Equal(t, "OK", resp, "Value mismatch for cmd SET k v EXAT "+BadTime) + + resp, err = exec.FireCommandAndReadResponse(conn, "GET k") + assert.Nil(t, err) + assert.Equal(t, "(nil)", resp, "Value mismatch for cmd GET k") + + resp, err = exec.FireCommandAndReadResponse(conn, "TTL k") + assert.Nil(t, err) + respFloat, ok := resp.(float64) + assert.True(t, ok) + assert.Equal(t, float64(-2), respFloat, "Value mismatch for cmd TTL k") }) t.Run("SET with EXAT and PXAT returns syntax error", func(t *testing.T) { conn := exec.ConnectToServer() - exec.FireCommand(conn, "DEL k") - assert.Equal(t, "ERR syntax error", exec.FireCommand(conn, "SET k v PXAT "+Etime+" EXAT "+Etime), "Value mismatch for cmd SET k v PXAT "+Etime+" EXAT "+Etime) - assert.Equal(t, "(nil)", exec.FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") + + DeleteKey(t, conn, exec, "k") + + resp, err := exec.FireCommandAndReadResponse(conn, "SET k v PXAT "+Etime+" EXAT "+Etime) + assert.Nil(t, err) + assert.Equal(t, "ERR syntax error", resp, "Value mismatch for cmd SET k v PXAT "+Etime+" EXAT "+Etime) + + resp, err = exec.FireCommandAndReadResponse(conn, "GET k") + assert.Nil(t, err) + assert.Equal(t, "(nil)", resp, "Value mismatch for cmd GET k") }) } @@ -185,14 +234,17 @@ func TestWithKeepTTLFlag(t *testing.T) { for i := 0; i < len(tcase.commands); i++ { cmd := tcase.commands[i] out := tcase.expected[i] - assert.Equal(t, out, exec.FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd) + + resp, err := exec.FireCommandAndReadResponse(conn, cmd) + assert.Nil(t, err) + assert.Equal(t, out, resp, "Value mismatch for cmd %s\n.", cmd) } } time.Sleep(2 * time.Second) - cmd := "GET k" out := "(nil)" - - assert.Equal(t, out, exec.FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd) + resp, err := exec.FireCommandAndReadResponse(conn, cmd) + assert.Nil(t, err) + assert.Equal(t, out, resp, "Value mismatch for cmd %s\n.", cmd) } diff --git a/integration_tests/commands/websocket/setup.go b/integration_tests/commands/websocket/setup.go index 247e7a2c5..f4292f6c1 100644 --- a/integration_tests/commands/websocket/setup.go +++ b/integration_tests/commands/websocket/setup.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "errors" - "log" + "fmt" "log/slog" "net/http" "sync" @@ -18,7 +18,11 @@ import ( "github.com/gorilla/websocket" ) -const URL = "ws://localhost:8380" +const ( + URL = "ws://localhost:8380" + testPort1 = 8380 + testPort2 = 8381 +) type TestServerOptions struct { Port int @@ -60,32 +64,35 @@ func (e *WebsocketCommandExecutor) ConnectToServer() *websocket.Conn { return conn } -func (e *WebsocketCommandExecutor) FireCommand(conn *websocket.Conn, cmd string) interface{} { - command := []byte(cmd) - - // send request - err := conn.WriteMessage(websocket.TextMessage, command) +func (e *WebsocketCommandExecutor) FireCommandAndReadResponse(conn *websocket.Conn, cmd string) (interface{}, error) { + err := e.FireCommand(conn, cmd) if err != nil { - return nil + return nil, err } // read the response _, resp, err := conn.ReadMessage() if err != nil { - return nil + return nil, err } // marshal to json var respJSON interface{} if err = json.Unmarshal(resp, &respJSON); err != nil { - return nil + return nil, fmt.Errorf("error unmarshaling response") } - return respJSON + return respJSON, nil } -func (e *WebsocketCommandExecutor) DisconnectServer(conn *websocket.Conn) { - conn.Close() +func (e *WebsocketCommandExecutor) FireCommand(conn *websocket.Conn, cmd string) error { + // send request + err := conn.WriteMessage(websocket.TextMessage, []byte(cmd)) + if err != nil { + return err + } + + return nil } func (e *WebsocketCommandExecutor) Name() string { @@ -93,24 +100,25 @@ func (e *WebsocketCommandExecutor) Name() string { } func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOptions) { + logger := opt.Logger config.DiceConfig.Network.IOBufferLength = 16 config.DiceConfig.Persistence.WriteAOFOnCleanup = false - // Initialize the WebsocketServer + // Initialize WebsocketServer globalErrChannel := make(chan error) watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Performance.WatchChanBufSize) shardManager := shard.NewShardManager(1, watchChan, nil, globalErrChannel, opt.Logger) - config.WebsocketPort = opt.Port - testServer := server.NewWebSocketServer(shardManager, watchChan, opt.Logger) - + testServer := server.NewWebSocketServer(shardManager, watchChan, testPort1, opt.Logger) shardManagerCtx, cancelShardManager := context.WithCancel(ctx) + + // run shard manager wg.Add(1) go func() { defer wg.Done() shardManager.Run(shardManagerCtx) }() - // Start the server in a goroutine + // start websocket server wg.Add(1) go func() { defer wg.Done() @@ -120,7 +128,7 @@ func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerO if errors.Is(srverr, derrors.ErrAborted) { return } - log.Printf("Websocket test server encountered an error: %v", srverr) + logger.Debug("Websocket test server encountered an error: %v", slog.Any("error", srverr)) } }() } diff --git a/integration_tests/commands/websocket/websocket_write_retries_test.go b/integration_tests/commands/websocket/writeretry_test.go similarity index 94% rename from integration_tests/commands/websocket/websocket_write_retries_test.go rename to integration_tests/commands/websocket/writeretry_test.go index a295ed089..b0b270fc8 100644 --- a/integration_tests/commands/websocket/websocket_write_retries_test.go +++ b/integration_tests/commands/websocket/writeretry_test.go @@ -1,10 +1,10 @@ package websocket import ( + "fmt" "net" "net/http" "net/url" - "os" "sync" "testing" "time" @@ -14,7 +14,7 @@ import ( "github.com/stretchr/testify/assert" ) -var serverAddr = "localhost:12345" +var serverAddr = fmt.Sprintf("localhost:%v", testPort2) var once sync.Once func TestWriteResponseWithRetries_Success(t *testing.T) { @@ -70,10 +70,6 @@ func TestWriteResponseWithRetries_EAGAINRetry(t *testing.T) { assert.Equal(t, 2, retries) } -func newSyscallError(syscall string, err error) *os.SyscallError { - return &os.SyscallError{Syscall: syscall, Err: err} -} - func startWebSocketServer() { http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Upgrade(w, r, nil, 1024, 1024) diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go index 0a269370f..52b764721 100644 --- a/internal/server/websocketServer.go +++ b/internal/server/websocketServer.go @@ -47,10 +47,10 @@ type WebsocketServer struct { shutdownChan chan struct{} } -func NewWebSocketServer(shardManager *shard.ShardManager, watchChan chan dstore.QueryWatchEvent, logger *slog.Logger) *WebsocketServer { +func NewWebSocketServer(shardManager *shard.ShardManager, watchChan chan dstore.QueryWatchEvent, port int, logger *slog.Logger) *WebsocketServer { mux := http.NewServeMux() srv := &http.Server{ - Addr: fmt.Sprintf(":%d", config.WebsocketPort), + Addr: fmt.Sprintf(":%d", port), Handler: mux, ReadHeaderTimeout: 5 * time.Second, } @@ -134,8 +134,12 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques // read incoming message _, msg, err := conn.ReadMessage() if err != nil { - writeResponse(conn, []byte("error: command reading failed")) - continue + // acceptable close errors + errs := []int{websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure} + if !websocket.IsCloseError(err, errs...) { + s.logger.Warn("failed to read message from client", slog.Any("error", err)) + } + break } // parse message to dice command diff --git a/main.go b/main.go index 16a863b8d..6e54f5419 100644 --- a/main.go +++ b/main.go @@ -207,7 +207,7 @@ func main() { }() } - websocketServer := server.NewWebSocketServer(shardManager, queryWatchChan, logr) + websocketServer := server.NewWebSocketServer(shardManager, queryWatchChan, config.WebsocketPort, logr) serverWg.Add(1) go func() { defer serverWg.Done()