diff --git a/internal/impl/io/output_http_server.go b/internal/impl/io/output_http_server.go index 8639bea51..467ab3451 100644 --- a/internal/impl/io/output_http_server.go +++ b/internal/impl/io/output_http_server.go @@ -389,57 +389,112 @@ func (h *httpServerOutput) streamHandler(w http.ResponseWriter, r *http.Request) } func (h *httpServerOutput) wsHandler(w http.ResponseWriter, r *http.Request) { - var err error - defer func() { - if err != nil { - http.Error(w, "Bad request", http.StatusBadRequest) - h.log.Warn("Websocket request failed: %v\n", err) - return - } - }() - - upgrader := websocket.Upgrader{} - - var ws *websocket.Conn - if ws, err = upgrader.Upgrade(w, r, nil); err != nil { - return - } - defer ws.Close() - - ctx, done := h.shutSig.SoftStopCtx(r.Context()) - defer done() - - for !h.shutSig.IsSoftStopSignalled() { - var ts message.Transaction - var open bool - - select { - case ts, open = <-h.transactions: - if !open { - go h.TriggerCloseNow() - return - } - case <-r.Context().Done(): - return - case <-h.shutSig.SoftStopChan(): - return - } - - var werr error - for _, msg := range message.GetAllBytes(ts.Payload) { - if werr = ws.WriteMessage(websocket.BinaryMessage, msg); werr != nil { - break - } - h.mWSBatchSent.Incr(1) - h.mWSSent.Incr(int64(batch.MessageCollapsedCount(ts.Payload))) - } - if werr != nil { - h.mWSError.Incr(1) - } - _ = ts.Ack(ctx, werr) - } + var err error + defer func() { + if err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + h.log.Warn("WebSocket request failed: %v", err) + return + } + }() + + upgrader := websocket.Upgrader{} + + // Upgrade the HTTP connection to a WebSocket connection + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + h.log.Warn("WebSocket upgrade failed: %v", err) + return + } + defer ws.Close() + + // Set up ping/pong handlers and deadlines + const ( + writeWait = 10 * time.Second + pongWait = 60 * time.Second + pingPeriod = (pongWait * 9) / 10 + ) + + ws.SetReadLimit(512) + if err := ws.SetReadDeadline(time.Now().Add(pongWait)); err != nil { + h.log.Warn("Failed to set read deadline: %v", err) + return + } + ws.SetPongHandler(func(string) error { + return ws.SetReadDeadline(time.Now().Add(pongWait)) + }) + + // Start a goroutine to read messages (to process control frames) + done := make(chan struct{}) + go func() { + defer close(done) + for { + _, _, err := ws.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + h.log.Warn("WebSocket read error: %v", err) + } + break + } + } + }() + + // Start ticker to send ping messages to the client periodically + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + + ctx, doneCtx := h.shutSig.SoftStopCtx(r.Context()) + defer doneCtx() + + for !h.shutSig.IsSoftStopSignalled() { + select { + case ts, open := <-h.transactions: + if !open { + // If the transactions channel is closed, trigger server shutdown + go h.TriggerCloseNow() + return + } + // Write messages to the client + var writeErr error + for _, msg := range message.GetAllBytes(ts.Payload) { + if err := ws.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + writeErr = err + break + } + if writeErr = ws.WriteMessage(websocket.BinaryMessage, msg); writeErr != nil { + break + } + h.mWSBatchSent.Incr(1) + h.mWSSent.Incr(int64(batch.MessageCollapsedCount(ts.Payload))) + } + if writeErr != nil { + h.mWSError.Incr(1) + _ = ts.Ack(ctx, writeErr) + return // Exit the loop on write error + } + _ = ts.Ack(ctx, nil) + case <-ticker.C: + // Send a ping message to the client + if err := ws.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + h.log.Warn("Failed to set write deadline for ping: %v", err) + return + } + if err := ws.WriteMessage(websocket.PingMessage, nil); err != nil { + h.log.Warn("WebSocket ping error: %v", err) + return + } + case <-done: + // The read goroutine has exited, indicating the client has disconnected + h.log.Debug("WebSocket client disconnected") + return + case <-ctx.Done(): + // The context has been canceled (e.g., server is shutting down) + return + } + } } + func (h *httpServerOutput) Consume(ts <-chan message.Transaction) error { if h.transactions != nil { return component.ErrAlreadyStarted