diff --git a/lib/server.ml b/lib/server.ml index 3fea5d5..3836fce 100644 --- a/lib/server.ml +++ b/lib/server.ml @@ -176,26 +176,33 @@ module Command = struct type connection_handler = Server_intf.connection_handler + module Shutdown_resolver = struct + type t = unit -> unit + + let empty = Fun.id, Hashtbl.create 0 + end + type nonrec t = - { sockets : + { (* types like [_ array] mean per domain * listening address *) + sockets : Eio_unix.Net.listening_socket_ty Eio_unix.Net.listening_socket list - ; shutdown_resolvers : (unit -> unit) list + ; shutdown_resolvers : Shutdown_resolver.t array ; client_sockets : ( int , Eio_unix.Net.stream_socket_ty Eio_unix.Net.stream_socket ) Hashtbl.t - list + array ; clock : float Eio.Time.clock_ty r ; shutdown_timeout : float } let shutdown = let length sockets = - List.fold_left (fun acc item -> Hashtbl.length item + acc) 0 sockets + Array.fold_left (fun acc item -> Hashtbl.length item + acc) 0 sockets in fun { sockets; shutdown_resolvers; client_sockets; clock; shutdown_timeout } -> Logs.info (fun m -> m "Starting server teardown..."); - List.iter (fun resolver -> resolver ()) shutdown_resolvers; + Array.iter (fun resolver -> resolver ()) shutdown_resolvers; (* Close the server sockets to stop accepting new connections *) List.iter Eio.Net.close sockets; (* Wait for [shutdown_timeout] seconds before shutting down client @@ -214,7 +221,7 @@ module Command = struct Eio.Time.sleep clock shutdown_timeout; (* Shut down all client sockets after the shutdown timeout has elapsed. *) - List.iter + Array.iter (fun client_sockets -> Hashtbl.iter (fun _ client_socket -> @@ -227,40 +234,42 @@ module Command = struct client_sockets); Logs.info (fun m -> m "Server teardown finished") - let accept_loop ~sw ~socket ~client_sockets connection_handler = - let released_p, released_u = Promise.create () in - Fiber.fork ~sw (fun () -> - let id = ref 0 in - while not (Promise.is_resolved released_p) do - Fiber.first - (fun () -> Promise.await released_p) - (fun () -> - Eio.Net.accept_fork - socket - ~sw - ~on_error:(fun exn -> - let bt = Printexc.get_backtrace () in - Format.eprintf "sheesh: %s %s @." (Printexc.to_string exn) bt; - Logs.err (fun m -> - m "Error in connection handler: %s" (Printexc.to_string exn))) - (fun socket addr -> - Switch.run (fun sw -> - let () = - let connection_id = - let cid = !id in - incr id; - cid - in - Hashtbl.replace client_sockets connection_id socket; - Switch.on_release sw (fun () -> - Hashtbl.remove client_sockets connection_id) - in - connection_handler ~sw socket addr))) - done); - fun () -> Promise.resolve released_u () - - let listen - ~sw + let listen = + let accept_loop ~sw ~listening_socket ~client_sockets connection_handler = + let accept = + let id = ref 0 in + let rec accept () = + Eio.Net.accept_fork + listening_socket + ~sw + ~on_error:(fun exn -> + let bt = Printexc.get_backtrace () in + Logs.err (fun m -> + m + "Error in connection handler: %s@\n%s" + (Printexc.to_string exn) + bt)) + (fun socket addr -> + Switch.run (fun sw -> + let connection_id = + let cid = !id in + incr id; + cid + in + Hashtbl.replace client_sockets connection_id socket; + Switch.on_release sw (fun () -> + Hashtbl.remove client_sockets connection_id); + connection_handler ~sw socket addr)); + accept () + in + accept + in + let released_p, released_u = Promise.create () in + Fiber.fork ~sw (fun () -> + Fiber.first (fun () -> Promise.await released_p) accept); + fun () -> Promise.resolve released_u () + in + fun ~sw ~address ~backlog ~reuse_addr @@ -268,45 +277,55 @@ module Command = struct ~domains ~shutdown_timeout env - connection_handler - = - let socket = - let network = Eio.Stdenv.net env in - Eio.Net.listen ~reuse_addr ~reuse_port ~backlog ~sw network address - in - let resolvers = ref [] in - let resolver_mutex = Eio.Mutex.create () in - let all_started, resolve_all_started = Promise.create () in - for idx = 0 to domains - 1 do - Eio.Fiber.fork ~sw (fun () -> - let is_last_domain = idx = domains - 1 in - let run_accept_loop () = + connection_handler -> + let listening_socket = + let network = Eio.Stdenv.net env in + Eio.Net.listen ~reuse_addr ~reuse_port ~backlog ~sw network address + in + let resolvers = Array.make domains Shutdown_resolver.empty in + let started_domains = Eio.Semaphore.make domains in + let run_accept_loop = + let resolver_mutex = Eio.Mutex.create () in + fun idx -> Switch.run (fun sw -> - let resolver, client_sockets = + let resolver = let client_sockets = Hashtbl.create 256 in - ( accept_loop ~sw ~client_sockets ~socket connection_handler - , client_sockets ) + let resolver = + accept_loop + ~sw + ~client_sockets + ~listening_socket + connection_handler + in + resolver, client_sockets in Eio.Mutex.lock resolver_mutex; - resolvers := (resolver, client_sockets) :: !resolvers; + resolvers.(idx) <- resolver; Eio.Mutex.unlock resolver_mutex; - if is_last_domain then Promise.resolve resolve_all_started ()) - in - (* Last domain starts on the main thread. *) - if is_last_domain - then run_accept_loop () + Eio.Semaphore.acquire started_domains) + in + for idx = 0 to domains - 1 do + let run_accept_loop () = run_accept_loop idx in + if idx = domains - 1 + then + (* Last domain starts on the main thread. *) + Eio.Fiber.fork ~sw run_accept_loop else - let domain_mgr = Eio.Stdenv.domain_mgr env in - Eio.Domain_manager.run domain_mgr run_accept_loop) - done; - Promise.await all_started; - Logs.info (fun m -> m "Server listening on %a" Eio.Net.Sockaddr.pp address); - { sockets = [ socket ] - ; shutdown_resolvers = List.map fst !resolvers - ; client_sockets = List.map snd !resolvers - ; clock = Eio.Stdenv.clock env - ; shutdown_timeout - } + Eio.Fiber.fork ~sw (fun () -> + let domain_mgr = Eio.Stdenv.domain_mgr env in + Eio.Domain_manager.run domain_mgr run_accept_loop) + done; + while Eio.Semaphore.get_value started_domains > 0 do + Fiber.yield () + done; + Logs.info (fun m -> + m "Server listening on %a" Eio.Net.Sockaddr.pp address); + { sockets = [ listening_socket ] + ; shutdown_resolvers = Array.map fst resolvers + ; client_sockets = Array.map snd resolvers + ; clock = Eio.Stdenv.clock env + ; shutdown_timeout + } let start ~sw env server = let { config; _ } = server in @@ -345,8 +364,11 @@ module Command = struct in { sockets = https_command.sockets @ command.sockets ; shutdown_resolvers = - command.shutdown_resolvers @ https_command.shutdown_resolvers - ; client_sockets = command.client_sockets @ https_command.client_sockets + Array.append + command.shutdown_resolvers + https_command.shutdown_resolvers + ; client_sockets = + Array.append command.client_sockets https_command.client_sockets ; clock ; shutdown_timeout = config.shutdown_timeout }