diff --git a/libos/include/libos_handle.h b/libos/include/libos_handle.h index 8cb06da987..6cf9658a87 100644 --- a/libos/include/libos_handle.h +++ b/libos/include/libos_handle.h @@ -56,6 +56,7 @@ struct libos_pipe_handle { enum libos_sock_state { SOCK_NEW, SOCK_BOUND, + SOCK_CONNECTING, SOCK_CONNECTED, SOCK_LISTENING, }; @@ -69,7 +70,7 @@ enum libos_sock_state { * stream reads (see the comment in `do_recvmsg` in "libos/src/sys/libos_socket.c"). * Access to `force_nonblocking_users_count` is protected by the lock of the handle wrapping this * struct. - * `pal_handle` should be accessed using atomic operations. + * `pal_handle` and `connecting_in_progress` should be accessed using atomic operations. * If you need to take both `recv_lock` and `lock`, take the former first. */ struct libos_sock_handle { @@ -98,6 +99,8 @@ struct libos_sock_handle { uint64_t sendtimeout_us; uint64_t receivetimeout_us; unsigned int last_error; + /* This field is an atomically-accessed version of the SOCK_CONNECTING state (for perf). */ + bool connecting_in_progress; /* This field denotes whether the socket was ever bound. */ bool was_bound; /* This field indicates if the socket is ready for read-like operations (`recv`/`read` or diff --git a/libos/include/libos_socket.h b/libos/include/libos_socket.h index 8b7d5789bf..4c696976b8 100644 --- a/libos/include/libos_socket.h +++ b/libos/include/libos_socket.h @@ -49,7 +49,7 @@ struct libos_sock_ops { * * Must be called with `handle->info.sock.lock` taken. */ - int (*connect)(struct libos_handle* handle, void* addr, size_t addrlen); + int (*connect)(struct libos_handle* handle, void* addr, size_t addrlen, bool* out_inprogress); /*! * \brief Disconnect a previously connected handle. @@ -121,6 +121,7 @@ struct libos_sock_ops { struct libos_handle* get_new_socket_handle(int family, int type, int protocol, bool is_nonblocking); +void check_connect_inprogress_on_poll(struct libos_handle* handle, bool error_event); extern struct libos_sock_ops sock_unix_ops; extern struct libos_sock_ops sock_ip_ops; diff --git a/libos/src/net/ip.c b/libos/src/net/ip.c index 9ddb5547b9..1a9e595b16 100644 --- a/libos/src/net/ip.c +++ b/libos/src/net/ip.c @@ -208,7 +208,7 @@ static int accept(struct libos_handle* handle, bool is_nonblocking, return 0; } -static int connect(struct libos_handle* handle, void* addr, size_t addrlen) { +static int connect(struct libos_handle* handle, void* addr, size_t addrlen, bool* out_inprogress) { struct libos_sock_handle* sock = &handle->info.sock; assert(locked(&sock->lock)); @@ -221,9 +221,8 @@ static int connect(struct libos_handle* handle, void* addr, size_t addrlen) { linux_to_pal_sockaddr(addr, &pal_remote_addr); struct pal_socket_addr pal_local_addr; - /* XXX: this connect is always blocking (regardless of actual setting of nonblockingness on - * `sock->pal_handle`. See also the comment in tcp connect implementation in Linux PAL. */ - ret = PalSocketConnect(sock->pal_handle, &pal_remote_addr, &pal_local_addr); + bool inprogress; + ret = PalSocketConnect(sock->pal_handle, &pal_remote_addr, &pal_local_addr, &inprogress); if (ret < 0) { return ret == -PAL_ERROR_CONNFAILED ? -ECONNREFUSED : pal_to_unix_errno(ret); } @@ -235,6 +234,7 @@ static int connect(struct libos_handle* handle, void* addr, size_t addrlen) { assert(!sock->was_bound); pal_to_linux_sockaddr(&pal_local_addr, &sock->local_addr, &sock->local_addrlen); } + *out_inprogress = inprogress; return 0; } @@ -245,7 +245,9 @@ static int disconnect(struct libos_handle* handle) { struct pal_socket_addr pal_ip_addr = { .domain = PAL_DISCONNECT, }; - int ret = PalSocketConnect(sock->pal_handle, &pal_ip_addr, /*local_addr=*/NULL); + bool inprogress_unused; + int ret = PalSocketConnect(sock->pal_handle, &pal_ip_addr, /*local_addr=*/NULL, + &inprogress_unused); return pal_to_unix_errno(ret); } diff --git a/libos/src/net/unix.c b/libos/src/net/unix.c index f65292713e..a638eb0b5f 100644 --- a/libos/src/net/unix.c +++ b/libos/src/net/unix.c @@ -213,7 +213,7 @@ static int accept(struct libos_handle* handle, bool is_nonblocking, return 0; } -static int connect(struct libos_handle* handle, void* addr, size_t addrlen) { +static int connect(struct libos_handle* handle, void* addr, size_t addrlen, bool* out_inprogress) { struct libos_sock_handle* sock = &handle->info.sock; assert(locked(&sock->lock)); @@ -260,6 +260,7 @@ static int connect(struct libos_handle* handle, void* addr, size_t addrlen) { } interrupt_epolls(handle); + *out_inprogress = false; return 0; } diff --git a/libos/src/sys/libos_epoll.c b/libos/src/sys/libos_epoll.c index 093efe43ba..fab6965eee 100644 --- a/libos/src/sys/libos_epoll.c +++ b/libos/src/sys/libos_epoll.c @@ -28,6 +28,7 @@ #include "libos_pollable_event.h" #include "libos_refcount.h" #include "libos_signal.h" +#include "libos_socket.h" #include "libos_table.h" #include "libos_thread.h" #include "libos_types.h" @@ -673,6 +674,12 @@ static int do_epoll_wait(int epfd, struct epoll_event* events, int maxevents, in this_item_events |= items[i]->events & (EPOLLOUT | EPOLLWRNORM); } + if (items[i]->handle->type == TYPE_SOCK && + (pal_ret_events[i] & (PAL_WAIT_READ | PAL_WAIT_WRITE))) { + bool error_event = !!(pal_ret_events[i] & (PAL_WAIT_ERROR | PAL_WAIT_HANG_UP)); + check_connect_inprogress_on_poll(items[i]->handle, error_event); + } + if (!this_item_events) { /* This handle is not interested in events that were detected - epoll item was * probably updated asynchronously. */ diff --git a/libos/src/sys/libos_poll.c b/libos/src/sys/libos_poll.c index 64af5e1b53..734b7a7242 100644 --- a/libos/src/sys/libos_poll.c +++ b/libos/src/sys/libos_poll.c @@ -12,6 +12,7 @@ #include "libos_internal.h" #include "libos_lock.h" #include "libos_signal.h" +#include "libos_socket.h" #include "libos_table.h" #include "libos_thread.h" #include "libos_utils.h" @@ -212,6 +213,12 @@ static long do_poll(struct pollfd* fds, size_t fds_len, uint64_t* timeout_us) { if (ret_events[i] & PAL_WAIT_WRITE) fds[i].revents |= fds[i].events & (POLLOUT | POLLWRNORM); + if (libos_handles[i]->type == TYPE_SOCK && + (ret_events[i] & (PAL_WAIT_READ | PAL_WAIT_WRITE))) { + bool error_event = !!(ret_events[i] & (PAL_WAIT_ERROR | PAL_WAIT_HANG_UP)); + check_connect_inprogress_on_poll(libos_handles[i], error_event); + } + if (fds[i].revents) ret_events_count++; } diff --git a/libos/src/sys/libos_socket.c b/libos/src/sys/libos_socket.c index d5acab1105..e56ae1dddb 100644 --- a/libos/src/sys/libos_socket.c +++ b/libos/src/sys/libos_socket.c @@ -15,7 +15,7 @@ #include "linux_abi/errors.h" /* - * Sockets can be in 4 states: NEW, BOUND, LISTENING and CONNECTED. + * Sockets can be in 5 states: NEW, BOUND, LISTENING, CONNECTING and CONNECTED. * * +------------------+ * | | @@ -24,17 +24,19 @@ * +--> NEW --------------------> BOUND -------------> LISTEN --------------+ * | | | ^ new socket * | | | | | - * | | | | | - * | | connect() | | disconnect() | - * | | | | (if it was bound) | - * | | connect() | | | - * | | | | | - * | | V | | - * | +---------------------> CONNECTED <--------------------------------+ - * | | - * | disconnect() | - * | (if it was not bound) | - * +------------------------------+ + * | | | +------------------------+ | + * | | connect() | disconnect() | | + * | | | (if it was bound) | | + * | | connect() | | | + * | | | select()/poll()/ | | + * | | V epoll() | | + * | +---------------------> CONNECTING ---------------> CONNECTED <----+ + * | (only for | + * | non-blocking sockets) | + * | | + * | disconnect() | + * | (if it was not bound) | + * +-----------------------------------------------------------+ * */ @@ -84,6 +86,44 @@ struct libos_handle* get_new_socket_handle(int family, int type, int protocol, return handle; } +void check_connect_inprogress_on_poll(struct libos_handle* handle, bool error_event) { + /* + * Special case of a non-blocking socket that is INPROGRESS (connecting): must check if error or + * success of connecting. If error, then set SO_ERROR (last_error). If success, then move to + * SOCK_CONNECTED state and clear SO_ERROR. See `man 2 connect`, EINPROGRESS case. + * + * We first fetch `connecting_in_progress` instead of a proper lock on the handle to speed up + * the common case of an already-connected socket doing recv/send. + */ + assert(handle->type == TYPE_SOCK); + + bool inprog = __atomic_load_n(&handle->info.sock.connecting_in_progress, __ATOMIC_ACQUIRE); + if (!inprog) + return; + + struct libos_sock_handle* sock = &handle->info.sock; + lock(&sock->lock); + + if (sock->state != SOCK_CONNECTING) { + /* data race: another thread could have done another select/poll on this socket and + * modified the state; there's nothing left to be done */ + goto out; + } + + if (error_event) { + sock->last_error = ECONNREFUSED; + goto out; + } + + sock->last_error = 0; + sock->can_be_read = true; + sock->can_be_written = true; + __atomic_store_n(&sock->connecting_in_progress, false, __ATOMIC_RELEASE); + sock->state = SOCK_CONNECTED; +out: + unlock(&sock->lock); +} + long libos_syscall_socket(int family, int type, int protocol) { switch (family) { case AF_UNIX: @@ -212,7 +252,8 @@ long libos_syscall_socketpair(int family, int type, int protocol, int* sv) { unlock(&sock2->lock); goto out; } - ret = sock2->ops->connect(handle2, &addr, sizeof(addr)); + bool inprogress_unused; + ret = sock2->ops->connect(handle2, &addr, sizeof(addr), &inprogress_unused); if (ret < 0) { unlock(&sock2->lock); goto out; @@ -491,6 +532,7 @@ long libos_syscall_connect(int fd, void* addr, int _addrlen) { switch (sock->state) { case SOCK_NEW: case SOCK_BOUND: + case SOCK_CONNECTING: case SOCK_CONNECTED: break; default: @@ -498,6 +540,12 @@ long libos_syscall_connect(int fd, void* addr, int _addrlen) { goto out; } + if (sock->state == SOCK_CONNECTING) { + assert(handle->flags & O_NONBLOCK); + ret = -EALREADY; + goto out; + } + if (sock->state == SOCK_CONNECTED) { unsigned short addr_family; if (addrlen < sizeof(addr_family)) { @@ -539,16 +587,24 @@ long libos_syscall_connect(int fd, void* addr, int _addrlen) { goto out; } - ret = sock->ops->connect(handle, addr, addrlen); + bool inprogress; + ret = sock->ops->connect(handle, addr, addrlen, &inprogress); maybe_epoll_et_trigger(handle, ret, /*in=*/false, /*was_partial=*/false); if (ret < 0) { goto out; } - sock->state = SOCK_CONNECTED; - sock->can_be_read = true; - sock->can_be_written = true; - ret = 0; + if (inprogress) { + sock->state = SOCK_CONNECTING; + __atomic_store_n(&sock->connecting_in_progress, true, __ATOMIC_RELEASE); + sock->last_error = EINPROGRESS; + ret = -((int)sock->last_error); + } else { + sock->state = SOCK_CONNECTED; + sock->can_be_read = true; + sock->can_be_written = true; + ret = 0; + } out: if (ret == -EINTR) { @@ -636,9 +692,14 @@ ssize_t do_sendmsg(struct libos_handle* handle, struct iovec* iov, size_t iov_le } lock(&sock->lock); + if (sock->state == SOCK_CONNECTING) { + unlock(&sock->lock); + return -EAGAIN; + } + bool has_sendtimeout_set = !!sock->sendtimeout_us; - ret = -sock->last_error; + ret = -((ssize_t)sock->last_error); sock->last_error = 0; if (!ret && !sock->can_be_written) { @@ -800,8 +861,14 @@ ssize_t do_recvmsg(struct libos_handle* handle, struct iovec* iov, size_t iov_le struct libos_sock_handle* sock = &handle->info.sock; lock(&sock->lock); + if (sock->state == SOCK_CONNECTING) { + unlock(&sock->lock); + return -EAGAIN; + } + bool has_recvtimeout_set = !!sock->receivetimeout_us; - ret = -sock->last_error; + + ret = -((ssize_t)sock->last_error); sock->last_error = 0; unlock(&sock->lock); diff --git a/libos/test/regression/meson.build b/libos/test/regression/meson.build index 90df8a6cb6..cbd15b4443 100644 --- a/libos/test/regression/meson.build +++ b/libos/test/regression/meson.build @@ -141,6 +141,7 @@ tests = { 'syscall_restart': {}, 'sysfs_common': {}, 'tcp_ancillary': {}, + 'tcp_einprogress': {}, 'tcp_ipv6_v6only': {}, 'tcp_msg_peek': {}, 'udp': {}, diff --git a/libos/test/regression/tcp_einprogress.c b/libos/test/regression/tcp_einprogress.c new file mode 100644 index 0000000000..f1f3555a21 --- /dev/null +++ b/libos/test/regression/tcp_einprogress.c @@ -0,0 +1,176 @@ +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" + +#define ERR(msg, args...) \ + errx(1, "%d: " msg, __LINE__, ##args) + +#define TIMEOUT_MS 1000 /* 1s; increase to e.g. 10s for manual tests */ +#define PORT 12345 /* nothing must be bound to this port! */ + +static void usage(const char* prog_name) { + fprintf(stderr, "usage: %s poll|epoll\n", prog_name); + fprintf(stderr, "(use 127.0.0.1 for responsive peer and 10.255.255.255 for unresponsive " + "peer)\n"); +} + +int main(int argc, const char** argv) { + int ret; + + if (argc != 3) { + usage(argv[0]); + return 1; + } + + if (strcmp(argv[2], "poll") && strcmp(argv[2], "epoll")) { + usage(argv[0]); + fprintf(stderr, "error: second argument not recognized (only 'poll'/'epoll' allowed)\n"); + return 1; + } + + int s = CHECK(socket(AF_INET, SOCK_STREAM, 0)); + + int flags = CHECK(fcntl(s, F_GETFL, 0)); + CHECK(fcntl(s, F_SETFL, flags | O_NONBLOCK)); + + struct sockaddr_in sa = { + .sin_family = AF_INET, + .sin_port = htons(PORT), + }; + if (inet_aton(argv[1], &sa.sin_addr) != 1) + ERR("inet_aton failed"); + + ret = connect(s, (void*)&sa, sizeof(sa)); + if (ret != -1) + ERR("connect unexpectedly succeeded"); + if (errno != EINPROGRESS && errno != ECONNREFUSED) + ERR("connect didn't fail with EINPROGRESS or ECONNREFUSED but with %s", strerror(errno)); + + if (errno == ECONNREFUSED) { + /* boring case without EINPROGRESS (aka blocking connect) */ + puts("TEST OK (no EINPROGRESS)"); + CHECK(close(s)); + return 0; + } + assert(errno == EINPROGRESS); + + struct sockaddr_in sa_local; + socklen_t addrlen_local = sizeof(sa_local); + ret = getsockname(s, (struct sockaddr*)&sa_local, &addrlen_local); + if (ret < 0) + ERR("[after EINPROGRESS] getsockname failed with %s", strerror(errno)); + printf("local address %s:%hu\n", inet_ntoa(sa_local.sin_addr), ntohs(sa_local.sin_port)); + fflush(stdout); + + ret = connect(s, (void*)&sa, sizeof(sa)); + if (ret != -1) { + ERR("[after EINPROGRESS] second connect unexpectedly succeeded"); + } + if (errno != EALREADY && errno != ECONNREFUSED) { + ERR("[after EINPROGRESS] second connect didn't fail with EALREADY or ECONNREFUSED but with" + " %s", strerror(errno)); + } + + if (errno == ECONNREFUSED) { + /* another boring case with EINPROGRESS but a quick response */ + puts("TEST OK (quick response)"); + CHECK(close(s)); + return 0; + } + assert(errno == EALREADY); + + struct sockaddr_in sa_peer; + socklen_t addrlen_peer = sizeof(sa_peer); + ret = getpeername(s, (struct sockaddr*)&sa_peer, &addrlen_peer); + if (ret != -1) { + ERR("[after EINPROGRESS] expected getpeername to fail but it succeeded"); + } + if (errno != ENOTCONN) { + ERR("[after EINPROGRESS] expected getpeername to fail with ENOTCONN but failed with %s", + strerror(errno)); + } + + char dummy_buf[3] = "hi"; + ssize_t bytes = send(s, dummy_buf, sizeof(dummy_buf), /*flags=*/0); + if (bytes != -1) { + ERR("[after EINPROGRESS] expected send to fail but it succeeded"); + } + if (errno != EAGAIN) { + ERR("[after EINPROGRESS] expected send to fail with EAGAIN but failed with %s", + strerror(errno)); + } + + bytes = recv(s, dummy_buf, sizeof(dummy_buf), /*flags=*/0); + if (bytes != -1) { + ERR("[after EINPROGRESS] expected recv to fail but it succeeded"); + } + if (errno != EAGAIN) { + ERR("[after EINPROGRESS] expected recv to fail with EAGAIN but failed with %s", + strerror(errno)); + } + + /* test can be run with "poll" or "epoll" cmdline arg: we test POLLOUT for the poll case and + * EPOLLIN for the epoll case (no reason other than to test both write and read events) */ + bool timedout = false; + bool poll_event_happened = false; + if (strcmp(argv[2], "poll") == 0) { + struct pollfd infds[] = { + {.fd = s, .events = POLLOUT}, + }; + ret = CHECK(poll(infds, 1, TIMEOUT_MS)); + if (ret == 0) + timedout = true; + else + poll_event_happened = !!(infds[0].revents & POLLOUT); + + } else { + int epfd = CHECK(epoll_create(/*size=*/1)); + struct epoll_event event = { .events = EPOLLIN }; + CHECK(epoll_ctl(epfd, EPOLL_CTL_ADD, s, &event)); + struct epoll_event out_event = { 0 }; + ret = CHECK(epoll_wait(epfd, &out_event, /*max_events=*/1, TIMEOUT_MS)); + CHECK(close(epfd)); + if (ret == 0) + timedout = true; + else + poll_event_happened = !!(out_event.events & EPOLLIN); + } + + /* one interesting case -- remote peer is completely unresponsive */ + if (timedout) { + puts("TEST OK (connection timed out)"); + CHECK(close(s)); + return 0; + } + + /* the most interesting case -- remote peer not unresponsive but very slow */ + if (!poll_event_happened) { + ERR("[after EINPROGRESS] polling didn't return %s on connecting socket", + strcmp(argv[2], "poll") == 0 ? "POLLOUT" : "EPOLLIN"); + } + + int so_error; + socklen_t optlen = sizeof(so_error); + CHECK(getsockopt(s, SOL_SOCKET, SO_ERROR, &so_error, &optlen)); + if (optlen != sizeof(so_error) || so_error != ECONNREFUSED) { + ERR("[after EINPROGRESS] expected SO_ERROR to be ECONNREFUSED but it is %s", + strerror(so_error)); + } + + puts("TEST OK (connection refused after initial EINPROGRESS)"); + CHECK(close(s)); + return 0; +} diff --git a/libos/test/regression/test_libos.py b/libos/test/regression/test_libos.py index 57b73feb1a..2994a2e0f4 100644 --- a/libos/test/regression/test_libos.py +++ b/libos/test/regression/test_libos.py @@ -1411,6 +1411,26 @@ def test_301_socket_tcp_ancillary(self): stdout, _ = self.run_binary(['tcp_ancillary']) self.assertIn('TEST OK', stdout) + # Two tests for a responsive peer: first connect() returns EINPROGRESS, then poll/epoll + # immediately returns because the connection is quickly refused + def test_305_socket_tcp_einprogress_responsive_poll(self): + stdout, _ = self.run_binary(['tcp_einprogress', '127.0.0.1', 'poll']) + self.assertIn('TEST OK (connection refused after initial EINPROGRESS)', stdout) + + def test_306_socket_tcp_einprogress_responsive_epoll(self): + stdout, _ = self.run_binary(['tcp_einprogress', '127.0.0.1', 'epoll']) + self.assertIn('TEST OK (connection refused after initial EINPROGRESS)', stdout) + + # Two tests for an unresponsive peer: first connect() returns EINPROGRESS, then poll/epoll + # times out because the connection cannot be established + def test_307_socket_tcp_einprogress_unresponsive_poll(self): + stdout, _ = self.run_binary(['tcp_einprogress', '10.255.255.255', 'poll']) + self.assertIn('TEST OK (connection timed out)', stdout) + + def test_308_socket_tcp_einprogress_unresponsive_epoll(self): + stdout, _ = self.run_binary(['tcp_einprogress', '10.255.255.255', 'epoll']) + self.assertIn('TEST OK (connection timed out)', stdout) + def test_310_socket_tcp_ipv6_v6only(self): stdout, _ = self.run_binary(['tcp_ipv6_v6only'], timeout=50) self.assertIn('test completed successfully', stdout) diff --git a/libos/test/regression/tests.toml b/libos/test/regression/tests.toml index 5a91b9dc34..1c8f858ba8 100644 --- a/libos/test/regression/tests.toml +++ b/libos/test/regression/tests.toml @@ -118,6 +118,7 @@ manifests = [ "syscall_restart", "sysfs_common", "tcp_ancillary", + "tcp_einprogress", "tcp_ipv6_v6only", "tcp_msg_peek", "toml_parsing", diff --git a/libos/test/regression/tests_musl.toml b/libos/test/regression/tests_musl.toml index 83d3d64715..f0803fb384 100644 --- a/libos/test/regression/tests_musl.toml +++ b/libos/test/regression/tests_musl.toml @@ -119,6 +119,7 @@ manifests = [ "syscall_restart", "sysfs_common", "tcp_ancillary", + "tcp_einprogress", "tcp_ipv6_v6only", "tcp_msg_peek", "toml_parsing", diff --git a/pal/include/pal/pal.h b/pal/include/pal/pal.h index dd5cce5fb2..259082fb76 100644 --- a/pal/include/pal/pal.h +++ b/pal/include/pal/pal.h @@ -595,13 +595,15 @@ int PalSocketAccept(PAL_HANDLE handle, pal_stream_options_t options, PAL_HANDLE* * \param addr Address to connect to. * \param[out] out_local_addr On success contains the local address of the socket. * Can be NULL, to ignore the result. + * \param[out] out_inprogress On success, returns true in special case of an in-progress connection + * on a non-blocking socket. * * \returns 0 on success, negative error code on failure. * * Can also be used to disconnect the socket, if #PAL_DISCONNECT is passed in \p addr. */ int PalSocketConnect(PAL_HANDLE handle, struct pal_socket_addr* addr, - struct pal_socket_addr* out_local_addr); + struct pal_socket_addr* out_local_addr, bool* out_inprogress); /*! * \brief Send data. diff --git a/pal/include/pal_internal.h b/pal/include/pal_internal.h index c7ca593ec2..66fbe0b171 100644 --- a/pal/include/pal_internal.h +++ b/pal/include/pal_internal.h @@ -119,7 +119,7 @@ struct socket_ops { int (*accept)(PAL_HANDLE handle, pal_stream_options_t options, PAL_HANDLE* out_client, struct pal_socket_addr* out_client_addr, struct pal_socket_addr* out_local_addr); int (*connect)(PAL_HANDLE handle, struct pal_socket_addr* addr, - struct pal_socket_addr* out_local_addr); + struct pal_socket_addr* out_local_addr, bool* out_inprogress); int (*send)(PAL_HANDLE handle, struct iovec* iov, size_t iov_len, size_t* out_size, struct pal_socket_addr* addr, bool force_nonblocking); int (*recv)(PAL_HANDLE handle, struct iovec* iov, size_t iov_len, size_t* out_size, @@ -190,7 +190,7 @@ int _PalSocketAccept(PAL_HANDLE handle, pal_stream_options_t options, PAL_HANDLE struct pal_socket_addr* out_client_addr, struct pal_socket_addr* out_local_addr); int _PalSocketConnect(PAL_HANDLE handle, struct pal_socket_addr* addr, - struct pal_socket_addr* out_local_addr); + struct pal_socket_addr* out_local_addr, bool* out_inprogress); int _PalSocketSend(PAL_HANDLE handle, struct iovec* iov, size_t iov_len, size_t* out_size, struct pal_socket_addr* addr, bool force_nonblocking); int _PalSocketRecv(PAL_HANDLE handle, struct iovec* iov, size_t iov_len, size_t* out_total_size, diff --git a/pal/regression/send_handle.c b/pal/regression/send_handle.c index 9f13168862..26807347d0 100644 --- a/pal/regression/send_handle.c +++ b/pal/regression/send_handle.c @@ -118,8 +118,9 @@ static void do_parent(void) { CHECK(PalSendHandle(child_process, handle)); PalObjectDestroy(handle); + bool connect_inprogress_unused; CHECK(PalSocketCreate(PAL_IPV4, PAL_SOCKET_TCP, /*options=*/0, &handle)); - CHECK(PalSocketConnect(handle, &addr, /*local_addr=*/NULL)); + CHECK(PalSocketConnect(handle, &addr, /*local_addr=*/NULL, &connect_inprogress_unused)); recv_and_check(handle, PAL_TYPE_SOCKET); PalObjectDestroy(handle); @@ -138,7 +139,7 @@ static void do_parent(void) { PalObjectDestroy(handle); CHECK(PalSocketCreate(PAL_IPV6, PAL_SOCKET_UDP, /*options=*/0, &handle)); - CHECK(PalSocketConnect(handle, &addr, /*local_addr=*/NULL)); + CHECK(PalSocketConnect(handle, &addr, /*local_addr=*/NULL, &connect_inprogress_unused)); write_msg(handle, PAL_TYPE_SOCKET); PalObjectDestroy(handle); diff --git a/pal/src/host/linux-sgx/enclave_ocalls.c b/pal/src/host/linux-sgx/enclave_ocalls.c index dec1776c68..c0c8dfc59f 100644 --- a/pal/src/host/linux-sgx/enclave_ocalls.c +++ b/pal/src/host/linux-sgx/enclave_ocalls.c @@ -1426,7 +1426,8 @@ int ocall_connect(int domain, int type, int protocol, int ipv6_v6only, const str return retval; } -int ocall_connect_simple(int fd, struct sockaddr_storage* addr, size_t* addrlen) { +int ocall_connect_simple(int fd, bool nonblocking, struct sockaddr_storage* addr, size_t* addrlen, + bool* out_inprogress) { int ret; void* old_ustack = sgx_prepare_ustack(); struct ocall_connect_simple* ocall_connect_args; @@ -1454,12 +1455,25 @@ int ocall_connect_simple(int fd, struct sockaddr_storage* addr, size_t* addrlen) ret = sgx_exitless_ocall(OCALL_CONNECT_SIMPLE, ocall_connect_args); } while (ret == -EINTR); + bool inprogress = false; + if (ret == -EINPROGRESS) { + if (!nonblocking) { + /* EINPROGRESS can be returned only on non-blocking sockets */ + ret = -EPERM; + goto out; + } + /* POSIX/Linux have an unusual semantics for EINPROGRESS: the connect operation is + * considered successful, but the return value is -EINPROGRESS error code. We don't want to + * replicate this oddness in Gramine, so we return `0` and set a special variable. */ + inprogress = true; + ret = 0; + } + if (ret < 0) { if (ret != -EACCES && ret != -EPERM && ret != -EADDRINUSE && ret != -EADDRNOTAVAIL && ret != -EAFNOSUPPORT && ret != -EAGAIN && ret != -EALREADY && ret != -EBADF - && ret != -ECONNREFUSED && ret != -EINPROGRESS && ret != -EISCONN - && ret != -ENETUNREACH && ret != -ENOTSOCK && ret != -EPROTOTYPE - && ret != -ETIMEDOUT) { + && ret != -ECONNREFUSED && ret != -EISCONN && ret != -ENETUNREACH + && ret != -ENOTSOCK && ret != -EPROTOTYPE && ret != -ETIMEDOUT) { ret = -EPERM; } goto out; @@ -1476,6 +1490,7 @@ int ocall_connect_simple(int fd, struct sockaddr_storage* addr, size_t* addrlen) goto out; } *addrlen = new_addrlen; + *out_inprogress = inprogress; ret = 0; out: diff --git a/pal/src/host/linux-sgx/enclave_ocalls.h b/pal/src/host/linux-sgx/enclave_ocalls.h index d2c726f16e..1cacfcd46b 100644 --- a/pal/src/host/linux-sgx/enclave_ocalls.h +++ b/pal/src/host/linux-sgx/enclave_ocalls.h @@ -63,7 +63,8 @@ int ocall_accept(int sockfd, struct sockaddr* addr, size_t* addrlen, struct sock int ocall_connect(int domain, int type, int protocol, int ipv6_v6only, const struct sockaddr* addr, size_t addrlen, struct sockaddr* bind_addr, size_t* bind_addrlen); -int ocall_connect_simple(int fd, struct sockaddr_storage* addr, size_t* addrlen); +int ocall_connect_simple(int fd, bool nonblocking, struct sockaddr_storage* addr, size_t* addrlen, + bool* out_inprogress); ssize_t ocall_recv(int sockfd, struct iovec* buf, size_t iov_len, void* addr, size_t* addrlenptr, void* control, size_t* controllenptr, unsigned int flags); diff --git a/pal/src/host/linux-sgx/host_ocalls.c b/pal/src/host/linux-sgx/host_ocalls.c index c3c2a76591..4e63e704fd 100644 --- a/pal/src/host/linux-sgx/host_ocalls.c +++ b/pal/src/host/linux-sgx/host_ocalls.c @@ -506,40 +506,23 @@ static long sgx_ocall_connect_simple(void* args) { struct ocall_connect_simple* ocall_connect_args = args; int ret = DO_SYSCALL_INTERRUPTIBLE(connect, ocall_connect_args->fd, ocall_connect_args->addr, (int)ocall_connect_args->addrlen); - if (ret < 0) { - /* XXX: Non blocking socket. Currently there is no way of notifying LibOS of successful or - * failed connection, so we have to block and wait. */ - if (ret != -EINPROGRESS) { - return ret; - } - struct pollfd pfd = { - .fd = ocall_connect_args->fd, - .events = POLLOUT, - }; - ret = DO_SYSCALL(poll, &pfd, 1, /*timeout=*/-1); - if (ret != 1 || pfd.revents == 0) { - return ret < 0 ? ret : -EINVAL; - } - int val = 0; - unsigned int len = sizeof(val); - ret = DO_SYSCALL(getsockopt, ocall_connect_args->fd, SOL_SOCKET, SO_ERROR, &val, &len); - if (ret < 0 || val < 0) { - return ret < 0 ? ret : -EINVAL; - } - if (val) { - return -val; - } - /* Connect succeeded. */ + if (ret < 0 && ret != -EINPROGRESS) { + return ret; } + /* Connect succeeded or in progress (EINPROGRESS); in both cases retrieve local name -- host + * Linux binds the socket to address even in case of EINPROGRESS. */ int addrlen = sizeof(*ocall_connect_args->addr); - ret = DO_SYSCALL(getsockname, ocall_connect_args->fd, ocall_connect_args->addr, &addrlen); - if (ret < 0) { - return ret; + int getsockname_ret = DO_SYSCALL(getsockname, ocall_connect_args->fd, ocall_connect_args->addr, + &addrlen); + if (getsockname_ret < 0) { + /* This should never happen, but we have to handle it somehow. */ + return getsockname_ret; } - ocall_connect_args->addrlen = addrlen; - return 0; + + assert(ret == 0 || ret == -EINPROGRESS); + return ret; } static long sgx_ocall_recv(void* args) { diff --git a/pal/src/host/linux-sgx/pal_sockets.c b/pal/src/host/linux-sgx/pal_sockets.c index 1daa3b9403..ccc7f44219 100644 --- a/pal/src/host/linux-sgx/pal_sockets.c +++ b/pal/src/host/linux-sgx/pal_sockets.c @@ -255,7 +255,7 @@ static int tcp_accept(PAL_HANDLE handle, pal_stream_options_t options, PAL_HANDL } static int connect(PAL_HANDLE handle, struct pal_socket_addr* addr, - struct pal_socket_addr* out_local_addr) { + struct pal_socket_addr* out_local_addr, bool* out_inprogress) { assert(handle->hdr.type == PAL_TYPE_SOCKET); if (addr->domain != PAL_DISCONNECT && addr->domain != handle->sock.domain) { return -PAL_ERROR_INVAL; @@ -266,18 +266,21 @@ static int connect(PAL_HANDLE handle, struct pal_socket_addr* addr, pal_to_linux_sockaddr(addr, &sa_storage, &linux_addrlen); assert(linux_addrlen <= INT_MAX); - int ret = ocall_connect_simple(handle->sock.fd, &sa_storage, &linux_addrlen); + bool inprogress; + int ret = ocall_connect_simple(handle->sock.fd, handle->sock.is_nonblocking, &sa_storage, + &linux_addrlen, &inprogress); if (ret < 0) { return unix_to_pal_error(ret); } if (out_local_addr) { - ret = verify_ip_addr(handle->sock.domain, &sa_storage, linux_addrlen); - if (ret < 0) { - return ret; + int verify_ret = verify_ip_addr(handle->sock.domain, &sa_storage, linux_addrlen); + if (verify_ret < 0) { + return verify_ret; } linux_to_pal_sockaddr(&sa_storage, out_local_addr); } + *out_inprogress = inprogress; return 0; } @@ -668,11 +671,11 @@ int _PalSocketAccept(PAL_HANDLE handle, pal_stream_options_t options, PAL_HANDLE } int _PalSocketConnect(PAL_HANDLE handle, struct pal_socket_addr* addr, - struct pal_socket_addr* out_local_addr) { + struct pal_socket_addr* out_local_addr, bool* out_inprogress) { if (!handle->sock.ops->connect) { return -PAL_ERROR_NOTSUPPORT; } - return handle->sock.ops->connect(handle, addr, out_local_addr); + return handle->sock.ops->connect(handle, addr, out_local_addr, out_inprogress); } int _PalSocketSend(PAL_HANDLE handle, struct iovec* iov, size_t iov_len, size_t* out_size, diff --git a/pal/src/host/linux/pal_sockets.c b/pal/src/host/linux/pal_sockets.c index 75a68f3c5a..8f59ba9d2a 100644 --- a/pal/src/host/linux/pal_sockets.c +++ b/pal/src/host/linux/pal_sockets.c @@ -263,7 +263,7 @@ static int tcp_accept(PAL_HANDLE handle, pal_stream_options_t options, PAL_HANDL } static int connect(PAL_HANDLE handle, struct pal_socket_addr* addr, - struct pal_socket_addr* out_local_addr) { + struct pal_socket_addr* out_local_addr, bool* out_inprogress) { assert(handle->hdr.type == PAL_TYPE_SOCKET); if (addr->domain != PAL_DISCONNECT && addr->domain != handle->sock.domain) { return -PAL_ERROR_INVAL; @@ -275,40 +275,25 @@ static int connect(PAL_HANDLE handle, struct pal_socket_addr* addr, assert(linux_addrlen <= INT_MAX); int ret = DO_SYSCALL(connect, handle->sock.fd, &sa_storage, (int)linux_addrlen); - if (ret < 0) { - /* XXX: Non blocking socket. Currently there is no way of notifying LibOS of successful or - * failed connection, so we have to block and wait. */ - if (ret != -EINPROGRESS) { - return unix_to_pal_error(ret); - } - struct pollfd pfd = { - .fd = handle->sock.fd, - .events = POLLOUT, - }; - ret = DO_SYSCALL(poll, &pfd, 1, /*timeout=*/-1); - if (ret != 1 || pfd.revents == 0) { - return ret < 0 ? unix_to_pal_error(ret) : -PAL_ERROR_INVAL; - } - int val = 0; - unsigned int len = sizeof(val); - ret = DO_SYSCALL(getsockopt, handle->sock.fd, SOL_SOCKET, SO_ERROR, &val, &len); - if (ret < 0 || val < 0) { - return ret < 0 ? unix_to_pal_error(ret) : -PAL_ERROR_INVAL; - } - if (val) { - return unix_to_pal_error(-val); - } - /* Connect succeeded. */ + if (ret < 0 && ret != -EINPROGRESS) { + return unix_to_pal_error(ret); } + /* Connect succeeded or in progress (EINPROGRESS); in both cases retrieve local name -- host + * Linux binds the socket to address even in case of EINPROGRESS */ if (out_local_addr) { - ret = do_getsockname(handle->sock.fd, &sa_storage); - if (ret < 0) { + int getsockname_ret = do_getsockname(handle->sock.fd, &sa_storage); + if (getsockname_ret < 0) { /* This should never happen, but we have to handle it somehow. */ - return ret; + return getsockname_ret; } linux_to_pal_sockaddr(&sa_storage, out_local_addr); } + + /* POSIX/Linux have an unusual semantics for EINPROGRESS: the connect operation is considered + * successful, but the return value is -EINPROGRESS error code. We don't want to replicate this + * oddness in Gramine, so we return `0` and set a special variable. */ + *out_inprogress = (ret == -EINPROGRESS); return 0; } @@ -721,11 +706,11 @@ int _PalSocketAccept(PAL_HANDLE handle, pal_stream_options_t options, PAL_HANDLE } int _PalSocketConnect(PAL_HANDLE handle, struct pal_socket_addr* addr, - struct pal_socket_addr* out_local_addr) { + struct pal_socket_addr* out_local_addr, bool* out_inprogress) { if (!handle->sock.ops->connect) { return -PAL_ERROR_NOTSUPPORT; } - return handle->sock.ops->connect(handle, addr, out_local_addr); + return handle->sock.ops->connect(handle, addr, out_local_addr, out_inprogress); } int _PalSocketSend(PAL_HANDLE handle, struct iovec* iov, size_t iov_len, size_t* out_size, diff --git a/pal/src/host/skeleton/pal_sockets.c b/pal/src/host/skeleton/pal_sockets.c index aeacefb2c4..a098c330dd 100644 --- a/pal/src/host/skeleton/pal_sockets.c +++ b/pal/src/host/skeleton/pal_sockets.c @@ -25,7 +25,7 @@ int _PalSocketAccept(PAL_HANDLE handle, pal_stream_options_t options, PAL_HANDLE } int _PalSocketConnect(PAL_HANDLE handle, struct pal_socket_addr* addr, - struct pal_socket_addr* local_addr) { + struct pal_socket_addr* out_local_addr, bool* out_inprogress) { return -PAL_ERROR_NOTIMPLEMENTED; } diff --git a/pal/src/pal_sockets.c b/pal/src/pal_sockets.c index 53b2f964c3..346e3b1aaf 100644 --- a/pal/src/pal_sockets.c +++ b/pal/src/pal_sockets.c @@ -29,9 +29,9 @@ int PalSocketAccept(PAL_HANDLE handle, pal_stream_options_t options, PAL_HANDLE* } int PalSocketConnect(PAL_HANDLE handle, struct pal_socket_addr* addr, - struct pal_socket_addr* local_addr) { + struct pal_socket_addr* out_local_addr, bool* out_inprogress) { assert(handle->hdr.type == PAL_TYPE_SOCKET); - return _PalSocketConnect(handle, addr, local_addr); + return _PalSocketConnect(handle, addr, out_local_addr, out_inprogress); } int PalSocketSend(PAL_HANDLE handle, struct iovec* iov, size_t iov_len, size_t* out_size,