Skip to content

Commit

Permalink
Sockets: Use epoll on Linux
Browse files Browse the repository at this point in the history
  • Loading branch information
stenzek committed Jul 21, 2024
1 parent 7880087 commit ad374ef
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 18 deletions.
134 changes: 116 additions & 18 deletions src/util/sockets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ using nfds_t = ULONG;
#include <sys/un.h>
#include <unistd.h>

#ifdef __linux__
#include <sys/epoll.h>
#endif

#define ioctlsocket ioctl
#define closesocket close
#define WSAEWOULDBLOCK EAGAIN
Expand Down Expand Up @@ -227,16 +231,42 @@ SocketMultiplexer::~SocketMultiplexer()
{
CloseAll();

#ifdef __linux__
if (m_epoll_fd >= 0)
close(m_epoll_fd);
#else
if (m_poll_array)
std::free(m_poll_array);
#endif
}

std::unique_ptr<SocketMultiplexer> SocketMultiplexer::Create(Error* error)
{
if (!PlatformMisc::InitializeSocketSupport(error))
return {};
std::unique_ptr<SocketMultiplexer> ret;
if (PlatformMisc::InitializeSocketSupport(error))
{
ret = std::unique_ptr<SocketMultiplexer>(new SocketMultiplexer());
if (!ret->Initialize(error))
ret.reset();
}

return ret;
}

bool SocketMultiplexer::Initialize(Error* error)
{
#ifdef __linux__
m_epoll_fd = epoll_create1(0);
if (m_epoll_fd < 0)
{
Error::SetErrno(error, "epoll_create1() failed: ", errno);
return false;
}

return std::unique_ptr<SocketMultiplexer>(new SocketMultiplexer());
return true;
#else
return true;
#endif
}

std::shared_ptr<ListenSocket> SocketMultiplexer::InternalCreateListenSocket(const SocketAddress& address,
Expand Down Expand Up @@ -325,8 +355,13 @@ std::shared_ptr<StreamSocket> SocketMultiplexer::InternalConnectStreamSocket(con

void SocketMultiplexer::AddOpenSocket(std::shared_ptr<BaseSocket> socket)
{
std::unique_lock lock(m_open_sockets_lock);
#ifdef __linux__
struct epoll_event ev = {.events = 0u, .data = {.fd = socket->GetDescriptor()}};
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, socket->GetDescriptor(), &ev) != 0) [[unlikely]]
ERROR_LOG("epoll_ctl() to add socket failed: {}", Error::CreateErrno(errno).GetDescription());
#endif

std::unique_lock lock(m_open_sockets_lock);
DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end());
m_open_sockets.emplace(socket->GetDescriptor(), std::move(socket));
}
Expand All @@ -339,27 +374,29 @@ void SocketMultiplexer::AddClientSocket(std::shared_ptr<BaseSocket> socket)

void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket)
{
#ifdef _DEBUG
{
std::unique_lock lock(m_poll_array_lock);
for (size_t i = 0; i < m_poll_array_active_size; i++)
{
pollfd& pfd = m_poll_array[i];
DebugAssert(pfd.fd != socket->GetDescriptor());
}
}
#endif

std::unique_lock lock(m_open_sockets_lock);
const auto iter = m_open_sockets.find(socket->GetDescriptor());
Assert(iter != m_open_sockets.end());
m_open_sockets.erase(iter);

#ifdef __linux__
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_DEL, socket->GetDescriptor(), nullptr) != 0) [[unlikely]]
ERROR_LOG("epoll_ctl() to remove socket failed: {}", Error::CreateErrno(errno).GetDescription());
#else
#ifdef _DEBUG
for (size_t i = 0; i < m_poll_array_active_size; i++)
{
pollfd& pfd = m_poll_array[i];
DebugAssert(pfd.fd != socket->GetDescriptor());
}
#endif

// Update size.
size_t new_active_size = 0;
for (size_t i = 0; i < m_poll_array_active_size; i++)
new_active_size = (m_poll_array[i].fd != INVALID_SOCKET) ? (i + 1) : new_active_size;
m_poll_array_active_size = new_active_size;
#endif
}

void SocketMultiplexer::RemoveClientSocket(BaseSocket* socket)
Expand Down Expand Up @@ -400,6 +437,11 @@ void SocketMultiplexer::CloseAll()

void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events)
{
#ifdef __linux__
struct epoll_event ev = {.events = events, .data = {.fd = descriptor}};
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_MOD, descriptor, &ev) != 0) [[unlikely]]
ERROR_LOG("epoll_ctl() for events 0x{:x} failed: {}", events, Error::CreateErrno(errno).GetDescription());
#else
std::unique_lock lock(m_poll_array_lock);
size_t free_slot = m_poll_array_active_size;
for (size_t i = 0; i < m_poll_array_active_size; i++)
Expand Down Expand Up @@ -440,10 +482,64 @@ void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor

m_poll_array[free_slot] = {.fd = descriptor, .events = static_cast<short>(events), .revents = 0};
m_poll_array_active_size = free_slot + 1;
#endif
}

bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
{
#ifdef __linux__
constexpr int MAX_EVENTS = 128;
struct epoll_event events[MAX_EVENTS];

const int nevents = epoll_wait(m_epoll_fd, events, MAX_EVENTS, static_cast<int>(milliseconds));
if (nevents <= 0)
return false;

// find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
PendingSocketPair* triggered_sockets =
reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(nevents)));
size_t num_triggered_sockets = 0;
{
std::unique_lock open_lock(m_open_sockets_lock);
for (int i = 0; i < nevents; i++)
{
const epoll_event& ev = events[i];
const auto iter = m_open_sockets.find(ev.data.fd);
if (iter == m_open_sockets.end()) [[unlikely]]
{
ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", ev.data.fd);
continue;
}

// we add a reference here in case the read kills it with a write pending, or something like that
new (&triggered_sockets[num_triggered_sockets++]) PendingSocketPair(iter->second->shared_from_this(), ev.events);
}
}

// fire events
for (size_t i = 0; i < num_triggered_sockets; i++)
{
PendingSocketPair& psp = triggered_sockets[i];

// fire events
if (psp.second & (EPOLLRDHUP | EPOLLHUP | EPOLLERR))
{
psp.first->OnHangupEvent();
}
else
{
if (psp.second & EPOLLIN)
psp.first->OnReadEvent();
if (psp.second & EPOLLOUT)
psp.first->OnWriteEvent();
}

psp.first.~shared_ptr();
}

return true;
#else
std::unique_lock lock(m_poll_array_lock);
if (m_poll_array_active_size == 0)
return false;
Expand All @@ -454,7 +550,8 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)

// find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
PendingSocketPair* triggered_sockets = reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * res));
PendingSocketPair* triggered_sockets =
reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(res)));
size_t num_triggered_sockets = 0;
{
std::unique_lock open_lock(m_open_sockets_lock);
Expand All @@ -467,7 +564,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
const auto iter = m_open_sockets.find(pfd.fd);
if (iter == m_open_sockets.end()) [[unlikely]]
{
ERROR_LOG("Attempting to look up known socket {}, this should never happen.", pfd.fd);
ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", pfd.fd);
continue;
}

Expand All @@ -481,7 +578,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
lock.unlock();

// fire events
for (u32 i = 0; i < num_triggered_sockets; i++)
for (size_t i = 0; i < num_triggered_sockets; i++)
{
PendingSocketPair& psp = triggered_sockets[i];

Expand All @@ -502,6 +599,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
}

return true;
#endif
}

ListenSocket::ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
Expand Down
7 changes: 7 additions & 0 deletions src/util/sockets.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ class SocketMultiplexer final
// Hide the constructor.
SocketMultiplexer();

// Initialization.
bool Initialize(Error* error);

// Tracking of open sockets.
void AddOpenSocket(std::shared_ptr<BaseSocket> socket);
void AddClientSocket(std::shared_ptr<BaseSocket> socket);
Expand All @@ -148,10 +151,14 @@ class SocketMultiplexer final
// We store the fd in the struct to avoid the cache miss reading the object.
using SocketMap = std::unordered_map<SocketDescriptor, std::shared_ptr<BaseSocket>>;

#ifdef __linux__
int m_epoll_fd = -1;
#else
std::mutex m_poll_array_lock;
pollfd* m_poll_array = nullptr;
size_t m_poll_array_active_size = 0;
size_t m_poll_array_max_size = 0;
#endif

std::mutex m_open_sockets_lock;
SocketMap m_open_sockets;
Expand Down

0 comments on commit ad374ef

Please sign in to comment.