Skip to content

Commit

Permalink
[LibOS,PAL] Support EINPROGRESS on non-blocking sockets connect
Browse files Browse the repository at this point in the history
Previously, Gramine transformed `connect()` of non-blocking sockets into
a blocking operation and thus never returned -EINPROGRESS. This led to
the connect operation being very slow (waiting for a host timeout) if a
remote peer is unresponsive.

This commit fixes this and adds a LibOS regression test (4 variants to
test poll/epoll and responsive/unresponsive peer).

Signed-off-by: Dmitrii Kuvaiskii <dmitrii.kuvaiskii@intel.com>
  • Loading branch information
Dmitrii Kuvaiskii committed Nov 23, 2023
1 parent 1ea3e60 commit a2e7d95
Show file tree
Hide file tree
Showing 22 changed files with 384 additions and 107 deletions.
5 changes: 4 additions & 1 deletion libos/include/libos_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct libos_pipe_handle {
enum libos_sock_state {
SOCK_NEW,
SOCK_BOUND,
SOCK_CONNECTING,
SOCK_CONNECTED,
SOCK_LISTENING,
};
Expand All @@ -69,7 +70,7 @@ enum libos_sock_state {
* stream reads (see the comment in `do_recvmsg` in "libos/src/sys/libos_socket.c").
* Access to `force_nonblocking_users_count` is protected by the lock of the handle wrapping this
* struct.
* `pal_handle` should be accessed using atomic operations.
* `pal_handle` and `connecting_in_progress` should be accessed using atomic operations.
* If you need to take both `recv_lock` and `lock`, take the former first.
*/
struct libos_sock_handle {
Expand Down Expand Up @@ -98,6 +99,8 @@ struct libos_sock_handle {
uint64_t sendtimeout_us;
uint64_t receivetimeout_us;
unsigned int last_error;
/* This field is an atomically-accessed version of the SOCK_CONNECTING state (for perf). */
bool connecting_in_progress;
/* This field denotes whether the socket was ever bound. */
bool was_bound;
/* This field indicates if the socket is ready for read-like operations (`recv`/`read` or
Expand Down
3 changes: 2 additions & 1 deletion libos/include/libos_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct libos_sock_ops {
*
* Must be called with `handle->info.sock.lock` taken.
*/
int (*connect)(struct libos_handle* handle, void* addr, size_t addrlen);
int (*connect)(struct libos_handle* handle, void* addr, size_t addrlen, bool* out_inprogress);

/*!
* \brief Disconnect a previously connected handle.
Expand Down Expand Up @@ -121,6 +121,7 @@ struct libos_sock_ops {

struct libos_handle* get_new_socket_handle(int family, int type, int protocol,
bool is_nonblocking);
void check_connect_inprogress_on_poll(struct libos_handle* handle, bool error_event);

extern struct libos_sock_ops sock_unix_ops;
extern struct libos_sock_ops sock_ip_ops;
Expand Down
12 changes: 7 additions & 5 deletions libos/src/net/ip.c
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ static int accept(struct libos_handle* handle, bool is_nonblocking,
return 0;
}

static int connect(struct libos_handle* handle, void* addr, size_t addrlen) {
static int connect(struct libos_handle* handle, void* addr, size_t addrlen, bool* out_inprogress) {
struct libos_sock_handle* sock = &handle->info.sock;
assert(locked(&sock->lock));

Expand All @@ -221,9 +221,8 @@ static int connect(struct libos_handle* handle, void* addr, size_t addrlen) {
linux_to_pal_sockaddr(addr, &pal_remote_addr);
struct pal_socket_addr pal_local_addr;

/* XXX: this connect is always blocking (regardless of actual setting of nonblockingness on
* `sock->pal_handle`. See also the comment in tcp connect implementation in Linux PAL. */
ret = PalSocketConnect(sock->pal_handle, &pal_remote_addr, &pal_local_addr);
bool inprogress;
ret = PalSocketConnect(sock->pal_handle, &pal_remote_addr, &pal_local_addr, &inprogress);
if (ret < 0) {
return ret == -PAL_ERROR_CONNFAILED ? -ECONNREFUSED : pal_to_unix_errno(ret);
}
Expand All @@ -235,6 +234,7 @@ static int connect(struct libos_handle* handle, void* addr, size_t addrlen) {
assert(!sock->was_bound);
pal_to_linux_sockaddr(&pal_local_addr, &sock->local_addr, &sock->local_addrlen);
}
*out_inprogress = inprogress;
return 0;
}

Expand All @@ -245,7 +245,9 @@ static int disconnect(struct libos_handle* handle) {
struct pal_socket_addr pal_ip_addr = {
.domain = PAL_DISCONNECT,
};
int ret = PalSocketConnect(sock->pal_handle, &pal_ip_addr, /*local_addr=*/NULL);
bool inprogress_unused;
int ret = PalSocketConnect(sock->pal_handle, &pal_ip_addr, /*local_addr=*/NULL,
&inprogress_unused);
return pal_to_unix_errno(ret);
}

Expand Down
3 changes: 2 additions & 1 deletion libos/src/net/unix.c
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ static int accept(struct libos_handle* handle, bool is_nonblocking,
return 0;
}

static int connect(struct libos_handle* handle, void* addr, size_t addrlen) {
static int connect(struct libos_handle* handle, void* addr, size_t addrlen, bool* out_inprogress) {
struct libos_sock_handle* sock = &handle->info.sock;
assert(locked(&sock->lock));

Expand Down Expand Up @@ -260,6 +260,7 @@ static int connect(struct libos_handle* handle, void* addr, size_t addrlen) {
}

interrupt_epolls(handle);
*out_inprogress = false;
return 0;
}

Expand Down
7 changes: 7 additions & 0 deletions libos/src/sys/libos_epoll.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "libos_pollable_event.h"
#include "libos_refcount.h"
#include "libos_signal.h"
#include "libos_socket.h"
#include "libos_table.h"
#include "libos_thread.h"
#include "libos_types.h"
Expand Down Expand Up @@ -673,6 +674,12 @@ static int do_epoll_wait(int epfd, struct epoll_event* events, int maxevents, in
this_item_events |= items[i]->events & (EPOLLOUT | EPOLLWRNORM);
}

if (items[i]->handle->type == TYPE_SOCK &&
(pal_ret_events[i] & (PAL_WAIT_READ | PAL_WAIT_WRITE))) {
bool error_event = !!(pal_ret_events[i] & (PAL_WAIT_ERROR | PAL_WAIT_HANG_UP));
check_connect_inprogress_on_poll(items[i]->handle, error_event);
}

if (!this_item_events) {
/* This handle is not interested in events that were detected - epoll item was
* probably updated asynchronously. */
Expand Down
7 changes: 7 additions & 0 deletions libos/src/sys/libos_poll.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "libos_internal.h"
#include "libos_lock.h"
#include "libos_signal.h"
#include "libos_socket.h"
#include "libos_table.h"
#include "libos_thread.h"
#include "libos_utils.h"
Expand Down Expand Up @@ -212,6 +213,12 @@ static long do_poll(struct pollfd* fds, size_t fds_len, uint64_t* timeout_us) {
if (ret_events[i] & PAL_WAIT_WRITE)
fds[i].revents |= fds[i].events & (POLLOUT | POLLWRNORM);

if (libos_handles[i]->type == TYPE_SOCK &&
(ret_events[i] & (PAL_WAIT_READ | PAL_WAIT_WRITE))) {
bool error_event = !!(ret_events[i] & (PAL_WAIT_ERROR | PAL_WAIT_HANG_UP));
check_connect_inprogress_on_poll(libos_handles[i], error_event);
}

if (fds[i].revents)
ret_events_count++;
}
Expand Down
107 changes: 87 additions & 20 deletions libos/src/sys/libos_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "linux_abi/errors.h"

/*
* Sockets can be in 4 states: NEW, BOUND, LISTENING and CONNECTED.
* Sockets can be in 5 states: NEW, BOUND, LISTENING, CONNECTING and CONNECTED.
*
* +------------------+
* | |
Expand All @@ -24,17 +24,19 @@
* +--> NEW --------------------> BOUND -------------> LISTEN --------------+
* | | | ^ new socket
* | | | | |
* | | | | |
* | | connect() | | disconnect() |
* | | | | (if it was bound) |
* | | connect() | | |
* | | | | |
* | | V | |
* | +---------------------> CONNECTED <--------------------------------+
* | |
* | disconnect() |
* | (if it was not bound) |
* +------------------------------+
* | | | +------------------------+ |
* | | connect() | disconnect() | |
* | | | (if it was bound) | |
* | | connect() | | |
* | | | select()/poll()/ | |
* | | V epoll() | |
* | +---------------------> CONNECTING ---------------> CONNECTED <----+
* | (only for |
* | non-blocking sockets) |
* | |
* | disconnect() |
* | (if it was not bound) |
* +-----------------------------------------------------------+
*
*/

Expand Down Expand Up @@ -84,6 +86,44 @@ struct libos_handle* get_new_socket_handle(int family, int type, int protocol,
return handle;
}

void check_connect_inprogress_on_poll(struct libos_handle* handle, bool error_event) {
/*
* Special case of a non-blocking socket that is INPROGRESS (connecting): must check if error or
* success of connecting. If error, then set SO_ERROR (last_error). If success, then move to
* SOCK_CONNECTED state and clear SO_ERROR. See `man 2 connect`, EINPROGRESS case.
*
* We first fetch `connecting_in_progress` instead of a proper lock on the handle to speed up
* the common case of an already-connected socket doing recv/send.
*/
assert(handle->type == TYPE_SOCK);

bool inprog = __atomic_load_n(&handle->info.sock.connecting_in_progress, __ATOMIC_ACQUIRE);
if (!inprog)
return;

struct libos_sock_handle* sock = &handle->info.sock;
lock(&sock->lock);

if (sock->state != SOCK_CONNECTING) {
/* data race: another thread could have done another select/poll on this socket and
* modified the state; there's nothing left to be done */
goto out;
}

if (error_event) {
sock->last_error = ECONNREFUSED;
goto out;
}

sock->last_error = 0;
sock->can_be_read = true;
sock->can_be_written = true;
__atomic_store_n(&sock->connecting_in_progress, false, __ATOMIC_RELEASE);
sock->state = SOCK_CONNECTED;
out:
unlock(&sock->lock);
}

long libos_syscall_socket(int family, int type, int protocol) {
switch (family) {
case AF_UNIX:
Expand Down Expand Up @@ -212,7 +252,8 @@ long libos_syscall_socketpair(int family, int type, int protocol, int* sv) {
unlock(&sock2->lock);
goto out;
}
ret = sock2->ops->connect(handle2, &addr, sizeof(addr));
bool inprogress_unused;
ret = sock2->ops->connect(handle2, &addr, sizeof(addr), &inprogress_unused);
if (ret < 0) {
unlock(&sock2->lock);
goto out;
Expand Down Expand Up @@ -491,13 +532,20 @@ long libos_syscall_connect(int fd, void* addr, int _addrlen) {
switch (sock->state) {
case SOCK_NEW:
case SOCK_BOUND:
case SOCK_CONNECTING:
case SOCK_CONNECTED:
break;
default:
ret = -EINVAL;
goto out;
}

if (sock->state == SOCK_CONNECTING) {
assert(handle->flags & O_NONBLOCK);
ret = -EALREADY;
goto out;
}

if (sock->state == SOCK_CONNECTED) {
unsigned short addr_family;
if (addrlen < sizeof(addr_family)) {
Expand Down Expand Up @@ -539,16 +587,24 @@ long libos_syscall_connect(int fd, void* addr, int _addrlen) {
goto out;
}

ret = sock->ops->connect(handle, addr, addrlen);
bool inprogress;
ret = sock->ops->connect(handle, addr, addrlen, &inprogress);
maybe_epoll_et_trigger(handle, ret, /*in=*/false, /*was_partial=*/false);
if (ret < 0) {
goto out;
}

sock->state = SOCK_CONNECTED;
sock->can_be_read = true;
sock->can_be_written = true;
ret = 0;
if (inprogress) {
sock->state = SOCK_CONNECTING;
__atomic_store_n(&sock->connecting_in_progress, true, __ATOMIC_RELEASE);
sock->last_error = EINPROGRESS;
ret = -((int)sock->last_error);
} else {
sock->state = SOCK_CONNECTED;
sock->can_be_read = true;
sock->can_be_written = true;
ret = 0;
}

out:
if (ret == -EINTR) {
Expand Down Expand Up @@ -636,9 +692,14 @@ ssize_t do_sendmsg(struct libos_handle* handle, struct iovec* iov, size_t iov_le
}

lock(&sock->lock);
if (sock->state == SOCK_CONNECTING) {
unlock(&sock->lock);
return -EAGAIN;
}

bool has_sendtimeout_set = !!sock->sendtimeout_us;

ret = -sock->last_error;
ret = -((ssize_t)sock->last_error);
sock->last_error = 0;

if (!ret && !sock->can_be_written) {
Expand Down Expand Up @@ -800,8 +861,14 @@ ssize_t do_recvmsg(struct libos_handle* handle, struct iovec* iov, size_t iov_le
struct libos_sock_handle* sock = &handle->info.sock;

lock(&sock->lock);
if (sock->state == SOCK_CONNECTING) {
unlock(&sock->lock);
return -EAGAIN;
}

bool has_recvtimeout_set = !!sock->receivetimeout_us;
ret = -sock->last_error;

ret = -((ssize_t)sock->last_error);
sock->last_error = 0;
unlock(&sock->lock);

Expand Down
1 change: 1 addition & 0 deletions libos/test/regression/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ tests = {
'syscall_restart': {},
'sysfs_common': {},
'tcp_ancillary': {},
'tcp_einprogress': {},
'tcp_ipv6_v6only': {},
'tcp_msg_peek': {},
'udp': {},
Expand Down
Loading

0 comments on commit a2e7d95

Please sign in to comment.