From db9bcf2cdaa6092c45b6e321ffb76800f07673c8 Mon Sep 17 00:00:00 2001 From: Antonio Nuno Monteiro Date: Mon, 31 Jul 2023 19:10:12 -0700 Subject: [PATCH] websocket: exposes "messages" instead of "frames" automatically handle continuation frames and fin bit --- examples/eio/echo_server_upgrade.ml | 4 +- examples/eio/eio_wscat.ml | 21 +++-- flake.lock | 14 ++-- lib/iOVec.ml | 18 +++++ lib/piaf.mli | 15 ++-- lib/ws.ml | 121 ++++++++++++++++++++++------ 6 files changed, 147 insertions(+), 46 deletions(-) diff --git a/examples/eio/echo_server_upgrade.ml b/examples/eio/echo_server_upgrade.ml index 4c690009..5c5a767a 100644 --- a/examples/eio/echo_server_upgrade.ml +++ b/examples/eio/echo_server_upgrade.ml @@ -3,9 +3,9 @@ open Piaf let connection_handler { Server.request; _ } = Response.Upgrade.websocket request ~f:(fun wsd -> - let frames = Ws.Descriptor.frames wsd in + let frames = Ws.Descriptor.messages wsd in Stream.iter - ~f:(fun (_opcode, frame) -> Ws.Descriptor.send_string wsd frame) + ~f:(fun (_opcode, frame) -> Ws.Descriptor.send_iovec wsd frame) frames) |> Result.get_ok diff --git a/examples/eio/eio_wscat.ml b/examples/eio/eio_wscat.ml index 91c4fa52..d11336fb 100644 --- a/examples/eio/eio_wscat.ml +++ b/examples/eio/eio_wscat.ml @@ -16,11 +16,17 @@ end let rec stdin_loop ~stdin buf wsd = let line = Eio.Buf_read.line buf in traceln "< %s" line; - if line = "exit" - then Ws.Descriptor.close wsd - else ( - Ws.Descriptor.send_string wsd line; - stdin_loop ~stdin buf wsd) + match line with + | "exit" -> Ws.Descriptor.close wsd + | "ping" -> + let application_data = + IOVec.make ~off:0 ~len:5 (Bigstringaf.of_string ~off:0 ~len:5 "hello") + in + Ws.Descriptor.send_ping ~application_data wsd; + stdin_loop ~stdin buf wsd + | line -> + Ws.Descriptor.send_stringy wsd line; + stdin_loop ~stdin buf wsd let request ~env ~sw host = let open Result in @@ -34,8 +40,9 @@ let request ~env ~sw host = Client.shutdown client) (fun () -> Stream.iter - ~f:(fun (_opcode, frame) -> Format.printf ">> %s@." frame) - (Ws.Descriptor.frames wsd)) + ~f:(fun (_opcode, { IOVec.buffer; off; len }) -> + Format.printf ">> %s@." (Bigstringaf.substring ~off ~len buffer)) + (Ws.Descriptor.messages wsd)) let setup_log ?style_renderer level = Fmt_tty.setup_std_outputs ?style_renderer (); diff --git a/flake.lock b/flake.lock index c2e9275a..f7ec1314 100644 --- a/flake.lock +++ b/flake.lock @@ -41,11 +41,11 @@ "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1689135990, - "narHash": "sha256-QnUEHQ1QhRaUb2FcP5gWKFJk2aro66GaovRLV2XYT/E=", + "lastModified": 1690855714, + "narHash": "sha256-zYsI1x+8FGwqHv0WY2rTacw78UnIJOBMruw6YvqCMuQ=", "owner": "nix-ocaml", "repo": "nix-overlays", - "rev": "d2671fd208ec9f25143244bdbe0775c82a5f3473", + "rev": "1b268dd81727a71e01f1e8f59d9ae1251cba35a0", "type": "github" }, "original": { @@ -56,17 +56,17 @@ }, "nixpkgs_2": { "locked": { - "lastModified": 1689065019, - "narHash": "sha256-oCQM37FahwN5uEr5PQZZGDuGBe9TnIOIFcz4bGXvGdE=", + "lastModified": 1690803489, + "narHash": "sha256-TqdStgF+EA+kJ8PzVjOxa8HdM684CmZZz2ohlKq9j4A=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "8f7eabf5564b06eafc39d4a5c9fa442c06b3bd55", + "rev": "0d117d7ad5d590991d23ccc7bd88c4e627cccf17", "type": "github" }, "original": { "owner": "NixOS", "repo": "nixpkgs", - "rev": "8f7eabf5564b06eafc39d4a5c9fa442c06b3bd55", + "rev": "0d117d7ad5d590991d23ccc7bd88c4e627cccf17", "type": "github" } }, diff --git a/lib/iOVec.ml b/lib/iOVec.ml index 9a942abc..f4f161f6 100644 --- a/lib/iOVec.ml +++ b/lib/iOVec.ml @@ -40,3 +40,21 @@ let of_bytes bytes ~off ~len = let buffer = Bigstringaf.create len in Bigstringaf.blit_from_bytes bytes ~src_off:off buffer ~dst_off:0 ~len; { buffer; off; len } + +let concat = function + | [] -> make Bigstringaf.empty ~off:0 ~len:0 + | [ iovec ] -> iovec + | iovecs -> + let length = lengthv iovecs in + let result_buffer = Bigstringaf.create length in + let aux acc_off { buffer; off; len } = + Bigstringaf.unsafe_blit + buffer + ~src_off:off + result_buffer + ~dst_off:acc_off + ~len; + acc_off + len + in + ignore @@ List.fold_left aux 0 iovecs; + { buffer = result_buffer; off = 0; len = length } diff --git a/lib/piaf.mli b/lib/piaf.mli index 4c710f88..f06da0d3 100644 --- a/lib/piaf.mli +++ b/lib/piaf.mli @@ -70,7 +70,7 @@ module Headers : sig hold: - [to_list (of_list lst) = lst] - - [get (of_list \[("k", "v1"); ("k", "v2")\]) "k" = Some "v2"]. *) + - [get (of_list [("k", "v1"); ("k", "v2")]) "k" = Some "v2"]. *) val of_rev_list : (name * value) list -> t (** [of_list assoc] is a collection of header fields defined by the @@ -79,7 +79,7 @@ module Headers : sig following equations should hold: - [to_list (of_rev_list lst) = List.rev lst] - - [get (of_rev_list \[("k", "v1"); ("k", "v2")\]) "k" = Some "v1"]. *) + - [get (of_rev_list [("k", "v1"); ("k", "v2")]) "k" = Some "v1"]. *) val to_list : t -> (name * value) list (** [to_list t] is the association list of header fields contained in [t] in @@ -434,18 +434,23 @@ module Body : sig end module Ws : sig + module Message : sig + type t = Websocketaf.Websocket.Opcode.t * Bigstringaf.t IOVec.t + end + module Descriptor : sig type t - val frames : t -> (Websocketaf.Websocket.Opcode.t * string) Stream.t + val messages : t -> Message.t Stream.t (** Stream of incoming websocket messages (frames) *) + val send_iovec : t -> Bigstringaf.t IOVec.t -> unit val send_stream : t -> Bigstringaf.t IOVec.t Stream.t -> unit val send_string_stream : t -> string Stream.t -> unit val send_string : t -> string -> unit val send_bigstring : t -> ?off:int -> ?len:int -> Bigstringaf.t -> unit - val send_ping : t -> unit - val send_pong : t -> unit + val send_ping : ?application_data:Bigstringaf.t IOVec.t -> t -> unit + val send_pong : ?application_data:Bigstringaf.t IOVec.t -> t -> unit val flushed : t -> unit Eio.Promise.t val close : t -> unit val is_closed : t -> bool diff --git a/lib/ws.ml b/lib/ws.ml index 447749b2..e5474f92 100644 --- a/lib/ws.ml +++ b/lib/ws.ml @@ -39,35 +39,50 @@ let upgrade_request ~headers ~scheme ~nonce target = ~scheme (Websocketaf.Handshake.create_request ~nonce ~headers target) +module Opcode = struct + type t = Websocket.Opcode.t + + let to_string = function + | `Continuation -> "Continuation" + | `Text -> "Text" + | `Binary -> "Binary" + | `Connection_close -> "Connection_close" + | `Ping -> "Ping" + | `Pong -> "Pong" + | `Other code -> Format.asprintf "Custom: %x" code +end + +module Message = struct + type t = Websocket.Opcode.t * Bigstringaf.t IOVec.t +end + module Descriptor : sig type t - type frame = Websocket.Opcode.t * string - val create : frames:frame Stream.t -> Wsd.t -> t - val frames : t -> frame Stream.t + val create : messages:Message.t Stream.t -> Wsd.t -> t + val messages : t -> Message.t Stream.t + val send_iovec : t -> Bigstringaf.t IOVec.t -> unit val send_stream : t -> Bigstringaf.t IOVec.t Stream.t -> unit val send_string_stream : t -> string Stream.t -> unit val send_string : t -> string -> unit val send_bigstring : t -> ?off:int -> ?len:int -> Bigstringaf.t -> unit - val send_ping : t -> unit - val send_pong : t -> unit + val send_ping : ?application_data:Bigstringaf.t IOVec.t -> t -> unit + val send_pong : ?application_data:Bigstringaf.t IOVec.t -> t -> unit val flushed : t -> unit Promise.t val close : t -> unit val is_closed : t -> bool end = struct type t = { wsd : Wsd.t - ; frames : (Websocket.Opcode.t * string) Stream.t + ; messages : Message.t Stream.t } - type frame = Websocket.Opcode.t * string - - let create ~frames wsd = { wsd; frames } - let frames t = t.frames + let create ~messages wsd = { wsd; messages } + let messages t = t.messages - let send_bytes t ?(off = 0) ?len bytes = + let send_bytes t ?is_fin ?(opcode = `Binary) ?(off = 0) ?len bytes = let len = match len with Some l -> l | None -> Bytes.length bytes in - Wsd.send_bytes t.wsd ~kind:`Binary ~off ~len bytes + Wsd.send_bytes t.wsd ?is_fin ~kind:opcode ~off ~len bytes let send_iovec : t -> Bigstringaf.t IOVec.t -> unit = fun t iovec -> @@ -82,7 +97,8 @@ end = struct send_stream t stream | None -> () - let send_string t str = send_bytes t (Bytes.of_string str) + let send_string t str = + send_bytes t ~opcode:`Text (Bytes.unsafe_of_string str) let rec send_string_stream : t -> string Stream.t -> unit = fun t stream -> @@ -97,46 +113,101 @@ end = struct let len = match len with Some l -> l | None -> Bigstringaf.length bstr in Wsd.schedule t.wsd ~kind:`Binary ~off ~len bstr - let send_ping t = Wsd.send_ping t.wsd - let send_pong t = Wsd.send_pong t.wsd + let send_ping ?application_data t = Wsd.send_ping ?application_data t.wsd + let send_pong ?application_data t = Wsd.send_pong ?application_data t.wsd let flushed t = let p, u = Promise.create () in Wsd.flushed t.wsd (Promise.resolve u); p - let close t = Wsd.close (* ~code:`Normal_closure *) t.wsd + let close t = Wsd.close ~code:`Normal_closure t.wsd let is_closed t = Wsd.is_closed t.wsd end module Handler = struct let websocket_handler ~sw ~notify_wsd wsd = - let frames, push_to_frames = Stream.create 256 in - Promise.resolve notify_wsd (Descriptor.create ~frames wsd); - let frame ~opcode ~is_fin:_ ~len payload = - let len = Int64.of_int len in + let frameq = Queue.create () in + let messages, push_to_messages = Stream.create 256 in + Promise.resolve notify_wsd (Descriptor.create ~messages wsd); + + let frame ~opcode ~is_fin ~len payload = let { Body.stream; _ } = + let body_length = `Fixed (Int64.of_int len) in Body.Raw.to_stream (module Websocketaf.Payload : Body.Raw.Reader with type t = Websocketaf.Payload.t) - ~body_length:(`Fixed len) + ~body_length ~body_error:(`Msg "") ~on_eof:(fun t -> - match Websocketaf.Wsd.error_code wsd with + match Wsd.error_code wsd with | Some error -> t.error_received := Promise.create_resolved (error :> Error.t) | None -> ()) payload in Fiber.fork ~sw (fun () -> - let frame = Body.stream_to_string ~length:(`Fixed len) stream in - push_to_frames (Some (opcode, frame))) + match opcode with + | `Pong -> + (* From RFC6455§5.5.2: + * A Pong frame MAY be sent unsolicited. This serves as a + * unidirectional heartbeat. A response to an unsolicited Pong frame + * is not expected. *) + (* Drain any application data payload in the Pong frame. *) + Stream.drain stream + | `Ping -> + (* From RFC6455§5.5.3: + * Upon receipt of a Ping frame, an endpoint MUST send a Pong frame + * in response, unless it already received a Close frame. *) + let payload = + let iovecs = Stream.to_list stream in + IOVec.concat iovecs + in + Wsd.send_pong ~application_data:payload wsd + | `Text | `Binary -> + let frame = Stream.to_list stream in + (match is_fin with + | true -> + (* FIN bit set, just push to the stream. *) + push_to_messages (Some (opcode, IOVec.concat frame)) + | false -> + (* FIN bit not set, accumulate in the temp queue. *) + Queue.add (opcode, frame) frameq) + | `Continuation -> + let frame = Stream.to_list stream in + Queue.add (opcode, frame) frameq; + (match is_fin with + | true -> + (* FIN bit set, consume the queue. *) + let opcode, message = + let opcode, first_frame = + (* invariant: the queue is non-empty if this is a continuation + frame *) + Queue.take frameq + in + let other_frames = + Queue.to_seq frameq |> Seq.map snd |> List.of_seq + in + let all_frames = first_frame :: other_frames in + opcode, IOVec.concat (List.concat all_frames) + in + (* Clear the queue after assembling the full message. *) + Queue.clear frameq; + push_to_messages (Some (opcode, message)) + | false -> + (* FIN bit not set, keep accumulating in the temp queue. *) + ()) + | `Connection_close -> + let message = Stream.to_list stream |> List.hd in + push_to_messages (Some (opcode, message)) + | `Other _ -> + failwith "Custom WebSocket frame types not yet supported") in let eof () = Logs.info (fun m -> m "Websocket connection EOF"); Websocketaf.Wsd.close wsd; - push_to_frames None + push_to_messages None in { Websocketaf.Websocket_connection.frame; eof } end