Skip to content

Commit

Permalink
Sockets: Add HasAnyClientSockets()
Browse files Browse the repository at this point in the history
  • Loading branch information
stenzek committed Jul 6, 2024
1 parent b06fcef commit 1fd8d27
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
31 changes: 27 additions & 4 deletions src/util/sockets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,13 @@ void SocketMultiplexer::AddOpenSocket(std::shared_ptr<BaseSocket> socket)
std::unique_lock lock(m_open_sockets_lock);

DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end());
m_open_sockets.emplace(socket->GetDescriptor(), socket);
m_open_sockets.emplace(socket->GetDescriptor(), std::move(socket));
}

void SocketMultiplexer::AddClientSocket(std::shared_ptr<BaseSocket> socket)
{
AddOpenSocket(std::move(socket));
m_client_socket_count.fetch_add(1, std::memory_order_acq_rel);
}

void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket)
Expand Down Expand Up @@ -349,12 +355,29 @@ void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket)
m_poll_array_active_size = new_active_size;
}

void SocketMultiplexer::RemoveClientSocket(BaseSocket* socket)
{
DebugAssert(m_client_socket_count.load(std::memory_order_acquire) > 0);
m_client_socket_count.fetch_sub(1, std::memory_order_acq_rel);
RemoveOpenSocket(socket);
}

bool SocketMultiplexer::HasAnyOpenSockets()
{
std::unique_lock lock(m_open_sockets_lock);
return !m_open_sockets.empty();
}

bool SocketMultiplexer::HasAnyClientSockets()
{
return (GetClientSocketCount() > 0);
}

size_t SocketMultiplexer::GetClientSocketCount()
{
return m_client_socket_count.load(std::memory_order_acquire);
}

void SocketMultiplexer::CloseAll()
{
std::unique_lock lock(m_open_sockets_lock);
Expand Down Expand Up @@ -559,7 +582,7 @@ u32 StreamSocket::GetSocketProtocolForAddress(const SocketAddress& sa)
void StreamSocket::InitialSetup()
{
// register for notifications
m_multiplexer.AddOpenSocket(shared_from_this());
m_multiplexer.AddClientSocket(shared_from_this());
m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN);

// trigger connected notification
Expand Down Expand Up @@ -679,7 +702,7 @@ void StreamSocket::Close()
return;

m_multiplexer.SetNotificationMask(this, m_descriptor, 0);
m_multiplexer.RemoveOpenSocket(this);
m_multiplexer.RemoveClientSocket(this);
shutdown(m_descriptor, SD_BOTH);
closesocket(m_descriptor);
m_descriptor = INVALID_SOCKET;
Expand All @@ -701,7 +724,7 @@ void StreamSocket::CloseWithError()
error.SetSocket(error_code);

m_multiplexer.SetNotificationMask(this, m_descriptor, 0);
m_multiplexer.RemoveOpenSocket(this);
m_multiplexer.RemoveClientSocket(this);
closesocket(m_descriptor);
m_descriptor = INVALID_SOCKET;
m_connected = false;
Expand Down
11 changes: 10 additions & 1 deletion src/util/sockets.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
#include <memory>
#include <mutex>
#include <optional>
#include <unordered_map>
#include <span>
#include <unordered_map>

#ifdef _WIN32
using SocketDescriptor = uintptr_t;
Expand Down Expand Up @@ -108,6 +108,12 @@ class SocketMultiplexer final
// Returns true if any sockets are currently registered.
bool HasAnyOpenSockets();

// Returns true if any client sockets are currently connected.
bool HasAnyClientSockets();

// Returns the number of current client sockets.
size_t GetClientSocketCount();

// Close all sockets on this multiplexer.
void CloseAll();

Expand All @@ -127,7 +133,9 @@ class SocketMultiplexer final

// Tracking of open sockets.
void AddOpenSocket(std::shared_ptr<BaseSocket> socket);
void AddClientSocket(std::shared_ptr<BaseSocket> socket);
void RemoveOpenSocket(BaseSocket* socket);
void RemoveClientSocket(BaseSocket* socket);

// Register for notifications
void SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events);
Expand All @@ -143,6 +151,7 @@ class SocketMultiplexer final

std::mutex m_open_sockets_lock;
SocketMap m_open_sockets;
std::atomic_size_t m_client_socket_count{0};
};

template<class T>
Expand Down

0 comments on commit 1fd8d27

Please sign in to comment.