diff --git a/face/web-socket-listener.go b/face/web-socket-listener.go index 0a48b427..d2727ceb 100644 --- a/face/web-socket-listener.go +++ b/face/web-socket-listener.go @@ -33,7 +33,13 @@ type WebSocketListenerConfig struct { TLSKey string } -// URL returns server URL. +// WebSocketListener listens for incoming WebSockets connections. +type WebSocketListener struct { + server http.Server + upgrader websocket.Upgrader + localURI *defn.URI +} + func (cfg WebSocketListenerConfig) URL() *url.URL { addr := net.JoinHostPort(cfg.Bind, strconv.FormatUint(uint64(cfg.Port), 10)) u := &url.URL{ @@ -55,7 +61,6 @@ func (cfg WebSocketListenerConfig) String() string { return b.String() } -// NewWebSocketListener constructs a WebSocketListener. func NewWebSocketListener(cfg WebSocketListenerConfig) (*WebSocketListener, error) { localURI := cfg.URL() ret := &WebSocketListener{ @@ -80,31 +85,21 @@ func NewWebSocketListener(cfg WebSocketListenerConfig) (*WebSocketListener, erro return ret, nil } -// WebSocketListener listens for incoming WebSockets connections. -type WebSocketListener struct { - server http.Server - upgrader websocket.Upgrader - localURI *defn.URI -} - -var _ Listener = &WebSocketListener{} - func (l *WebSocketListener) String() string { return "WebSocketListener, " + l.localURI.String() } -// Run starts the WebSocket listener. func (l *WebSocketListener) Run() { l.server.Handler = http.HandlerFunc(l.handler) - var e error + var err error if l.server.TLSConfig == nil { - e = l.server.ListenAndServe() + err = l.server.ListenAndServe() } else { - e = l.server.ListenAndServeTLS("", "") + err = l.server.ListenAndServeTLS("", "") } - if !errors.Is(e, http.ErrServerClosed) { - core.LogFatal(l, "Unable to start listener: ", e) + if !errors.Is(err, http.ErrServerClosed) { + core.LogFatal(l, "Unable to start listener: ", err) } } @@ -122,7 +117,6 @@ func (l *WebSocketListener) handler(w http.ResponseWriter, r *http.Request) { MakeNDNLPLinkService(newTransport, options).Run(nil) } -// Close closes the WebSocketListener. func (l *WebSocketListener) Close() { core.LogInfo(l, "Stopping listener") l.server.Shutdown(context.TODO()) diff --git a/face/web-socket-transport.go b/face/web-socket-transport.go index 968043c3..24e118f9 100644 --- a/face/web-socket-transport.go +++ b/face/web-socket-transport.go @@ -8,8 +8,8 @@ package face import ( + "fmt" "net" - "strconv" "github.com/gorilla/websocket" "github.com/named-data/YaNFD/core" @@ -22,13 +22,8 @@ type WebSocketTransport struct { c *websocket.Conn } -var _ transport = &WebSocketTransport{} - -// NewWebSocketTransport creates a Unix stream transport. func NewWebSocketTransport(localURI *defn.URI, c *websocket.Conn) (t *WebSocketTransport) { remoteURI := defn.MakeWebSocketClientFaceURI(c.RemoteAddr()) - t = &WebSocketTransport{c: c} - t.running.Store(true) scope := defn.NonLocal ip := net.ParseIP(remoteURI.PathHost()) @@ -36,32 +31,35 @@ func NewWebSocketTransport(localURI *defn.URI, c *websocket.Conn) (t *WebSocketT scope = defn.Local } + t = &WebSocketTransport{c: c} t.makeTransportBase(remoteURI, localURI, PersistencyOnDemand, scope, defn.PointToPoint, defn.MaxNDNPacketSize) + t.running.Store(true) + return t } func (t *WebSocketTransport) String() string { - return "WebSocketTransport, FaceID=" + strconv.FormatUint(t.faceID, 10) + - ", RemoteURI=" + t.remoteURI.String() + ", LocalURI=" + t.localURI.String() + return fmt.Sprintf("WebSocketTransport, FaceID=%d, RemoteURI=%s, LocalURI=%s", t.faceID, t.remoteURI, t.localURI) } -// SetPersistency changes the persistency of the face. func (t *WebSocketTransport) SetPersistency(persistency Persistency) bool { return persistency == PersistencyOnDemand } -// GetSendQueueSize returns the current size of the send queue. func (t *WebSocketTransport) GetSendQueueSize() uint64 { return 0 } func (t *WebSocketTransport) sendFrame(frame []byte) { + if !t.running.Load() { + return + } + if len(frame) > t.MTU() { core.LogWarn(t, "Attempted to send frame larger than MTU - DROP") return } - core.LogDebug(t, "Sending frame of size ", len(frame)) e := t.c.WriteMessage(websocket.BinaryMessage, frame) if e != nil { core.LogWarn(t, "Unable to send on socket - DROP and Face DOWN") @@ -73,11 +71,18 @@ func (t *WebSocketTransport) sendFrame(frame []byte) { } func (t *WebSocketTransport) runReceive() { + defer t.Close() + for { mt, message, e := t.c.ReadMessage() if e != nil { - core.LogWarn(t, "Unable to read from socket (", e, ") - DROP and Face DOWN") - t.Close() + if websocket.IsCloseError(e) { + // gracefully closed + } else if websocket.IsUnexpectedCloseError(e) { + core.LogInfo(t, "WebSocket closed unexpectedly (", e, ") - DROP and Face DOWN") + } else { + core.LogWarn(t, "Unable to read from WebSocket (", e, ") - DROP and Face DOWN") + } return } @@ -86,18 +91,17 @@ func (t *WebSocketTransport) runReceive() { continue } - core.LogTrace(t, "Receive of size ", len(message)) - t.nInBytes += uint64(len(message)) - if len(message) > defn.MaxNDNPacketSize { core.LogWarn(t, "Received too much data without valid TLV block - DROP") continue } + t.nInBytes += uint64(len(message)) t.linkService.handleIncomingFrame(message) } } func (t *WebSocketTransport) Close() { + t.running.Store(false) t.c.Close() }