From 20853165b412badd719ee3b023226fe02be1d4ad Mon Sep 17 00:00:00 2001 From: Antonio Nuno Monteiro Date: Sun, 15 Oct 2023 21:57:09 -0700 Subject: [PATCH] refactor: tighten bindings --- lib/body.ml | 26 ++++---- lib/client.ml | 128 +++++++++++++++++++++------------------- lib/connection.ml | 100 +++++++++++++++++-------------- lib/form.ml | 18 +++--- lib/http2.ml | 2 +- lib/http_impl.ml | 94 ++++++++++++++--------------- lib/http_server_impl.ml | 23 ++++---- lib/openssl.ml | 56 +++++++++++------- lib/posix.ml | 6 +- lib/request.ml | 8 ++- lib/request_info.ml | 31 ++++++++++ lib/response.ml | 44 +++++++------- lib/server.ml | 112 ++++++++++++++++++----------------- lib/util.ml | 2 +- 14 files changed, 358 insertions(+), 292 deletions(-) diff --git a/lib/body.ml b/lib/body.ml index ad947f6b..e1cbade3 100644 --- a/lib/body.ml +++ b/lib/body.ml @@ -161,10 +161,12 @@ let sendfile ?length path = (* TODO: accept buffer for I/O, so that caller can pool buffers? *) let stream_of_fd ?on_close fd = - let { Unix.st_size = length; _ } = - Eio_unix.run_in_systhread (fun () -> Unix.fstat fd) + let remaining = + let { Unix.st_size = length; _ } = + Eio_unix.run_in_systhread (fun () -> Unix.fstat fd) + in + Atomic.make length in - let remaining = Atomic.make length in Stream.from ~f:(fun () -> let current = Atomic.get remaining in if current = 0 @@ -343,15 +345,6 @@ module Raw = struct let rec read_fn () = let t = Lazy.force t in let p, u = Promise.create () in - let on_read_direct buffer ~off ~len = - total_len := Int64.add !total_len (Int64.of_int len); - Promise.resolve u (Some (IOVec.make buffer ~off ~len)) - and on_read_with_yield buffer ~off ~len = - total_len := Int64.add !total_len (Int64.of_int len); - Fiber.yield (); - Promise.resolve u (Some (IOVec.make buffer ~off ~len)) - in - t.read_counter <- t.read_counter + 1; let on_eof () = Option.iter (fun f -> f t) on_eof; Reader.close body; @@ -367,12 +360,21 @@ module Raw = struct Promise.resolve u None in let on_read = + let on_read_direct buffer ~off ~len = + total_len := Int64.add !total_len (Int64.of_int len); + Promise.resolve u (Some (IOVec.make buffer ~off ~len)) + and on_read_with_yield buffer ~off ~len = + total_len := Int64.add !total_len (Int64.of_int len); + Fiber.yield (); + Promise.resolve u (Some (IOVec.make buffer ~off ~len)) + in if t.read_counter > 128 then ( t.read_counter <- 0; on_read_with_yield) else on_read_direct in + t.read_counter <- t.read_counter + 1; Reader.schedule_read body ~on_eof ~on_read; Fiber.first (fun () -> Promise.await p) diff --git a/lib/client.ml b/lib/client.ml index d8c64cf5..69ff8f20 100644 --- a/lib/client.ml +++ b/lib/client.ml @@ -73,8 +73,8 @@ let create_https_connection ~sw ~config ~conn_info ~uri fd = let*! { ssl = ssl_client; ssl_ctx } = Openssl.connect ~config ~hostname:host fd in - let ssl_socket = Eio_ssl.Context.ssl_socket ssl_ctx in let (module Https), version = + let ssl_socket = Eio_ssl.Context.ssl_socket ssl_ctx in match Ssl.get_negotiated_alpn_protocol ssl_socket with | None -> Logs.warn (fun m -> @@ -127,9 +127,11 @@ let create_https_connection ~sw ~config ~conn_info ~uri fd = ssl_client let open_connection ~sw ~config ~uri env conn_info = - let clock = Eio.Stdenv.clock env in - let network = Eio.Stdenv.net env in - let*! socket = Connection.connect ~sw ~clock ~network ~config conn_info in + let*! socket = + let clock = Eio.Stdenv.clock env in + let network = Eio.Stdenv.net env in + Connection.connect ~sw ~clock ~network ~config conn_info + in (if config.Config.tcp_nodelay then let fd = Eio_unix.Resource.fd_opt socket |> Option.get in @@ -320,28 +322,27 @@ let make_request_info = let { Connection_info.host; scheme; _ } = info in let is_h2c_upgrade = is_h2c_upgrade ~config ~version ~scheme in - let h2_settings = H2.Settings.to_base64 (Config.to_http2_settings config) in - let canonical_headers = - (* Important that this doesn't shadow the labeled `headers` argument - * above. We need the original headers as issued by the caller in order to - * reproduce them e.g. when following redirects. *) - let headers = - let open Headers in - if is_h2c_upgrade - then - (Well_known.connection, "Upgrade, HTTP2-Settings") - :: (Well_known.upgrade, "h2c") - :: ("HTTP2-Settings", Result.get_ok h2_settings) - :: headers - else headers - in - Headers.canonicalize_headers - ~version - ~host - ~body_length:body.Body.length - headers - in let request = + let canonical_headers = + let h2_settings = + H2.Settings.to_base64 (Config.to_http2_settings config) + in + let headers = + let open Headers in + if is_h2c_upgrade + then + (Well_known.connection, "Upgrade, HTTP2-Settings") + :: (Well_known.upgrade, "h2c") + :: ("HTTP2-Settings", Result.get_ok h2_settings) + :: headers + else headers + in + Headers.canonicalize_headers + ~version + ~host + ~body_length:body.Body.length + headers + in Request.create ~meth ~version @@ -398,28 +399,29 @@ let rec send_request_and_handle_response ()); if Status.is_permanent_redirection response.status then conn.uri <- new_uri; - let target = Uri.path_and_query new_uri in - (* From RFC7231§6.4: - * Note: In HTTP/1.0, the status codes 301 (Moved Permanently) and 302 - * (Found) were defined for the first type of redirect ([RFC1945], - * Section 9.3). Early user agents split on whether the method applied - * to the redirect target would be the same as the original request or - * would be rewritten as GET. Although HTTP originally defined the former - * semantics for 301 and 302 (to match its original implementation at - * CERN), and defined 303 (See Other) to match the latter semantics, - * prevailing practice gradually converged on the latter semantics for - * 301 and 302 as well. The first revision of HTTP/1.1 added 307 - * (Temporary Redirect) to indicate the former semantics without being - * impacted by divergent practice. Over 10 years later, most user agents - * still do method rewriting for 301 and 302; therefore, this - * specification makes that behavior conformant when the original request - * is POST. *) - let meth' = - match meth, response.status with - | `POST, (`Found | `Moved_permanently) -> `GET - | _ -> meth - in let request_info' = + let target = Uri.path_and_query new_uri in + (* From RFC7231§6.4: + * Note: In HTTP/1.0, the status codes 301 (Moved Permanently) and + * 302 (Found) were defined for the first type of redirect + * ([RFC1945], Section 9.3). Early user agents split on whether + * the method applied to the redirect target would be the same as + * the original request or would be rewritten as GET. Although + * HTTP originally defined the former semantics for 301 and 302 (to + * match its original implementation at CERN), and defined 303 (See + * Other) to match the latter semantics, prevailing practice + * gradually converged on the latter semantics for 301 and 302 as + * well. The first revision of HTTP/1.1 added 307 (Temporary + * Redirect) to indicate the former semantics without being + * impacted by divergent practice. Over 10 years later, most user + * agents still do method rewriting for 301 and 302; therefore, + * this specification makes that behavior conformant when the + * original request is POST. *) + let meth' = + match meth, response.status with + | `POST, (`Found | `Moved_permanently) -> `GET + | _ -> meth + in make_request_info t ~remaining_redirects:(remaining_redirects - 1) @@ -452,8 +454,10 @@ let call t ~meth ?(headers = []) ?(body = Body.empty) target = match reused with | Error #Error.client as err -> err | Ok _ -> - let headers = t.config.default_headers @ headers in - let request_info = make_request_info t ~meth ~headers ~body target in + let request_info = + let headers = t.config.default_headers @ headers in + make_request_info t ~meth ~headers ~body target + in let (Connection.Conn conn) = t.conn in conn.persistent <- Request.persistent_connection request_info.request; send_request_and_handle_response t ~body request_info @@ -480,20 +484,22 @@ let ws_upgrade : -> (Ws.Descriptor.t, [> Error.client ]) result = fun t ?(headers = []) target -> - let (Conn { info; _ }) = t.conn in - (* From RFC6455§4.1: - * The value of this header field MUST be a nonce consisting of a randomly - * selected 16-byte value that has been base64-encoded (see Section 4 of - * [RFC4648]). The nonce MUST be selected randomly for each connection. *) - let nonce = Openssl.random_string 16 in - let request = - Ws.upgrade_request - ~headers:(Httpaf.Headers.of_list headers) - ~scheme:info.scheme - ~nonce - target + let*! response = + let request = + let (Conn { info; _ }) = t.conn in + (* From RFC6455§4.1: + * The value of this header field MUST be a nonce consisting of a randomly + * selected 16-byte value that has been base64-encoded (see Section 4 of + * [RFC4648]). The nonce MUST be selected randomly for each connection. *) + let nonce = Openssl.random_string 16 in + Ws.upgrade_request + ~headers:(Httpaf.Headers.of_list headers) + ~scheme:info.scheme + ~nonce + target + in + send t request in - let*! response = send t request in match Body.drain response.body with | Error #Error.t as err -> err | Ok () -> Http_impl.upgrade_connection ~sw:t.sw t.conn diff --git a/lib/connection.ml b/lib/connection.ml index db84aac2..d6589fe5 100644 --- a/lib/connection.ml +++ b/lib/connection.ml @@ -35,51 +35,61 @@ module Version = Httpaf.Version module Logs = (val Logging.setup ~src:"piaf.connection" ~doc:"Piaf Connection module") -let resolve_host env ~config ~port hostname : (_, [> Error.client ]) result = - let clock = Eio.Stdenv.clock env in - let network = Eio.Stdenv.net env in - match - Eio.Time.with_timeout_exn clock config.Config.connect_timeout (fun () -> - Eio.Net.getaddrinfo_stream ~service:(string_of_int port) network hostname) - with - | [] -> - Error - (`Connect_error (Format.asprintf "Can't resolve hostname: %s" hostname)) - | xs -> - (match config.Config.prefer_ip_version with - | `Both -> - let order_v4v6 = Eio.Net.Ipaddr.fold ~v4:(fun _ -> -1) ~v6:(fun _ -> 1) in - Ok - (* Sort IPv4 ahead of IPv6 for compatibility. *) - (List.sort - (fun a1 a2 -> - match a1, a2 with - | `Unix s1, `Unix s2 -> String.compare s1 s2 - | `Tcp (ip1, _), `Tcp (ip2, _) -> - compare (order_v4v6 ip1) (order_v4v6 ip2) - | `Unix _, `Tcp _ -> 1 - | `Tcp _, `Unix _ -> -1) - xs) - | `V4 -> - Ok - (List.filter - (function - | `Tcp (ip, _) -> - Eio.Net.Ipaddr.fold ~v4:(fun _ -> true) ~v6:(fun _ -> false) ip - | `Unix _ -> true) - xs) - | `V6 -> - Ok - (List.filter - (function - | `Tcp (ip, _) -> - Eio.Net.Ipaddr.fold ~v4:(fun _ -> false) ~v6:(fun _ -> true) ip - | `Unix _ -> true) - xs)) - | exception Eio.Time.Timeout -> - Error - (`Connect_error - (Format.asprintf "Timed out resolving hostname: %s" hostname)) +let resolve_host = + let order_v4v6 = Eio.Net.Ipaddr.fold ~v4:(fun _ -> -1) ~v6:(fun _ -> 1) in + fun (env : Eio_unix.Stdenv.base) ~config ~port hostname -> + let clock = Eio.Stdenv.clock env in + let network = Eio.Stdenv.net env in + match + Eio.Time.with_timeout_exn clock config.Config.connect_timeout (fun () -> + Eio.Net.getaddrinfo_stream + ~service:(string_of_int port) + network + hostname) + with + | [] -> + Error + (`Connect_error (Format.asprintf "Can't resolve hostname: %s" hostname)) + | xs -> + (match config.Config.prefer_ip_version with + | `Both -> + Ok + (* Sort IPv4 ahead of IPv6 for compatibility. *) + (List.sort + (fun a1 a2 -> + match a1, a2 with + | `Unix s1, `Unix s2 -> String.compare s1 s2 + | `Tcp (ip1, _), `Tcp (ip2, _) -> + compare (order_v4v6 ip1) (order_v4v6 ip2) + | `Unix _, `Tcp _ -> 1 + | `Tcp _, `Unix _ -> -1) + xs) + | `V4 -> + Ok + (List.filter + (function + | `Tcp (ip, _) -> + Eio.Net.Ipaddr.fold + ~v4:(fun _ -> true) + ~v6:(fun _ -> false) + ip + | `Unix _ -> true) + xs) + | `V6 -> + Ok + (List.filter + (function + | `Tcp (ip, _) -> + Eio.Net.Ipaddr.fold + ~v4:(fun _ -> false) + ~v6:(fun _ -> true) + ip + | `Unix _ -> true) + xs)) + | exception Eio.Time.Timeout -> + Error + (`Connect_error + (Format.asprintf "Timed out resolving hostname: %s" hostname)) module Info = struct (* This represents information that changes from connection to connection, diff --git a/lib/form.ml b/lib/form.ml index 8ed2b7d8..5e22ef48 100644 --- a/lib/form.ml +++ b/lib/form.ml @@ -56,15 +56,17 @@ module Multipart = struct let stream (* , _or_error *) = Body.to_stream request.body in let kvs, push_to_kvs = Stream.create 128 in let emit name stream = push_to_kvs (Some (name, stream)) in - let+! multipart = - Multipart.parse_multipart_form - ~content_type - ~max_chunk_size - ~emit - ~finish:(fun () -> push_to_kvs None) - stream + let+! multipart_fields = + let+! multipart = + Multipart.parse_multipart_form + ~content_type + ~max_chunk_size + ~emit + ~finish:(fun () -> push_to_kvs None) + stream + in + Multipart.result_fields multipart in - let multipart_fields = Multipart.result_fields multipart in Stream.map ~f:(fun (name, stream) -> let name = Option.get name in diff --git a/lib/http2.ml b/lib/http2.ml index 4bdc1ede..77bffcf1 100644 --- a/lib/http2.ml +++ b/lib/http2.ml @@ -282,7 +282,7 @@ module HTTP : Http_intf.HTTP2 with type scheme = Scheme.http = struct ~error_handler:(make_client_error_handler error_handler `Connection) (response_handler, response_error_handler) in - Stdlib.Result.map + Result.map (fun connection -> (* Perform the runtime upgrade -- stop speaking HTTP/1.1, start * speaking HTTP/2 by feeding Gluten the `H2.Client_connection` diff --git a/lib/http_impl.ml b/lib/http_impl.ml index 0a7bea69..8505b38b 100644 --- a/lib/http_impl.ml +++ b/lib/http_impl.ml @@ -58,11 +58,11 @@ let create_connection let connection_error_received, notify_connection_error_received = Promise.create () in - let error_handler = make_error_handler notify_connection_error_received in - let connection, runtime = - Http_impl.Client.create_connection ~config ~error_handler ~sw socket - in let conn = + let error_handler = make_error_handler notify_connection_error_received in + let connection, runtime = + Http_impl.Client.create_connection ~config ~error_handler ~sw socket + in Connection.Conn { impl = (module Http_impl) ; connection @@ -142,25 +142,27 @@ let send_request : let module Client = Http.Client in let module Bodyw = Http.Body.Writer in let response_received, notify_response = Promise.create () in - let response_handler response = Promise.resolve notify_response response in let error_received, notify_error = Promise.create () in - let error_handler = make_error_handler notify_error in - Logs.info (fun m -> - m "@[Sending request:@]@]@;<0 2>@[%a@]@." Request.pp_hum request); - let flush_headers_immediately = - match body.contents with - | `Sendfile _ -> true - | _ -> config.flush_headers_immediately - in - let request_body = - Http.Client.request - connection - ~flush_headers_immediately - ~error_handler - ~response_handler - request - in Fiber.fork ~sw (fun () -> + Logs.info (fun m -> + m "@[Sending request:@]@]@;<0 2>@[%a@]@." Request.pp_hum request); + let request_body = + let error_handler = make_error_handler notify_error in + let response_handler response = + Promise.resolve notify_response response + in + let flush_headers_immediately = + match body.contents with + | `Sendfile _ -> true + | _ -> config.flush_headers_immediately + in + Http.Client.request + connection + ~flush_headers_immediately + ~error_handler + ~response_handler + request + in match body.contents with | `Empty _ -> Bodyw.close request_body | `String s -> @@ -206,11 +208,11 @@ let upgrade_connection : let wsd_received, notify_wsd = Promise.create () in let error_received, notify_error = Promise.create () in - let error_handler _wsd error = - Promise.resolve notify_error (error :> Error.client) - in Logs.info (fun m -> m "Upgrading connection to the Websocket protocol"); let ws_conn = + let error_handler _wsd error = + Promise.resolve notify_error (error :> Error.client) + in Websocketaf.Client_connection.create ~error_handler (Ws.Handler.websocket_handler ~sw ~notify_wsd) @@ -250,37 +252,37 @@ let create_h2c_connection | `HTTP -> let (module Http2) = (module Http2.HTTP : Http_intf.HTTP2) in let response_received, notify_response_received = Promise.create () in - let response_handler response = - Promise.resolve notify_response_received response - in let connection_error_received, notify_error_received = Promise.create () in - let error_handler = make_error_handler notify_error_received in let response_error_received, notify_response_error_received = Promise.create () in - let response_error_handler = - make_error_handler notify_response_error_received + let result = + let response_handler response = + Promise.resolve notify_response_received response + in + let error_handler = make_error_handler notify_error_received in + let response_error_handler = + make_error_handler notify_response_error_received + in + Http2.Client.create_h2c + ~config + ~http_request:(Request.to_http1 http_request) + ~error_handler + (response_handler, response_error_handler) + runtime in - (match - Http2.Client.create_h2c - ~config - ~http_request:(Request.to_http1 http_request) - ~error_handler - (response_handler, response_error_handler) - runtime - with + (match result with | Ok connection -> Logs.info (fun m -> m "Connection state changed (HTTP/2 confirmed)"); (* Doesn't write the body by design. The server holds on to the HTTP/1.1 body * that was sent as part of the upgrade. *) - let result = - handle_response - ~sw - response_received - response_error_received - connection_error_received - in - (match result with + (match + handle_response + ~sw + response_received + response_error_received + connection_error_received + with | Ok response -> let connection = Connection.Conn diff --git a/lib/http_server_impl.ml b/lib/http_server_impl.ml index 81071508..23ce0aac 100644 --- a/lib/http_server_impl.ml +++ b/lib/http_server_impl.ml @@ -66,6 +66,9 @@ let do_sendfile : fun (module Http) ~src_fd ~fd ~report_exn response_body -> let fd = Option.get (Eio_unix.Resource.fd_opt fd) in Eio_unix.Fd.use_exn "sendfile" fd (fun fd -> + (* Flush everything to the wire before calling `sendfile`, as we're gonna + bypass the http/af runtime and write bytes to the file descriptor + directly. *) Http.Body.Writer.flush response_body (fun () -> match Posix.sendfile @@ -226,17 +229,15 @@ let handle_error : response_body | `HTTPS -> failwith "sendfile is not supported in HTTPS connections") in - try - Logs.warn (fun m -> - m - "Error handler called with error: %a%a" - Error.pp_hum - error - (Format.pp_print_option (fun fmt request -> - Format.fprintf fmt "; Request: @?%a" Request.pp_hum request)) - request); - error_handler client_address ?request ~respond error - with + Logs.warn (fun m -> + m + "Error handler called with error: %a%a" + Error.pp_hum + error + (Format.pp_print_option (fun fmt request -> + Format.fprintf fmt "; Request: @?%a" Request.pp_hum request)) + request); + try error_handler client_address ?request ~respond error with | exn -> Logs.err (fun m -> let raw_backtrace = Printexc.get_raw_backtrace () in diff --git a/lib/openssl.ml b/lib/openssl.ml index a64e2dc6..ecded846 100644 --- a/lib/openssl.ml +++ b/lib/openssl.ml @@ -164,14 +164,17 @@ let version_to_ssl = function | TLSv1_3 -> TLSv1_3 let protocols_to_disable min max = - let f = - match min, max with - | Versions.TLS.Any, _ -> fun x -> Versions.TLS.compare x max > 0 - | _, Versions.TLS.Any -> fun x -> Versions.TLS.compare x min < 0 - | _ -> - fun x -> Versions.TLS.compare x min < 0 || Versions.TLS.compare x max > 0 + let protocols, _ = + let f = + match min, max with + | Versions.TLS.Any, _ -> fun x -> Versions.TLS.compare x max > 0 + | _, Versions.TLS.Any -> fun x -> Versions.TLS.compare x min < 0 + | _ -> + fun x -> + Versions.TLS.compare x min < 0 || Versions.TLS.compare x max > 0 + in + List.partition f Versions.TLS.ordered in - let protocols, _ = List.partition f Versions.TLS.ordered in protocols module Error = struct @@ -268,7 +271,6 @@ let setup_client_ctx ~hostname fd = - let alpn_protocols = Versions.ALPN.protocols_of_version max_http_version in match Ssl.( create_context @@ -284,10 +286,15 @@ let setup_client_ctx Ssl.set_max_protocol_version ctx (Versions.TLS.to_max_version max_tls_version); - List.iter - (fun proto -> Logs.info (fun m -> m "ALPN: offering %s" proto)) - alpn_protocols; - Ssl.set_context_alpn_protos ctx alpn_protocols; + let () = + let alpn_protocols = + Versions.ALPN.protocols_of_version max_http_version + in + List.iter + (fun proto -> Logs.info (fun m -> m "ALPN: offering %s" proto)) + alpn_protocols; + Ssl.set_context_alpn_protos ctx alpn_protocols + in (* Use the server's preferences rather than the client's *) Ssl.honor_cipher_order ctx; let*! () = @@ -302,9 +309,8 @@ let setup_client_ctx let ssl_ctx = Eio_ssl.Context.create ~ctx fd in let ssl_sock = Eio_ssl.Context.ssl_socket ssl_ctx in (* If hostname is an IP address, check that instead of the hostname *) - let ipaddr = Ipaddr.of_string hostname in - (match ipaddr with - | Ok ipadr -> Ssl.set_ip ssl_sock (Ipaddr.to_string ipadr) + (match Ipaddr.of_string hostname with + | Ok ipaddr -> Ssl.set_ip ssl_sock (Ipaddr.to_string ipaddr) | _ -> Ssl.set_client_SNI_hostname ssl_sock hostname; (* https://wiki.openssl.org/index.php/Hostname_validation *) @@ -409,7 +415,6 @@ let setup_server_ctx } ~max_http_version = - let alpn_protocols = Versions.ALPN.protocols_of_version max_http_version in match Ssl.( create_context @@ -425,6 +430,7 @@ let setup_server_ctx (protocols_to_disable min_tls_version max_tls_version) in Ssl.disable_protocols ctx disabled_protocols; + let alpn_protocols = Versions.ALPN.protocols_of_version max_http_version in List.iter (fun proto -> Logs.info (fun m -> m "ALPN: offering %s" proto)) alpn_protocols; @@ -446,8 +452,10 @@ let setup_server_ctx (* assumes an `accept`ed socket *) let get_negotiated_alpn_protocol ssl_ctx = - let ssl_socket = Eio_ssl.Context.ssl_socket ssl_ctx in - match Ssl.get_negotiated_alpn_protocol ssl_socket with + match + let ssl_socket = Eio_ssl.Context.ssl_socket ssl_ctx in + Ssl.get_negotiated_alpn_protocol ssl_socket + with | Some "http/1.1" -> Versions.HTTP.HTTP_1_1 | Some "h2" -> HTTP_2 | None (* Unable to negotiate a protocol *) | Some _ -> @@ -470,14 +478,18 @@ let accept ~timeout fd = - let*! ctx = setup_server_ctx ~config ~max_http_version in - let ssl_ctx = Eio_ssl.Context.create ~ctx fd in + let*! ssl_ctx = + let+! ctx = setup_server_ctx ~config ~max_http_version in + Eio_ssl.Context.create ~ctx fd + in match Eio.Time.with_timeout clock timeout (fun () -> Ok (Eio_ssl.accept ssl_ctx)) with | Ok ssl_server -> - let alpn_version = get_negotiated_alpn_protocol ssl_ctx in - Ok { socket = ssl_server; alpn_version } + Ok + { socket = ssl_server + ; alpn_version = get_negotiated_alpn_protocol ssl_ctx + } | Error `Timeout -> Result.error (`Connect_error diff --git a/lib/posix.ml b/lib/posix.ml index 27005e50..959298b9 100644 --- a/lib/posix.ml +++ b/lib/posix.ml @@ -39,11 +39,7 @@ let sendfile ~dst_fd raw_write_body = - (* Flush everything to the wire before calling `sendfile`, as we're gonna - bypass the http/af runtime and write bytes to the file descriptor - directly. *) - let sent_ret = Sendfile.sendfile ~src:src_fd dst_fd in - match sent_ret with + match Sendfile.sendfile ~src:src_fd dst_fd with | Ok sent -> (* NOTE(anmonteiro): we don't need to * `Gluten.Server.report_write_result` here given that we put diff --git a/lib/request.ml b/lib/request.ml index 7936091c..7692fbbf 100644 --- a/lib/request.ml +++ b/lib/request.ml @@ -39,9 +39,11 @@ type t = } let uri { scheme; target; headers; version; _ } = - let host = Headers.host ~version headers in - let scheme = Scheme.to_string scheme in - let uri = Uri.with_uri ~host ~scheme:(Some scheme) (Uri.of_string target) in + let uri = + let scheme = Scheme.to_string scheme in + let host = Headers.host ~version headers in + Uri.with_uri ~host ~scheme:(Some scheme) (Uri.of_string target) + in Uri.canonicalize uri let create ~scheme ~version ?(headers = Headers.empty) ~meth ~body target = diff --git a/lib/request_info.ml b/lib/request_info.ml index cf4281b7..f99ff1ce 100644 --- a/lib/request_info.ml +++ b/lib/request_info.ml @@ -1,3 +1,34 @@ +(*---------------------------------------------------------------------------- + * Copyright (c) 2022-2023, António Nuno Monteiro + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + *---------------------------------------------------------------------------*) + type t = { scheme : Scheme.t ; version : Versions.HTTP.t diff --git a/lib/response.ml b/lib/response.ml index 9769b27e..0871f967 100644 --- a/lib/response.ml +++ b/lib/response.ml @@ -61,18 +61,14 @@ let of_stream ?version ?headers ~body status = create ?version ?headers ~body:(Body.of_stream body) status let sendfile ?version ?(headers = Headers.empty) path = - let mime = Magic_mime.lookup path in + let+! body = Body.sendfile path in let headers = + let mime = Magic_mime.lookup path in Headers.(add_unless_exists headers Well_known.content_type mime) in - let+! body = Body.sendfile path in create ?version ~headers ~body `OK let copy_file ?version ?(headers = Headers.empty) path = - let mime = Magic_mime.lookup path in - let headers = - Headers.(add_unless_exists headers Well_known.content_type mime) - in let*! fd = try Eio_unix.run_in_systhread (fun () -> @@ -82,6 +78,10 @@ let copy_file ?version ?(headers = Headers.empty) path = | exn -> Result.error (`Exn exn) in + let headers = + let mime = Magic_mime.lookup path in + Headers.(add_unless_exists headers Well_known.content_type mime) + in let stream = Body.stream_of_fd fd in Ok (create @@ -105,25 +105,25 @@ module Upgrade = struct match request.version with | HTTP_1_0 | HTTP_2 -> Error `Upgrade_not_supported | HTTP_1_1 -> - let wsd_received, notify_wsd = Promise.create () in - let _error_received, notify_error = Promise.create () in - let upgrade_handler ~sw upgrade = - let error_handler _wsd error = - Promise.resolve notify_error (error :> Error.client) - in - - let ws_conn = - Websocketaf.Server_connection.create_websocket - ~error_handler - (Ws.Handler.websocket_handler ~sw ~notify_wsd) - in - Fiber.fork ~sw (fun () -> f (Promise.await wsd_received)); - upgrade (Gluten.make (module Websocketaf.Server_connection) ws_conn) + let upgrade_handler = + let wsd_received, notify_wsd = Promise.create () in + let _error_received, notify_error = Promise.create () in + fun ~sw upgrade -> + let error_handler _wsd error = + Promise.resolve notify_error (error :> Error.client) + in + + let ws_conn = + Websocketaf.Server_connection.create_websocket + ~error_handler + (Ws.Handler.websocket_handler ~sw ~notify_wsd) + in + Fiber.fork ~sw (fun () -> f (Promise.await wsd_received)); + upgrade (Gluten.make (module Websocketaf.Server_connection) ws_conn) in - let httpaf_headers = Headers.to_http1 request.headers in - (match + let httpaf_headers = Headers.to_http1 request.headers in Websocketaf.Handshake.upgrade_headers ~sha1:Openssl.sha1 ~request_method:request.meth diff --git a/lib/server.ml b/lib/server.ml index fcec1428..3fea5d51 100644 --- a/lib/server.ml +++ b/lib/server.ml @@ -74,62 +74,58 @@ let is_requesting_h2c_upgrade ~config ~version ~scheme headers = | _ -> false let do_h2c_upgrade ~sw ~fd ~request_body server = - let { config; error_handler; handler } = server in - let upgrade_handler ~sw:_ client_address (request : Request.t) upgrade = - let http_request = - Httpaf.Request.create - ~headers: - (Httpaf.Headers.of_rev_list (Headers.to_rev_list request.headers)) - request.meth - request.target - in - let connection = - Result.get_ok - (Http2.HTTP.Server.create_h2c_connection_handler - ~config - ~sw - ~fd - ~error_handler - ~http_request - ~request_body - ~client_address - handler) - in - upgrade (Gluten.make (module H2.Server_connection) connection) + let upgrade_handler = + let { config; error_handler; handler } = server in + fun ~sw:_ client_address (request : Request.t) upgrade -> + let http_request = + Httpaf.Request.create + ~headers: + (Httpaf.Headers.of_rev_list (Headers.to_rev_list request.headers)) + request.meth + request.target + in + let connection = + Result.get_ok + (Http2.HTTP.Server.create_h2c_connection_handler + ~config + ~sw + ~fd + ~error_handler + ~http_request + ~request_body + ~client_address + handler) + in + upgrade (Gluten.make (module H2.Server_connection) connection) in - let request_handler { request; ctx = { Request_info.client_address; _ } } = + fun { request; ctx = { Request_info.client_address; _ } } -> let headers = Headers.( of_list [ Well_known.connection, "Upgrade"; Well_known.upgrade, "h2c" ]) in Response.Upgrade.generic ~headers (upgrade_handler client_address request) - in - request_handler let http_connection_handler t : connection_handler = - let { error_handler; handler; config } = t in let (module Http) = - match config.max_http_version, config.h2c_upgrade with + match t.config.max_http_version, t.config.h2c_upgrade with | HTTP_2, true | (HTTP_1_0 | HTTP_1_1), _ -> (module Http1.HTTP : Http_intf.HTTP) | HTTP_2, false -> (module Http2.HTTP : Http_intf.HTTP) in fun ~sw socket client_address -> - let request_handler - ({ request; ctx = { Request_info.client_address = _; scheme; _ } } as - ctx) - = - match - is_requesting_h2c_upgrade - ~config - ~version:request.version - ~scheme - request.headers - with + let { error_handler; handler; config } = t in + let request_handler ctx = + let { request = { version; headers; body; _ } + ; ctx = { Request_info.scheme; _ } + } + = + ctx + in + match is_requesting_h2c_upgrade ~config ~version ~scheme headers with | false -> handler ctx | true -> let h2c_handler = - let request_body = Body.to_list request.body in + let request_body = Body.to_list body in do_h2c_upgrade ~sw ~fd:socket ~request_body t in h2c_handler ctx @@ -249,14 +245,16 @@ module Command = struct m "Error in connection handler: %s" (Printexc.to_string exn))) (fun socket addr -> Switch.run (fun sw -> - let connection_id = - let cid = !id in - incr id; - cid + let () = + let connection_id = + let cid = !id in + incr id; + cid + in + Hashtbl.replace client_sockets connection_id socket; + Switch.on_release sw (fun () -> + Hashtbl.remove client_sockets connection_id) in - Hashtbl.replace client_sockets connection_id socket; - Switch.on_release sw (fun () -> - Hashtbl.remove client_sockets connection_id); connection_handler ~sw socket addr))) done); fun () -> Promise.resolve released_u () @@ -272,9 +270,8 @@ module Command = struct env connection_handler = - let domain_mgr = Eio.Stdenv.domain_mgr env in - let network = Eio.Stdenv.net env in let socket = + let network = Eio.Stdenv.net env in Eio.Net.listen ~reuse_addr ~reuse_port ~backlog ~sw network address in let resolvers = ref [] in @@ -285,9 +282,10 @@ module Command = struct let is_last_domain = idx = domains - 1 in let run_accept_loop () = Switch.run (fun sw -> - let client_sockets = Hashtbl.create 256 in - let resolver = - accept_loop ~sw ~client_sockets ~socket connection_handler + let resolver, client_sockets = + let client_sockets = Hashtbl.create 256 in + ( accept_loop ~sw ~client_sockets ~socket connection_handler + , client_sockets ) in Eio.Mutex.lock resolver_mutex; resolvers := (resolver, client_sockets) :: !resolvers; @@ -297,7 +295,9 @@ module Command = struct (* Last domain starts on the main thread. *) if is_last_domain then run_accept_loop () - else Eio.Domain_manager.run domain_mgr run_accept_loop) + else + let domain_mgr = Eio.Stdenv.domain_mgr env in + Eio.Domain_manager.run domain_mgr run_accept_loop) done; Promise.await all_started; Logs.info (fun m -> m "Server listening on %a" Eio.Net.Sockaddr.pp address); @@ -310,10 +310,9 @@ module Command = struct let start ~sw env server = let { config; _ } = server in - let clock = Eio.Stdenv.clock env in (* TODO(anmonteiro): config option to listen only in HTTPS? *) - let connection_handler = http_connection_handler server in let command = + let connection_handler = http_connection_handler server in listen ~sw ~address:config.address @@ -328,8 +327,11 @@ module Command = struct match config.https with | None -> command | Some https -> - let connection_handler = https_connection_handler ~clock ~https server in + let clock = Eio.Stdenv.clock env in let https_command = + let connection_handler = + https_connection_handler ~clock ~https server + in listen ~sw ~address:https.address diff --git a/lib/util.ml b/lib/util.ml index 72847403..09cebf29 100644 --- a/lib/util.ml +++ b/lib/util.ml @@ -38,8 +38,8 @@ module Uri = struct | None -> raise (Failure "host_exn") let parse_with_base_uri ~scheme ~uri location = - let location_uri = Uri.of_string location in let new_uri = + let location_uri = Uri.of_string location in match Uri.host location_uri with | Some _ -> location_uri | None ->