From fad44f54bf0d1aff9a439d6624003582da8ab6cd Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Fri, 11 Sep 2020 16:56:23 +0200 Subject: [PATCH] move Tcpip_stack_socket to Tcpip_stack_socket.V4 provide Tcpip_stack_socket.V6 implementing Mirage_stack.V6 --- src/stack-unix/dune | 5 +- src/stack-unix/ipv4_socket.ml | 2 +- src/stack-unix/ipv6_socket.ml | 19 +- src/stack-unix/tcpip_stack_socket.ml | 310 +++++++++++++++++--------- src/stack-unix/tcpip_stack_socket.mli | 20 +- src/stack-unix/udpv6_socket.ml | 22 +- test/test_socket.ml | 2 +- 7 files changed, 242 insertions(+), 138 deletions(-) diff --git a/src/stack-unix/dune b/src/stack-unix/dune index f1dc1e1e8..4ed6f2591 100644 --- a/src/stack-unix/dune +++ b/src/stack-unix/dune @@ -50,5 +50,6 @@ (modules tcpip_stack_socket ipv4_socket ipv6_socket) (wrapped false) (libraries lwt.unix cstruct-lwt ipaddr.unix logs - tcpip.tcpv4-socket tcpip.udpv4-socket tcpip.ipv4 tcpip.ipv6 tcpip.icmpv4 - mirage-protocols mirage-stack)) + tcpip.tcpv4-socket tcpip.udpv4-socket tcpip.ipv4 + tcpip.tcpv6-socket tcpip.udpv6-socket tcpip.ipv6 + tcpip.icmpv4 mirage-protocols mirage-stack)) diff --git a/src/stack-unix/ipv4_socket.ml b/src/stack-unix/ipv4_socket.ml index 06616fd79..77e387f69 100644 --- a/src/stack-unix/ipv4_socket.ml +++ b/src/stack-unix/ipv4_socket.ml @@ -34,6 +34,6 @@ let connect _ = return_unit let input _ ~tcp:_ ~udp:_ ~default:_ _ = return_unit let write _ ?fragment:_ ?ttl:_ ?src:_ _ _ ?size:_ _ _ = fail (Failure "Not implemented") -let get_ip _ = [Ipaddr.V4.of_string_exn "0.0.0.0"] +let get_ip _ = [Ipaddr.V4.any] let src _ ~dst:_ = raise (Failure "Not implemented") let pseudoheader _ ?src:_ _ _ _ = raise (Failure "Not implemented") diff --git a/src/stack-unix/ipv6_socket.ml b/src/stack-unix/ipv6_socket.ml index ea8b00b40..4300ca7b4 100644 --- a/src/stack-unix/ipv6_socket.ml +++ b/src/stack-unix/ipv6_socket.ml @@ -17,15 +17,16 @@ open Lwt -type id = string -type ip = unit type t = unit type +'a io = 'a Lwt.t -type error = [ `Unimplemented | `Unknown of string ] +type error = Mirage_protocols.Ip.error type ipaddr = Ipaddr.V6.t type buffer = Cstruct.t type callback = src:ipaddr -> dst:ipaddr -> buffer -> unit io +let pp_error = Mirage_protocols.Ip.pp_error +let pp_ipaddr = Ipaddr.V6.pp + let mtu _ = 1500 - Ipv6_wire.sizeof_ipv6 let id _ = () @@ -33,14 +34,8 @@ let disconnect () = return_unit let connect () = return_unit let input _ ~tcp:_ ~udp:_ ~default:_ _ = return_unit -let allocate_frame _ ~dst:_ ~proto:_ = raise (Failure "Not implemented") -let write _ _ _ = fail (Failure "Not implemented") -let writev _ _ _ = fail (Failure "Not implemented") - -let get_ip _ = Ipaddr.V6.of_string_exn "::" -let set_ip _ _ = fail (Failure "Not implemented") -let get_ip_gateways _ = raise (Failure "Not implemented") -let set_ip_gateways _ _ = fail (Failure "Not implemented") +let write _ ?fragment:_ ?ttl:_ ?src:_ _ _ ?size:_ _ _ = fail (Failure "Not implemented") -let checksum _ _ = raise (Failure "Not implemented") +let get_ip _ = [Ipaddr.V6.unspecified] let src _ ~dst:_ = raise (Failure "Not implemented") +let pseudoheader _ ?src:_ _ _ _ = raise (Failure "Not implemented") diff --git a/src/stack-unix/tcpip_stack_socket.ml b/src/stack-unix/tcpip_stack_socket.ml index 94b909e79..8707c2f6c 100644 --- a/src/stack-unix/tcpip_stack_socket.ml +++ b/src/stack-unix/tcpip_stack_socket.ml @@ -19,112 +19,204 @@ open Lwt.Infix let src = Logs.Src.create "tcpip-stack-socket" ~doc:"Platform's native TCP/IP stack" module Log = (val Logs.src_log src : Logs.LOG) -type socket_ipv4_input = unit Lwt.t - -module type UDPV4_SOCKET = Mirage_protocols.UDP - with type ipinput = socket_ipv4_input - -module type TCPV4_SOCKET = Mirage_protocols.TCP - with type ipinput = socket_ipv4_input - -module Tcpv4 = Tcpv4_socket -module Udpv4 = Udpv4_socket - -module TCPV4 = Tcpv4_socket -module UDPV4 = Udpv4_socket -module IPV4 = Ipv4_socket - -type t = { - udpv4 : Udpv4.t; - tcpv4 : Tcpv4.t; -} - -let udpv4 { udpv4; _ } = udpv4 -let tcpv4 { tcpv4; _ } = tcpv4 -let ipv4 _ = None - -(* List of IP addresses to bind to *) -let configure _t addrs = - match addrs with - | [] -> Lwt.return_unit - | [ip] when (Ipaddr.V4.compare Ipaddr.V4.any ip) = 0 -> Lwt.return_unit - | l -> - let pp_iplist fmt l = Format.pp_print_list Ipaddr.V4.pp fmt l in - Log.warn (fun f -> f - "Manager: sockets currently bind to all available IPs. IPs %a were specified, but this will be ignored" pp_iplist l); - Lwt.return_unit - -let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p - -let listen_udpv4 t ~port callback = - if port < 0 || port > 65535 then - raise (Invalid_argument (err_invalid_port port)) - else - (* FIXME: we should not ignore the result *) - Lwt.async (fun () -> - Udpv4.get_udpv4_listening_fd t.udpv4 port >>= fun fd -> - let buf = Cstruct.create 4096 in - let rec loop () = - (* TODO cancellation *) - Lwt.catch (fun () -> - Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> - let buf = Cstruct.sub buf 0 len in - (match sa with - | Lwt_unix.ADDR_INET (addr, src_port) -> - let src = Ipaddr_unix.V4.of_inet_addr_exn addr in - let dst = Ipaddr.V4.any in (* TODO *) - callback ~src ~dst ~src_port buf - | _ -> Lwt.return_unit)) - (fun exn -> - Log.warn (fun m -> m "exception %s in recvfrom" (Printexc.to_string exn)) ; - Lwt.return_unit) >>= fun () -> - loop () - in - loop ()) - -let listen_tcpv4 ?keepalive _t ~port callback = - if port < 0 || port > 65535 then - raise (Invalid_argument (err_invalid_port port)) - else - let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in - Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; - (* TODO: as elsewhere in the module, we bind all available addresses; it would be better not to do so if the user has requested it *) - let interface = Ipaddr_unix.V4.to_inet_addr Ipaddr.V4.any in - (* FIXME: we should not ignore the result *) - Lwt.async (fun () -> - Lwt_unix.bind fd (Lwt_unix.ADDR_INET (interface, port)) >>= fun () -> - Lwt_unix.listen fd 10; - (* TODO cancellation *) - let rec loop () = - Lwt.catch (fun () -> - Lwt_unix.accept fd >|= fun (afd, _) -> - (match keepalive with - | None -> () - | Some { Mirage_protocols.Keepalive.after; interval; probes } -> - Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes); - Lwt.async - (fun () -> - Lwt.catch - (fun () -> callback afd) - (fun exn -> - Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ; - Lwt.return_unit))) - (fun exn -> - Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ; - Lwt.return_unit) >>= fun () -> - loop () - in - loop ()) - -let listen _t = - let t, _ = Lwt.task () in - t (* TODO cancellation *) - -let connect ips udpv4 tcpv4 = - Log.info (fun f -> f "Manager: connect"); - let t = { tcpv4; udpv4 } in - Log.info (fun f -> f "Manager: configuring"); - configure t ips >|= fun () -> - t - -let disconnect _ = Lwt.return_unit +module V4 = struct + module TCPV4 = Tcpv4_socket + module UDPV4 = Udpv4_socket + module IPV4 = Ipv4_socket + + type t = { + udpv4 : UDPV4.t; + tcpv4 : TCPV4.t; + } + + let udpv4 { udpv4; _ } = udpv4 + let tcpv4 { tcpv4; _ } = tcpv4 + let ipv4 _ = None + + (* List of IP addresses to bind to *) + let configure _t addrs = + match addrs with + | [] -> Lwt.return_unit + | [ip] when (Ipaddr.V4.compare Ipaddr.V4.any ip) = 0 -> Lwt.return_unit + | l -> + let pp_iplist fmt l = Format.pp_print_list Ipaddr.V4.pp fmt l in + Log.warn (fun f -> f + "Manager: sockets currently bind to all available IPs. IPs %a were specified, but this will be ignored" pp_iplist l); + Lwt.return_unit + + let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p + + let listen_udpv4 t ~port callback = + if port < 0 || port > 65535 then + raise (Invalid_argument (err_invalid_port port)) + else + (* FIXME: we should not ignore the result *) + Lwt.async (fun () -> + UDPV4.get_udpv4_listening_fd t.udpv4 port >>= fun fd -> + let buf = Cstruct.create 4096 in + let rec loop () = + (* TODO cancellation *) + Lwt.catch (fun () -> + Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> + let buf = Cstruct.sub buf 0 len in + (match sa with + | Lwt_unix.ADDR_INET (addr, src_port) -> + let src = Ipaddr_unix.V4.of_inet_addr_exn addr in + let dst = Ipaddr.V4.any in (* TODO *) + callback ~src ~dst ~src_port buf + | _ -> Lwt.return_unit)) + (fun exn -> + Log.warn (fun m -> m "exception %s in recvfrom" (Printexc.to_string exn)) ; + Lwt.return_unit) >>= fun () -> + loop () + in + loop ()) + + let listen_tcpv4 ?keepalive _t ~port callback = + if port < 0 || port > 65535 then + raise (Invalid_argument (err_invalid_port port)) + else + let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in + Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; + (* TODO: as elsewhere in the module, we bind all available addresses; it would be better not to do so if the user has requested it *) + let interface = Ipaddr_unix.V4.to_inet_addr Ipaddr.V4.any in + (* FIXME: we should not ignore the result *) + Lwt.async (fun () -> + Lwt_unix.bind fd (Lwt_unix.ADDR_INET (interface, port)) >>= fun () -> + Lwt_unix.listen fd 10; + (* TODO cancellation *) + let rec loop () = + Lwt.catch (fun () -> + Lwt_unix.accept fd >|= fun (afd, _) -> + (match keepalive with + | None -> () + | Some { Mirage_protocols.Keepalive.after; interval; probes } -> + Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes); + Lwt.async + (fun () -> + Lwt.catch + (fun () -> callback afd) + (fun exn -> + Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ; + Lwt.return_unit))) + (fun exn -> + Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ; + Lwt.return_unit) >>= fun () -> + loop () + in + loop ()) + + let listen _t = + let t, _ = Lwt.task () in + t (* TODO cancellation *) + + let connect ips udpv4 tcpv4 = + Log.info (fun f -> f "Manager: connect"); + let t = { tcpv4; udpv4 } in + Log.info (fun f -> f "Manager: configuring"); + configure t ips >|= fun () -> + t + + let disconnect _ = Lwt.return_unit +end + +module V6 = struct + module TCP = Tcpv6_socket + module UDP = Udpv6_socket + module IP = Ipv6_socket + + type t = { + udp : UDP.t; + tcp : TCP.t; + } + + let udp { udp; _ } = udp + let tcp { tcp; _ } = tcp + let ip _ = () + + let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p + + let listen_udp t ~port callback = + if port < 0 || port > 65535 then + raise (Invalid_argument (err_invalid_port port)) + else + (* FIXME: we should not ignore the result *) + Lwt.async (fun () -> + UDP.get_udpv6_listening_fd t.udp port >>= fun fd -> + let buf = Cstruct.create 4096 in + let rec loop () = + (* TODO cancellation *) + Lwt.catch (fun () -> + Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> + let buf = Cstruct.sub buf 0 len in + (match sa with + | Lwt_unix.ADDR_INET (addr, src_port) -> + let src = Ipaddr_unix.V6.of_inet_addr_exn addr in + let dst = Ipaddr.V6.unspecified in (* TODO *) + callback ~src ~dst ~src_port buf + | _ -> Lwt.return_unit)) + (fun exn -> + Log.warn (fun m -> m "exception %s in recvfrom" (Printexc.to_string exn)) ; + Lwt.return_unit) >>= fun () -> + loop () + in + loop ()) + + let listen_tcp ?keepalive _t ~port callback = + if port < 0 || port > 65535 then + raise (Invalid_argument (err_invalid_port port)) + else + let fd = Lwt_unix.(socket PF_INET6 SOCK_STREAM 0) in + Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; + (* TODO: as elsewhere in the module, we bind all available addresses; it would be better not to do so if the user has requested it *) + let interface = Ipaddr_unix.V6.to_inet_addr Ipaddr.V6.unspecified in + (* FIXME: we should not ignore the result *) + Lwt.async (fun () -> + Lwt_unix.bind fd (Lwt_unix.ADDR_INET (interface, port)) >>= fun () -> + Lwt_unix.listen fd 10; + (* TODO cancellation *) + let rec loop () = + Lwt.catch (fun () -> + Lwt_unix.accept fd >|= fun (afd, _) -> + (match keepalive with + | None -> () + | Some { Mirage_protocols.Keepalive.after; interval; probes } -> + Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes); + Lwt.async + (fun () -> + Lwt.catch + (fun () -> callback afd) + (fun exn -> + Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ; + Lwt.return_unit))) + (fun exn -> + Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ; + Lwt.return_unit) >>= fun () -> + loop () + in + loop ()) + + let listen _t = + let t, _ = Lwt.task () in + t (* TODO cancellation *) + + (* List of IP addresses to bind to *) + let configure _t addrs = + match addrs with + | [] -> Lwt.return_unit + | [ip] when (Ipaddr.V6.compare Ipaddr.V6.unspecified ip) = 0 -> Lwt.return_unit + | l -> + let pp_iplist fmt l = Format.pp_print_list Ipaddr.V6.pp fmt l in + Log.warn (fun f -> f + "Manager: sockets currently bind to all available IPs. IPs %a were specified, but this will be ignored" pp_iplist l); + Lwt.return_unit + + let connect ips udp tcp = + Log.info (fun f -> f "Manager: connect"); + let t = { tcp; udp } in + Log.info (fun f -> f "Manager: configuring"); + configure t ips >|= fun () -> + t + + let disconnect _ = Lwt.return_unit +end diff --git a/src/stack-unix/tcpip_stack_socket.mli b/src/stack-unix/tcpip_stack_socket.mli index 23968e956..260f219cb 100644 --- a/src/stack-unix/tcpip_stack_socket.mli +++ b/src/stack-unix/tcpip_stack_socket.mli @@ -14,8 +14,18 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) -include Mirage_stack.V4 - with module UDPV4 = Udpv4_socket - and module TCPV4 = Tcpv4_socket - and module IPV4 = Ipv4_socket -val connect : Ipaddr.V4.t list -> Udpv4_socket.t -> Tcpv4_socket.t -> t Lwt.t +module V4 : sig + include Mirage_stack.V4 + with module UDPV4 = Udpv4_socket + and module TCPV4 = Tcpv4_socket + and module IPV4 = Ipv4_socket + val connect : Ipaddr.V4.t list -> Udpv4_socket.t -> Tcpv4_socket.t -> t Lwt.t +end + +module V6 : sig + include Mirage_stack.V6 + with module UDP = Udpv6_socket + and module TCP = Tcpv6_socket + and module IP = Ipv6_socket + val connect : Ipaddr.V6.t list -> Udpv6_socket.t -> Tcpv6_socket.t -> t Lwt.t +end diff --git a/src/stack-unix/udpv6_socket.ml b/src/stack-unix/udpv6_socket.ml index 23d0a8f14..922f31f30 100644 --- a/src/stack-unix/udpv6_socket.ml +++ b/src/stack-unix/udpv6_socket.ml @@ -41,9 +41,10 @@ let get_udpv6_listening_fd {listen_fds;interface} port = Lwt.return fd (** IO operation errors *) -type error = [ - | `Unknown of string (** an undiagnosed error *) -] +type error = [`Sendto_failed] + +let pp_error ppf = function + | `Sendto_failed -> Fmt.pf ppf "sendto failed to write any bytes" let connect (id:ip) = let t = @@ -67,12 +68,17 @@ let id { interface; _ } = let t, _ = Lwt.task () in t -let write ?source_port ~dest_ip ~dest_port t buf = +let write ?src_port ?ttl:_ttl ~dst ~dst_port t buf = let open Lwt_unix in - ( match source_port with + let rec write_to_fd fd buf = + Lwt_cstruct.sendto fd buf [] (ADDR_INET ((Ipaddr_unix.V6.to_inet_addr dst), dst_port)) + >>= function + | n when n = Cstruct.len buf -> Lwt.return @@ Ok () + | 0 -> Lwt.return @@ Error `Sendto_failed + | n -> write_to_fd fd (Cstruct.sub buf n (Cstruct.len buf - n)) (* keep trying *) + in + ( match src_port with | None -> get_udpv6_listening_fd t 0 | Some port -> get_udpv6_listening_fd t port ) >>= fun fd -> - Lwt_cstruct.sendto fd buf [] (ADDR_INET ((Ipaddr_unix.V6.to_inet_addr dest_ip), dest_port)) - >>= fun _ -> - return_unit + write_to_fd fd buf diff --git a/test/test_socket.ml b/test/test_socket.ml index af517694c..4b2c10e35 100644 --- a/test/test_socket.ml +++ b/test/test_socket.ml @@ -1,6 +1,6 @@ open Lwt.Infix -module Stack = Tcpip_stack_socket +module Stack = Tcpip_stack_socket.V4 module Time = Vnetif_common.Time type stack_stack = {