diff --git a/lib_eio/net.ml b/lib_eio/net.ml index d5207bd52..b7cac362d 100644 --- a/lib_eio/net.ml +++ b/lib_eio/net.ml @@ -269,3 +269,31 @@ let with_tcp_connect ?(timeout=Time.Timeout.none) ~host ~service t f = | exception (Exn.Io _ as ex) -> let bt = Printexc.get_raw_backtrace () in Exn.reraise_with_context ex bt "connecting to %S:%s" host service + +(* Run a server loop in a single domain. *) +let run_server_loop ~connections ~on_error ~stop listening_socket connection_handler = + Switch.run @@ fun sw -> + let rec accept () = + Semaphore.acquire connections; + accept_fork ~sw ~on_error listening_socket (fun conn addr -> + Fun.protect (fun () -> connection_handler conn addr) + ~finally:(fun () -> Semaphore.release connections) + ); + accept () + in + match stop with + | None -> accept () + | Some stop -> Fiber.first accept (fun () -> Promise.await stop) + +let run_server ?(max_connections=Int.max_int) ?(additional_domains) ?stop ~on_error listening_socket connection_handler : 'a = + if max_connections <= 0 then invalid_arg "max_connections"; + Switch.run @@ fun sw -> + let connections = Semaphore.make max_connections in + let run_server_loop () = run_server_loop ~connections ~on_error ~stop listening_socket connection_handler in + additional_domains |> Option.iter (fun (domain_mgr, domains) -> + if domains < 0 then invalid_arg "additional_domains"; + for _ = 1 to domains do + Fiber.fork ~sw (fun () -> Domain_manager.run domain_mgr (fun () -> ignore (run_server_loop () : 'a))) + done; + ); + run_server_loop () diff --git a/lib_eio/net.mli b/lib_eio/net.mli index 434d5672e..d9d51b6c0 100644 --- a/lib_eio/net.mli +++ b/lib_eio/net.mli @@ -191,11 +191,17 @@ val accept_fork : on_error:(exn -> unit) -> connection_handler -> unit -(** [accept_fork socket fn] accepts a connection and handles it in a new fiber. +(** [accept_fork ~sw ~on_error socket fn] accepts a connection and handles it in a new fiber. After accepting a connection to [socket], it runs [fn flow client_addr] in a new fiber. - [flow] will be closed when [fn] returns. *) + [flow] will be closed when [fn] returns. The new fiber is attached to [sw]. + + @param on_error Called if [connection_handler] raises an exception. + This is typically a good place to log the error and continue. + If the exception is an {!Eio.Io} error then the caller's address is added to it. + If you don't want to handle connection errors, + use [~on_error:raise] to cancel the caller's context. *) val accept_sub : sw:Switch.t -> @@ -205,6 +211,38 @@ val accept_sub : unit [@@deprecated "Use accept_fork instead"] +(** {2 Running Servers} *) + +val run_server : + ?max_connections:int -> + ?additional_domains:(#Domain_manager.t * int) -> + ?stop:'a Promise.t -> + on_error:(exn -> unit) -> + #listening_socket -> + connection_handler -> + 'a +(** [run_server ~on_error sock connection_handler] establishes a concurrent socket server [s]. + + It accepts incoming client connections on socket [sock] and handles them with {!accept_fork} + (see that for the description of [on_error] and [connection_handler]). + + {b Running a Parallel Server} + + By default [s] runs on a {e single} OCaml {!module:Domain}. However, if [additional_domains:(domain_mgr, domains)] + parameter is given, then [s] will spawn [domains] additional domains and run accept loops in those too. + In such cases you must ensure that [connection_handler] only accesses thread-safe values. + Note that having more than {!Domain.recommended_domain_count} domains in total is likely to result in bad performance. + + @param max_connections The maximum number of concurrent connections accepted by [s] at any time. + The default is [Int.max_int]. + @param stop Resolving this promise causes [s] to stop accepting new connections. + [run_server] will wait for all existing connections to finish and then return. + This is useful to upgrade a server without clients noticing. + To stop immediately, cancelling all connections, just cancel [s]'s fiber instead. + @param on_error Connection error handler (see {!accept_fork}). + @raise Invalid_argument if [max_connections <= 0]. + if [additional_domains = (domain_mgr, domains)] is used and [domains < 0]. *) + (** {2 Datagram Sockets} *) val datagram_socket : diff --git a/tests/network.md b/tests/network.md index a3ee943ff..0c2b36d1f 100644 --- a/tests/network.md +++ b/tests/network.md @@ -259,7 +259,7 @@ Extracting file descriptors from Eio objects: traceln "Listening socket has Unix FD: %b" (Eio_unix.FD.peek_opt server <> None); let have_client, have_server = Fiber.pair - (fun () -> + (fun () -> let flow = Eio.Net.connect ~sw net addr in (Eio_unix.FD.peek_opt flow <> None) ) @@ -710,3 +710,219 @@ TODO: This is wrong; see https://github.com/ocaml-multicore/eio/issues/342 - : unit = () ``` + +## run_server + +A simple connection handler for testing: +```ocaml +let handle_connection flow _addr = + let msg = read_all flow in + assert (msg = "Hi"); + Fiber.yield (); + Eio.Flow.copy_string "Bye" flow +``` + +A mock listening socket that allows acceping `n_clients` clients, each of which writes "Hi", +and then allows `n_domains` further attempts, none of which even complete: + +```ocaml +let mock_listener ~n_clients ~n_domains = + let make_flow i () = + if n_domains > 1 then Fiber.yield () (* Load balance *) + else Fiber.check (); + let flow = Eio_mock.Flow.make ("flow" ^ string_of_int i) in + Eio_mock.Flow.on_read flow [`Return "Hi"; `Raise End_of_file]; + flow, `Tcp (Eio.Net.Ipaddr.V4.loopback, 30000 + i) + in + let listening_socket = Eio_mock.Net.listening_socket "tcp/80" in + Eio_mock.Handler.set_handler listening_socket#on_accept Fiber.await_cancel; + Eio_mock.Net.on_accept listening_socket ( + List.init n_clients (fun i -> `Run (make_flow i)) @ + List.init n_domains (fun _ -> `Run Fiber.await_cancel) + ); + listening_socket +``` + +Start handling the connections, then begin a graceful shutdown, +allowing the connections to finish and then exiting: + +```ocaml +# Eio_mock.Backend.run @@ fun () -> + let listening_socket = mock_listener ~n_clients:3 ~n_domains:1 in + let stop, set_stop = Promise.create () in + Fiber.both + (fun () -> + Eio.Net.run_server listening_socket handle_connection + ~max_connections:10 + ~on_error:raise + ~stop + ) + (fun () -> + traceln "Begin graceful shutdown"; + Promise.resolve set_stop () + );; ++tcp/80: accepted connection from tcp:127.0.0.1:30000 ++flow0: read "Hi" ++tcp/80: accepted connection from tcp:127.0.0.1:30001 ++flow1: read "Hi" ++tcp/80: accepted connection from tcp:127.0.0.1:30002 ++flow2: read "Hi" ++Begin graceful shutdown ++flow0: wrote "Bye" ++flow0: closed ++flow1: wrote "Bye" ++flow1: closed ++flow2: wrote "Bye" ++flow2: closed +- : unit = () +``` + +Non-graceful shutdown, closing all connections still in progress: + +```ocaml +# Eio_mock.Backend.run @@ fun () -> + let listening_socket = mock_listener ~n_clients:3 ~n_domains:1 in + Fiber.both + (fun () -> + Eio.Net.run_server listening_socket handle_connection + ~max_connections:10 + ~on_error:raise + ) + (fun () -> failwith "Simulated error");; ++tcp/80: accepted connection from tcp:127.0.0.1:30000 ++flow0: read "Hi" ++tcp/80: accepted connection from tcp:127.0.0.1:30001 ++flow1: read "Hi" ++tcp/80: accepted connection from tcp:127.0.0.1:30002 ++flow2: read "Hi" ++flow0: closed ++flow1: closed ++flow2: closed +Exception: Failure "Simulated error". +``` + +To test support for multiple domains, we just run everything in one domain +to keep the output deterministic. We override `traceln` to log the (fake) +domain ID too: + +```ocaml +let with_domain_tracing id fn = + let traceln ?__POS__ fmt = + Eio.Private.Debug.default_traceln ?__POS__ ("[%d] " ^^ fmt) id + in + Fiber.with_binding Eio.Private.Debug.v#traceln { traceln } fn + +let fake_domain_mgr () = object (_ : #Eio.Domain_manager.t) + val mutable next_domain_id = 1 + + method run fn = + let self = next_domain_id in + next_domain_id <- next_domain_id + 1; + let cancelled, _ = Promise.create () in + with_domain_tracing self (fun () -> fn ~cancelled) + + method run_raw _ = assert false +end +``` + +Handling the connections with 3 domains, with a graceful shutdown: + +```ocaml +# Eio_mock.Backend.run @@ fun () -> + with_domain_tracing 0 @@ fun () -> + let n_domains = 3 in + let listening_socket = mock_listener ~n_clients:3 ~n_domains in + let stop, set_stop = Promise.create () in + Fiber.both + (fun () -> + Eio.Net.run_server listening_socket handle_connection + ~additional_domains:(fake_domain_mgr (), n_domains - 1) + ~max_connections:10 + ~on_error:raise + ~stop + ) + (fun () -> + Fiber.yield (); + traceln "Begin graceful shutdown"; + Promise.resolve set_stop () + );; ++[1] tcp/80: accepted connection from tcp:127.0.0.1:30000 ++[1] flow0: read "Hi" ++[2] tcp/80: accepted connection from tcp:127.0.0.1:30001 ++[2] flow1: read "Hi" ++[0] tcp/80: accepted connection from tcp:127.0.0.1:30002 ++[0] flow2: read "Hi" ++[0] Begin graceful shutdown ++[1] flow0: wrote "Bye" ++[1] flow0: closed ++[2] flow1: wrote "Bye" ++[2] flow1: closed ++[0] flow2: wrote "Bye" ++[0] flow2: closed +- : unit = () +``` + +Handling the connections with 3 domains, aborting immediately: + +```ocaml +# Eio_mock.Backend.run @@ fun () -> + with_domain_tracing 0 @@ fun () -> + let n_domains = 3 in + let listening_socket = mock_listener ~n_clients:3 ~n_domains in + Fiber.both + (fun () -> + Eio.Net.run_server listening_socket handle_connection + ~additional_domains:(fake_domain_mgr (), n_domains - 1) + ~max_connections:10 + ~on_error:raise + ) + (fun () -> Fiber.yield (); failwith "Simulated error");; ++[1] tcp/80: accepted connection from tcp:127.0.0.1:30000 ++[1] flow0: read "Hi" ++[2] tcp/80: accepted connection from tcp:127.0.0.1:30001 ++[2] flow1: read "Hi" ++[0] tcp/80: accepted connection from tcp:127.0.0.1:30002 ++[0] flow2: read "Hi" ++[1] flow0: closed ++[2] flow1: closed ++[0] flow2: closed +Exception: Failure "Simulated error". +``` + +Limiting to 2 concurrent connections: + +```ocaml +# Eio_mock.Backend.run @@ fun () -> + let listening_socket = mock_listener ~n_clients:10 ~n_domains:1 in + let stop, set_stop = Promise.create () in + Fiber.both + (fun () -> + Eio.Net.run_server listening_socket handle_connection + ~max_connections:2 + ~on_error:raise + ~stop + ) + (fun () -> + for _ = 1 to 2 do Fiber.yield () done; + traceln "Begin graceful shutdown"; + Promise.resolve set_stop () + );; ++tcp/80: accepted connection from tcp:127.0.0.1:30000 ++flow0: read "Hi" ++tcp/80: accepted connection from tcp:127.0.0.1:30001 ++flow1: read "Hi" ++flow0: wrote "Bye" ++flow0: closed ++flow1: wrote "Bye" ++flow1: closed ++tcp/80: accepted connection from tcp:127.0.0.1:30002 ++flow2: read "Hi" ++tcp/80: accepted connection from tcp:127.0.0.1:30003 ++flow3: read "Hi" ++Begin graceful shutdown ++flow2: wrote "Bye" ++flow2: closed ++flow3: wrote "Bye" ++flow3: closed +- : unit = () +```