Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LibOS,PAL] Support EINPROGRESS on non-blocking sockets connect #1643

Merged
merged 1 commit into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion libos/include/libos_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct libos_pipe_handle {
enum libos_sock_state {
SOCK_NEW,
SOCK_BOUND,
SOCK_CONNECTING,
SOCK_CONNECTED,
SOCK_LISTENING,
};
Expand All @@ -69,7 +70,7 @@ enum libos_sock_state {
* stream reads (see the comment in `do_recvmsg` in "libos/src/sys/libos_socket.c").
* Access to `force_nonblocking_users_count` is protected by the lock of the handle wrapping this
* struct.
* `pal_handle` should be accessed using atomic operations.
* `pal_handle` and `connecting_in_progress` should be accessed using atomic operations.
* If you need to take both `recv_lock` and `lock`, take the former first.
*/
struct libos_sock_handle {
Expand Down Expand Up @@ -98,6 +99,8 @@ struct libos_sock_handle {
uint64_t sendtimeout_us;
uint64_t receivetimeout_us;
unsigned int last_error;
/* This field is an atomically-accessed version of the SOCK_CONNECTING state (for perf). */
bool connecting_in_progress;
/* This field denotes whether the socket was ever bound. */
bool was_bound;
/* This field indicates if the socket is ready for read-like operations (`recv`/`read` or
Expand Down
3 changes: 2 additions & 1 deletion libos/include/libos_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct libos_sock_ops {
*
* Must be called with `handle->info.sock.lock` taken.
*/
int (*connect)(struct libos_handle* handle, void* addr, size_t addrlen);
int (*connect)(struct libos_handle* handle, void* addr, size_t addrlen, bool* out_inprogress);

/*!
* \brief Disconnect a previously connected handle.
Expand Down Expand Up @@ -121,6 +121,7 @@ struct libos_sock_ops {

struct libos_handle* get_new_socket_handle(int family, int type, int protocol,
bool is_nonblocking);
void check_connect_inprogress_on_poll(struct libos_handle* handle, bool error_event);

extern struct libos_sock_ops sock_unix_ops;
extern struct libos_sock_ops sock_ip_ops;
Expand Down
12 changes: 7 additions & 5 deletions libos/src/net/ip.c
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ static int accept(struct libos_handle* handle, bool is_nonblocking,
return 0;
}

static int connect(struct libos_handle* handle, void* addr, size_t addrlen) {
static int connect(struct libos_handle* handle, void* addr, size_t addrlen, bool* out_inprogress) {
struct libos_sock_handle* sock = &handle->info.sock;
assert(locked(&sock->lock));

Expand All @@ -221,9 +221,8 @@ static int connect(struct libos_handle* handle, void* addr, size_t addrlen) {
linux_to_pal_sockaddr(addr, &pal_remote_addr);
struct pal_socket_addr pal_local_addr;

/* XXX: this connect is always blocking (regardless of actual setting of nonblockingness on
* `sock->pal_handle`. See also the comment in tcp connect implementation in Linux PAL. */
ret = PalSocketConnect(sock->pal_handle, &pal_remote_addr, &pal_local_addr);
bool inprogress;
ret = PalSocketConnect(sock->pal_handle, &pal_remote_addr, &pal_local_addr, &inprogress);
if (ret < 0) {
return ret == -PAL_ERROR_CONNFAILED ? -ECONNREFUSED : pal_to_unix_errno(ret);
}
Expand All @@ -235,6 +234,7 @@ static int connect(struct libos_handle* handle, void* addr, size_t addrlen) {
assert(!sock->was_bound);
pal_to_linux_sockaddr(&pal_local_addr, &sock->local_addr, &sock->local_addrlen);
}
*out_inprogress = inprogress;
return 0;
}

Expand All @@ -245,7 +245,9 @@ static int disconnect(struct libos_handle* handle) {
struct pal_socket_addr pal_ip_addr = {
.domain = PAL_DISCONNECT,
};
int ret = PalSocketConnect(sock->pal_handle, &pal_ip_addr, /*local_addr=*/NULL);
bool inprogress_unused;
int ret = PalSocketConnect(sock->pal_handle, &pal_ip_addr, /*local_addr=*/NULL,
&inprogress_unused);
return pal_to_unix_errno(ret);
}

Expand Down
3 changes: 2 additions & 1 deletion libos/src/net/unix.c
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ static int accept(struct libos_handle* handle, bool is_nonblocking,
return 0;
}

static int connect(struct libos_handle* handle, void* addr, size_t addrlen) {
static int connect(struct libos_handle* handle, void* addr, size_t addrlen, bool* out_inprogress) {
struct libos_sock_handle* sock = &handle->info.sock;
assert(locked(&sock->lock));

Expand Down Expand Up @@ -260,6 +260,7 @@ static int connect(struct libos_handle* handle, void* addr, size_t addrlen) {
}

interrupt_epolls(handle);
*out_inprogress = false;
return 0;
}

Expand Down
7 changes: 7 additions & 0 deletions libos/src/sys/libos_epoll.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "libos_pollable_event.h"
#include "libos_refcount.h"
#include "libos_signal.h"
#include "libos_socket.h"
#include "libos_table.h"
#include "libos_thread.h"
#include "libos_types.h"
Expand Down Expand Up @@ -673,6 +674,12 @@ static int do_epoll_wait(int epfd, struct epoll_event* events, int maxevents, in
this_item_events |= items[i]->events & (EPOLLOUT | EPOLLWRNORM);
}

if (items[i]->handle->type == TYPE_SOCK &&
(pal_ret_events[i] & (PAL_WAIT_READ | PAL_WAIT_WRITE))) {
bool error_event = !!(pal_ret_events[i] & (PAL_WAIT_ERROR | PAL_WAIT_HANG_UP));
check_connect_inprogress_on_poll(items[i]->handle, error_event);
}

if (!this_item_events) {
/* This handle is not interested in events that were detected - epoll item was
* probably updated asynchronously. */
Expand Down
7 changes: 7 additions & 0 deletions libos/src/sys/libos_poll.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "libos_internal.h"
#include "libos_lock.h"
#include "libos_signal.h"
#include "libos_socket.h"
#include "libos_table.h"
#include "libos_thread.h"
#include "libos_utils.h"
Expand Down Expand Up @@ -212,6 +213,12 @@ static long do_poll(struct pollfd* fds, size_t fds_len, uint64_t* timeout_us) {
if (ret_events[i] & PAL_WAIT_WRITE)
fds[i].revents |= fds[i].events & (POLLOUT | POLLWRNORM);

if (libos_handles[i]->type == TYPE_SOCK &&
(ret_events[i] & (PAL_WAIT_READ | PAL_WAIT_WRITE))) {
bool error_event = !!(ret_events[i] & (PAL_WAIT_ERROR | PAL_WAIT_HANG_UP));
check_connect_inprogress_on_poll(libos_handles[i], error_event);
}

if (fds[i].revents)
ret_events_count++;
}
Expand Down
107 changes: 87 additions & 20 deletions libos/src/sys/libos_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "linux_abi/errors.h"

/*
* Sockets can be in 4 states: NEW, BOUND, LISTENING and CONNECTED.
* Sockets can be in 5 states: NEW, BOUND, LISTENING, CONNECTING and CONNECTED.
*
* +------------------+
* | |
Expand All @@ -24,17 +24,19 @@
* +--> NEW --------------------> BOUND -------------> LISTEN --------------+
* | | | ^ new socket
* | | | | |
* | | | | |
* | | connect() | | disconnect() |
* | | | | (if it was bound) |
* | | connect() | | |
* | | | | |
* | | V | |
* | +---------------------> CONNECTED <--------------------------------+
* | |
* | disconnect() |
* | (if it was not bound) |
* +------------------------------+
* | | | +------------------------+ |
* | | connect() | disconnect() | |
* | | | (if it was bound) | |
* | | connect() | | |
* | | | select()/poll()/ | |
* | | V epoll() | |
* | +---------------------> CONNECTING ---------------> CONNECTED <----+
* | (only for |
* | non-blocking sockets) |
* | |
* | disconnect() |
* | (if it was not bound) |
* +-----------------------------------------------------------+
*
*/

Expand Down Expand Up @@ -84,6 +86,44 @@ struct libos_handle* get_new_socket_handle(int family, int type, int protocol,
return handle;
}

void check_connect_inprogress_on_poll(struct libos_handle* handle, bool error_event) {
/*
* Special case of a non-blocking socket that is INPROGRESS (connecting): must check if error or
* success of connecting. If error, then set SO_ERROR (last_error). If success, then move to
* SOCK_CONNECTED state and clear SO_ERROR. See `man 2 connect`, EINPROGRESS case.
*
* We first fetch `connecting_in_progress` instead of a proper lock on the handle to speed up
* the common case of an already-connected socket doing recv/send.
*/
assert(handle->type == TYPE_SOCK);

bool inprog = __atomic_load_n(&handle->info.sock.connecting_in_progress, __ATOMIC_ACQUIRE);
if (!inprog)
return;

struct libos_sock_handle* sock = &handle->info.sock;
lock(&sock->lock);

if (sock->state != SOCK_CONNECTING) {
/* data race: another thread could have done another select/poll on this socket and
* modified the state; there's nothing left to be done */
goto out;
}

if (error_event) {
sock->last_error = ECONNREFUSED;
goto out;
}

sock->last_error = 0;
sock->can_be_read = true;
sock->can_be_written = true;
__atomic_store_n(&sock->connecting_in_progress, false, __ATOMIC_RELEASE);
sock->state = SOCK_CONNECTED;
out:
unlock(&sock->lock);
}

long libos_syscall_socket(int family, int type, int protocol) {
switch (family) {
case AF_UNIX:
Expand Down Expand Up @@ -212,7 +252,8 @@ long libos_syscall_socketpair(int family, int type, int protocol, int* sv) {
unlock(&sock2->lock);
goto out;
}
ret = sock2->ops->connect(handle2, &addr, sizeof(addr));
bool inprogress_unused;
ret = sock2->ops->connect(handle2, &addr, sizeof(addr), &inprogress_unused);
if (ret < 0) {
unlock(&sock2->lock);
goto out;
Expand Down Expand Up @@ -491,13 +532,20 @@ long libos_syscall_connect(int fd, void* addr, int _addrlen) {
switch (sock->state) {
case SOCK_NEW:
case SOCK_BOUND:
case SOCK_CONNECTING:
case SOCK_CONNECTED:
break;
default:
ret = -EINVAL;
goto out;
}

if (sock->state == SOCK_CONNECTING) {
assert(handle->flags & O_NONBLOCK);
ret = -EALREADY;
goto out;
}

if (sock->state == SOCK_CONNECTED) {
unsigned short addr_family;
if (addrlen < sizeof(addr_family)) {
Expand Down Expand Up @@ -539,16 +587,24 @@ long libos_syscall_connect(int fd, void* addr, int _addrlen) {
goto out;
}

ret = sock->ops->connect(handle, addr, addrlen);
bool inprogress;
ret = sock->ops->connect(handle, addr, addrlen, &inprogress);
maybe_epoll_et_trigger(handle, ret, /*in=*/false, /*was_partial=*/false);
if (ret < 0) {
goto out;
}

sock->state = SOCK_CONNECTED;
sock->can_be_read = true;
sock->can_be_written = true;
ret = 0;
if (inprogress) {
sock->state = SOCK_CONNECTING;
__atomic_store_n(&sock->connecting_in_progress, true, __ATOMIC_RELEASE);
sock->last_error = EINPROGRESS;
ret = -((int)sock->last_error);
} else {
sock->state = SOCK_CONNECTED;
sock->can_be_read = true;
sock->can_be_written = true;
ret = 0;
}

out:
if (ret == -EINTR) {
Expand Down Expand Up @@ -636,9 +692,14 @@ ssize_t do_sendmsg(struct libos_handle* handle, struct iovec* iov, size_t iov_le
}

lock(&sock->lock);
if (sock->state == SOCK_CONNECTING) {
unlock(&sock->lock);
return -EAGAIN;
}

bool has_sendtimeout_set = !!sock->sendtimeout_us;

ret = -sock->last_error;
ret = -((ssize_t)sock->last_error);
sock->last_error = 0;

if (!ret && !sock->can_be_written) {
Expand Down Expand Up @@ -800,8 +861,14 @@ ssize_t do_recvmsg(struct libos_handle* handle, struct iovec* iov, size_t iov_le
struct libos_sock_handle* sock = &handle->info.sock;

lock(&sock->lock);
if (sock->state == SOCK_CONNECTING) {
unlock(&sock->lock);
return -EAGAIN;
}

bool has_recvtimeout_set = !!sock->receivetimeout_us;
ret = -sock->last_error;

ret = -((ssize_t)sock->last_error);
sock->last_error = 0;
unlock(&sock->lock);

Expand Down
1 change: 1 addition & 0 deletions libos/test/regression/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ tests = {
'syscall_restart': {},
'sysfs_common': {},
'tcp_ancillary': {},
'tcp_einprogress': {},
'tcp_ipv6_v6only': {},
'tcp_msg_peek': {},
'udp': {},
Expand Down
Loading