Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rpc): ignore sigpipe on linux #7319

Merged
merged 1 commit into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Unreleased
----------

- RPC: Ignore SIGPIPE when clients suddenly disconnect on OSX (#7299, partially
fixes #6879, @rgrinberg)
- RPC: Ignore SIGPIPE when clients suddenly disconnect (#7299, #7319, fixes
#6879, @rgrinberg)

- Always clean up the UI on exit. (#7271, fixes #7142 @rgrinberg)

Expand Down
78 changes: 61 additions & 17 deletions src/csexp_rpc/csexp_rpc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,13 @@ module Session = struct
type state =
| Closed
| Open of
{ out_channel : out_channel
{ out_buf : Io_buffer.t
; fd : Unix.file_descr
; (* A mutex for modifying [out_buf].

Needed as long as we use threads for async IO. Once we switch to
event based IO, we won't need this mutex anymore *)
write_mutex : Mutex.t
; in_channel : in_channel
; writer : Worker.t
; reader : Worker.t
Expand All @@ -139,13 +145,22 @@ module Session = struct
; mutable state : state
}

let create in_channel out_channel =
let create fd in_channel =
let id = Id.gen () in
if debug then
Log.info [ Pp.textf "RPC created new session %d" (Id.to_int id) ];
let* reader = Worker.create () in
let+ writer = Worker.create () in
let state = Open { in_channel; out_channel; reader; writer } in
let state =
Open
{ fd
; in_channel
; out_buf = Io_buffer.create ~size:8192
; write_mutex = Mutex.create ()
; writer
; reader
}
in
{ id; state }

let string_of_packet = function
Expand All @@ -159,10 +174,11 @@ module Session = struct
let close t =
match t.state with
| Closed -> ()
| Open { in_channel = _; out_channel; reader; writer } ->
| Open { write_mutex = _; fd = _; in_channel; out_buf = _; reader; writer }
->
Worker.stop reader;
Worker.stop writer;
close_out_noerr out_channel;
close_in_noerr in_channel;
t.state <- Closed

let read t =
Expand Down Expand Up @@ -203,6 +219,30 @@ module Session = struct
debug res;
res

external send : Unix.file_descr -> Bytes.t -> int -> int -> int = "dune_send"

let write = if Sys.linux then send else Unix.single_write

let rec csexp_write_loop fd out_buf token write_mutex =
Mutex.lock write_mutex;
if Io_buffer.flushed out_buf token then Mutex.unlock write_mutex
else
(* We always make sure to try and write the entire buffer.
This should minimize the amount of [write] calls we need
to do *)
let written =
let bytes = Io_buffer.bytes out_buf in
let pos = Io_buffer.pos out_buf in
let len = Io_buffer.length out_buf in
try write fd bytes pos len
with exn ->
Mutex.unlock write_mutex;
reraise exn
in
Io_buffer.read out_buf written;
Mutex.unlock write_mutex;
Comment on lines +233 to +243
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let written =
let bytes = Io_buffer.bytes out_buf in
let pos = Io_buffer.pos out_buf in
let len = Io_buffer.length out_buf in
try write fd bytes pos len
with exn ->
Mutex.unlock write_mutex;
reraise exn
in
Io_buffer.read out_buf written;
Mutex.unlock write_mutex;
let bytes = Io_buffer.bytes out_buf in
let pos = Io_buffer.pos out_buf in
let len = Io_buffer.length out_buf in
match write fd bytes pos len with
| written ->
Io_buffer.read out_buf written;
Mutex.unlock write_mutex;
csexp_write_loop fd out_buf token write_mutex
| exception exn ->
Mutex.unlock write_mutex;
reraise exn

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A second question is that we repeatedly lock/unlock the mutex on each iteration of the write loop; is this better than holding on to the mutex until the writing is done? (eg by using a local function)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better because it's allowing callers to refill the write buffer on every send. Less write calls and more mutex locks sounds like a better trade off. Anyway, the mutex will be gone soon so I suggest we don't dwell on it.

csexp_write_loop fd out_buf token write_mutex

let write t sexps =
if debug then
Log.info
Expand All @@ -218,21 +258,23 @@ module Session = struct
| Some sexps ->
Code_error.raise "attempting to write to a closed channel"
[ ("sexp", Dyn.(list Sexp.to_dyn) sexps) ])
| Open { writer; out_channel; _ } -> (
| Open { writer; fd; out_buf; write_mutex; _ } -> (
match sexps with
| None ->
(try
Unix.shutdown
(Unix.descr_of_out_channel out_channel)
Unix.SHUTDOWN_ALL
(* TODO this hack is temporary until we get rid of dune rpc init *)
Unix.shutdown fd Unix.SHUTDOWN_ALL
with Unix.Unix_error (_, _, _) -> ());
close t;
Fiber.return ()
| Some sexps -> (
let+ res =
Mutex.lock write_mutex;
Io_buffer.write_csexps out_buf sexps;
let flush_token = Io_buffer.flush_token out_buf in
Mutex.unlock write_mutex;
Worker.task writer ~f:(fun () ->
List.iter sexps ~f:(Csexp.to_channel out_channel);
flush out_channel)
csexp_write_loop fd out_buf flush_token write_mutex)
in
match res with
| Ok () -> ()
Expand Down Expand Up @@ -327,8 +369,7 @@ module Server = struct
Transport.accept transport
|> Option.map ~f:(fun client ->
let in_ = Unix.in_channel_of_descr client in
let out = Unix.out_channel_of_descr client in
(in_, out)))
(client, in_)))
in
let loop () =
let* accept = accept () in
Expand All @@ -349,8 +390,8 @@ module Server = struct
accepted."
];
Fiber.return None
| Ok (Some (in_, out)) ->
let+ session = Session.create in_ out in
| Ok (Some (fd, in_)) ->
let+ session = Session.create fd in_ in
Some session
in
Fiber.Stream.In.create loop
Expand Down Expand Up @@ -413,9 +454,8 @@ module Client = struct
let transport = Transport.create t.sockaddr in
t.transport <- Some transport;
let client = Transport.connect transport in
let out = Unix.out_channel_of_descr client in
let in_ = Unix.in_channel_of_descr client in
(in_, out))
(client, in_))
in
Worker.stop async;
match task with
Expand All @@ -433,3 +473,7 @@ module Client = struct

let stop t = Option.iter t.transport ~f:Transport.close
end

module Private = struct
module Io_buffer = Io_buffer
end
4 changes: 4 additions & 0 deletions src/csexp_rpc/csexp_rpc.mli
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,7 @@ module Server : sig

val listening_address : t -> Unix.sockaddr
end

module Private : sig
module Io_buffer : module type of Io_buffer
end
31 changes: 31 additions & 0 deletions src/csexp_rpc/csexp_rpc_stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,34 @@ CAMLprim value dune_pthread_chdir(value unit) {
}

#endif

#if __linux__

#include <caml/threads.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>

CAMLprim value dune_send(value v_fd, value v_bytes, value v_pos, value v_len) {
CAMLparam4(v_fd, v_bytes, v_pos, v_len);
int len = Long_val(v_len);
if (len > UNIX_BUFFER_SIZE) {
len = UNIX_BUFFER_SIZE;
}
int pos = Long_val(v_pos);
int fd = Int_val(v_fd);
char iobuf[UNIX_BUFFER_SIZE];
memmove(iobuf, &Byte(v_bytes, pos), len);
caml_release_runtime_system();
int ret = send(fd, iobuf, len, MSG_NOSIGNAL);
caml_acquire_runtime_system();
if (ret == -1) {
uerror("send", Nothing);
};
CAMLreturn(Val_int(ret));
}
#else
CAMLprim value dune_send(value v_fd, value v_bytes, value v_pos, value v_len) {
caml_invalid_argument("sendmsg without sigpipe only available on linux");
}
#endif
1 change: 1 addition & 0 deletions src/csexp_rpc/dune
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
dune_util
csexp
fiber
threads.posix
(re_export unix))
(foreign_stubs
(language c)
Expand Down
98 changes: 98 additions & 0 deletions src/csexp_rpc/io_buffer.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
open Stdune

type t =
{ mutable bytes : Bytes.t (* underlying bytes *)
; (* the position we can start reading from (until [pos_w]) *)
mutable pos_r : int
; (* the position we can start writing to (until [Bytes.length bytes - 1]) *)
mutable pos_w : int
; (* total number of bytes written to this buffer. 2^63 bytes should be
enough for anybody *)
mutable total_written : int
}

type flush_token = int

(* We can't use [Out_channel] for writes on Linux because we want to disable
sigpipes. Eventually we'll move to event based IO and ditch the threads,
so we'll need this anyway *)

let create ~size =
{ bytes = Bytes.create size; pos_r = 0; pos_w = 0; total_written = 0 }

let length t = t.pos_w - t.pos_r

let max_buffer_size = 65536

let maybe_resize_to_fit t write_size =
let buf_len = Bytes.length t.bytes in
let capacity = buf_len - t.pos_w in
if capacity < write_size then (
rgrinberg marked this conversation as resolved.
Show resolved Hide resolved
let bytes =
let new_size =
let needed = buf_len + write_size - capacity in
max (min max_buffer_size (buf_len * 2)) needed
in
Bytes.create new_size
in
let len = length t in
Bytes.blit ~src:t.bytes ~src_pos:t.pos_r ~dst:bytes ~dst_pos:0 ~len;
t.bytes <- bytes;
t.pos_w <- len;
t.pos_r <- 0)

let write_char_exn t c =
assert (t.pos_w < Bytes.length t.bytes);
Bytes.set t.bytes t.pos_w c;
rgrinberg marked this conversation as resolved.
Show resolved Hide resolved
t.pos_w <- t.pos_w + 1

let write_string_exn t src =
assert (t.pos_w < Bytes.length t.bytes);
let len = String.length src in
rgrinberg marked this conversation as resolved.
Show resolved Hide resolved
Bytes.blit_string ~src ~src_pos:0 ~dst:t.bytes ~dst_pos:t.pos_w ~len;
t.pos_w <- t.pos_w + len

let read t len =
let pos_r = t.pos_r + len in
if pos_r > t.pos_w then
Code_error.raise "not enough bytes in buffer"
[ ("len", Dyn.int len); ("length", Dyn.int (length t)) ];
t.pos_r <- pos_r;
t.total_written <- t.total_written + len

let flush_token t = t.total_written + length t

let flushed t token = t.total_written >= token

let write_csexps =
let rec loop t (csexp : Csexp.t) =
match csexp with
| Atom str ->
write_string_exn t (string_of_int (String.length str));
write_char_exn t ':';
write_string_exn t str
| List e ->
write_char_exn t '(';
List.iter ~f:(loop t) e;
write_char_exn t ')'
in
fun t csexps ->
let length =
List.fold_left csexps ~init:0 ~f:(fun acc csexp ->
acc + Csexp.serialised_length csexp)
in
maybe_resize_to_fit t length;
List.iter ~f:(loop t) csexps

let pos t = t.pos_r

let bytes t = t.bytes

let to_dyn ({ bytes; pos_r; pos_w; total_written } as t) =
let open Dyn in
record
[ ("total_written", int total_written)
; ("contents", string (Bytes.sub_string bytes ~pos:pos_r ~len:(length t)))
; ("pos_w", int pos_w)
; ("pos_r", int pos_r)
]
33 changes: 33 additions & 0 deletions src/csexp_rpc/io_buffer.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
(** A resizable IO buffer *)

type t

val to_dyn : t -> Dyn.t

(** create a new io buffer *)
val create : size:int -> t

(** [read t n] reads [n] bytes *)
val read : t -> int -> unit

(** [write t csexps] write [csexps] to [t] while resizing [t] as necessary *)
val write_csexps : t -> Csexp.t list -> unit

(** a flush token is used to determine when a write has been completely flushed *)
type flush_token

(** [flush_token t] will be flushed whenever everything in [t] will be written *)
val flush_token : t -> flush_token

(** [flushed t token] will return [true] once all the data that was present in
[t] when [token] was created will be written *)
val flushed : t -> flush_token -> bool

(** underlying raw buffer *)
val bytes : t -> Bytes.t

(** [pos t] in [bytes t] to read *)
val pos : t -> int

(** [length t] the number of bytes to read [bytes t] *)
val length : t -> int
Loading