Skip to content

Commit

Permalink
fix(rpc): ignore sigpipe on linux (#7319)
Browse files Browse the repository at this point in the history
Signed-off-by: Rudi Grinberg <me@rgrinberg.com>
  • Loading branch information
rgrinberg authored Mar 17, 2023
1 parent 695a450 commit 8cdeba0
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 19 deletions.
4 changes: 2 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Unreleased
----------

- RPC: Ignore SIGPIPE when clients suddenly disconnect on OSX (#7299, partially
fixes #6879, @rgrinberg)
- RPC: Ignore SIGPIPE when clients suddenly disconnect (#7299, #7319, fixes
#6879, @rgrinberg)

- Always clean up the UI on exit. (#7271, fixes #7142 @rgrinberg)

Expand Down
78 changes: 61 additions & 17 deletions src/csexp_rpc/csexp_rpc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,13 @@ module Session = struct
type state =
| Closed
| Open of
{ out_channel : out_channel
{ out_buf : Io_buffer.t
; fd : Unix.file_descr
; (* A mutex for modifying [out_buf].
Needed as long as we use threads for async IO. Once we switch to
event based IO, we won't need this mutex anymore *)
write_mutex : Mutex.t
; in_channel : in_channel
; writer : Worker.t
; reader : Worker.t
Expand All @@ -139,13 +145,22 @@ module Session = struct
; mutable state : state
}

let create in_channel out_channel =
let create fd in_channel =
let id = Id.gen () in
if debug then
Log.info [ Pp.textf "RPC created new session %d" (Id.to_int id) ];
let* reader = Worker.create () in
let+ writer = Worker.create () in
let state = Open { in_channel; out_channel; reader; writer } in
let state =
Open
{ fd
; in_channel
; out_buf = Io_buffer.create ~size:8192
; write_mutex = Mutex.create ()
; writer
; reader
}
in
{ id; state }

let string_of_packet = function
Expand All @@ -159,10 +174,11 @@ module Session = struct
let close t =
match t.state with
| Closed -> ()
| Open { in_channel = _; out_channel; reader; writer } ->
| Open { write_mutex = _; fd = _; in_channel; out_buf = _; reader; writer }
->
Worker.stop reader;
Worker.stop writer;
close_out_noerr out_channel;
close_in_noerr in_channel;
t.state <- Closed

let read t =
Expand Down Expand Up @@ -203,6 +219,30 @@ module Session = struct
debug res;
res

external send : Unix.file_descr -> Bytes.t -> int -> int -> int = "dune_send"

let write = if Sys.linux then send else Unix.single_write

let rec csexp_write_loop fd out_buf token write_mutex =
Mutex.lock write_mutex;
if Io_buffer.flushed out_buf token then Mutex.unlock write_mutex
else
(* We always make sure to try and write the entire buffer.
This should minimize the amount of [write] calls we need
to do *)
let written =
let bytes = Io_buffer.bytes out_buf in
let pos = Io_buffer.pos out_buf in
let len = Io_buffer.length out_buf in
try write fd bytes pos len
with exn ->
Mutex.unlock write_mutex;
reraise exn
in
Io_buffer.read out_buf written;
Mutex.unlock write_mutex;
csexp_write_loop fd out_buf token write_mutex

let write t sexps =
if debug then
Log.info
Expand All @@ -218,21 +258,23 @@ module Session = struct
| Some sexps ->
Code_error.raise "attempting to write to a closed channel"
[ ("sexp", Dyn.(list Sexp.to_dyn) sexps) ])
| Open { writer; out_channel; _ } -> (
| Open { writer; fd; out_buf; write_mutex; _ } -> (
match sexps with
| None ->
(try
Unix.shutdown
(Unix.descr_of_out_channel out_channel)
Unix.SHUTDOWN_ALL
(* TODO this hack is temporary until we get rid of dune rpc init *)
Unix.shutdown fd Unix.SHUTDOWN_ALL
with Unix.Unix_error (_, _, _) -> ());
close t;
Fiber.return ()
| Some sexps -> (
let+ res =
Mutex.lock write_mutex;
Io_buffer.write_csexps out_buf sexps;
let flush_token = Io_buffer.flush_token out_buf in
Mutex.unlock write_mutex;
Worker.task writer ~f:(fun () ->
List.iter sexps ~f:(Csexp.to_channel out_channel);
flush out_channel)
csexp_write_loop fd out_buf flush_token write_mutex)
in
match res with
| Ok () -> ()
Expand Down Expand Up @@ -327,8 +369,7 @@ module Server = struct
Transport.accept transport
|> Option.map ~f:(fun client ->
let in_ = Unix.in_channel_of_descr client in
let out = Unix.out_channel_of_descr client in
(in_, out)))
(client, in_)))
in
let loop () =
let* accept = accept () in
Expand All @@ -349,8 +390,8 @@ module Server = struct
accepted."
];
Fiber.return None
| Ok (Some (in_, out)) ->
let+ session = Session.create in_ out in
| Ok (Some (fd, in_)) ->
let+ session = Session.create fd in_ in
Some session
in
Fiber.Stream.In.create loop
Expand Down Expand Up @@ -413,9 +454,8 @@ module Client = struct
let transport = Transport.create t.sockaddr in
t.transport <- Some transport;
let client = Transport.connect transport in
let out = Unix.out_channel_of_descr client in
let in_ = Unix.in_channel_of_descr client in
(in_, out))
(client, in_))
in
Worker.stop async;
match task with
Expand All @@ -433,3 +473,7 @@ module Client = struct

let stop t = Option.iter t.transport ~f:Transport.close
end

module Private = struct
module Io_buffer = Io_buffer
end
4 changes: 4 additions & 0 deletions src/csexp_rpc/csexp_rpc.mli
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,7 @@ module Server : sig

val listening_address : t -> Unix.sockaddr
end

module Private : sig
module Io_buffer : module type of Io_buffer
end
31 changes: 31 additions & 0 deletions src/csexp_rpc/csexp_rpc_stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,34 @@ CAMLprim value dune_pthread_chdir(value unit) {
}

#endif

#if __linux__

#include <caml/threads.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>

CAMLprim value dune_send(value v_fd, value v_bytes, value v_pos, value v_len) {
CAMLparam4(v_fd, v_bytes, v_pos, v_len);
int len = Long_val(v_len);
if (len > UNIX_BUFFER_SIZE) {
len = UNIX_BUFFER_SIZE;
}
int pos = Long_val(v_pos);
int fd = Int_val(v_fd);
char iobuf[UNIX_BUFFER_SIZE];
memmove(iobuf, &Byte(v_bytes, pos), len);
caml_release_runtime_system();
int ret = send(fd, iobuf, len, MSG_NOSIGNAL);
caml_acquire_runtime_system();
if (ret == -1) {
uerror("send", Nothing);
};
CAMLreturn(Val_int(ret));
}
#else
CAMLprim value dune_send(value v_fd, value v_bytes, value v_pos, value v_len) {
caml_invalid_argument("sendmsg without sigpipe only available on linux");
}
#endif
1 change: 1 addition & 0 deletions src/csexp_rpc/dune
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
dune_util
csexp
fiber
threads.posix
(re_export unix))
(foreign_stubs
(language c)
Expand Down
98 changes: 98 additions & 0 deletions src/csexp_rpc/io_buffer.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
open Stdune

type t =
{ mutable bytes : Bytes.t (* underlying bytes *)
; (* the position we can start reading from (until [pos_w]) *)
mutable pos_r : int
; (* the position we can start writing to (until [Bytes.length bytes - 1]) *)
mutable pos_w : int
; (* total number of bytes written to this buffer. 2^63 bytes should be
enough for anybody *)
mutable total_written : int
}

type flush_token = int

(* We can't use [Out_channel] for writes on Linux because we want to disable
sigpipes. Eventually we'll move to event based IO and ditch the threads,
so we'll need this anyway *)

let create ~size =
{ bytes = Bytes.create size; pos_r = 0; pos_w = 0; total_written = 0 }

let length t = t.pos_w - t.pos_r

let max_buffer_size = 65536

let maybe_resize_to_fit t write_size =
let buf_len = Bytes.length t.bytes in
let capacity = buf_len - t.pos_w in
if capacity < write_size then (
let bytes =
let new_size =
let needed = buf_len + write_size - capacity in
max (min max_buffer_size (buf_len * 2)) needed
in
Bytes.create new_size
in
let len = length t in
Bytes.blit ~src:t.bytes ~src_pos:t.pos_r ~dst:bytes ~dst_pos:0 ~len;
t.bytes <- bytes;
t.pos_w <- len;
t.pos_r <- 0)

let write_char_exn t c =
assert (t.pos_w < Bytes.length t.bytes);
Bytes.set t.bytes t.pos_w c;
t.pos_w <- t.pos_w + 1

let write_string_exn t src =
assert (t.pos_w < Bytes.length t.bytes);
let len = String.length src in
Bytes.blit_string ~src ~src_pos:0 ~dst:t.bytes ~dst_pos:t.pos_w ~len;
t.pos_w <- t.pos_w + len

let read t len =
let pos_r = t.pos_r + len in
if pos_r > t.pos_w then
Code_error.raise "not enough bytes in buffer"
[ ("len", Dyn.int len); ("length", Dyn.int (length t)) ];
t.pos_r <- pos_r;
t.total_written <- t.total_written + len

let flush_token t = t.total_written + length t

let flushed t token = t.total_written >= token

let write_csexps =
let rec loop t (csexp : Csexp.t) =
match csexp with
| Atom str ->
write_string_exn t (string_of_int (String.length str));
write_char_exn t ':';
write_string_exn t str
| List e ->
write_char_exn t '(';
List.iter ~f:(loop t) e;
write_char_exn t ')'
in
fun t csexps ->
let length =
List.fold_left csexps ~init:0 ~f:(fun acc csexp ->
acc + Csexp.serialised_length csexp)
in
maybe_resize_to_fit t length;
List.iter ~f:(loop t) csexps

let pos t = t.pos_r

let bytes t = t.bytes

let to_dyn ({ bytes; pos_r; pos_w; total_written } as t) =
let open Dyn in
record
[ ("total_written", int total_written)
; ("contents", string (Bytes.sub_string bytes ~pos:pos_r ~len:(length t)))
; ("pos_w", int pos_w)
; ("pos_r", int pos_r)
]
33 changes: 33 additions & 0 deletions src/csexp_rpc/io_buffer.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
(** A resizable IO buffer *)

type t

val to_dyn : t -> Dyn.t

(** create a new io buffer *)
val create : size:int -> t

(** [read t n] reads [n] bytes *)
val read : t -> int -> unit

(** [write t csexps] write [csexps] to [t] while resizing [t] as necessary *)
val write_csexps : t -> Csexp.t list -> unit

(** a flush token is used to determine when a write has been completely flushed *)
type flush_token

(** [flush_token t] will be flushed whenever everything in [t] will be written *)
val flush_token : t -> flush_token

(** [flushed t token] will return [true] once all the data that was present in
[t] when [token] was created will be written *)
val flushed : t -> flush_token -> bool

(** underlying raw buffer *)
val bytes : t -> Bytes.t

(** [pos t] in [bytes t] to read *)
val pos : t -> int

(** [length t] the number of bytes to read [bytes t] *)
val length : t -> int
Loading

0 comments on commit 8cdeba0

Please sign in to comment.