From 363fd500b7a965e636230be52c987233d875ff92 Mon Sep 17 00:00:00 2001 From: Leonardo Araujo Date: Fri, 29 Mar 2024 00:50:27 -0300 Subject: [PATCH] refactor: removed ip blacklist exception --- include/base_exceptions.hpp | 9 --------- include/socket.hpp | 2 +- include/unix_socket.hpp | 2 +- include/win_socket.hpp | 2 +- src/teapot.cpp | 18 +++++++++--------- src/unix_socket.cpp | 11 +++++------ src/win_socket.cpp | 13 +++++++------ 7 files changed, 24 insertions(+), 33 deletions(-) diff --git a/include/base_exceptions.hpp b/include/base_exceptions.hpp index 69f8833..6e1e454 100644 --- a/include/base_exceptions.hpp +++ b/include/base_exceptions.hpp @@ -54,15 +54,6 @@ namespace tpt : BaseException(message, errorCode) {} }; - class IPBlackListedException : public BaseException - { - public: - IPBlackListedException( - const std::string &message = "IP is blacklisted", - int errorCode = -1) - : BaseException(message, errorCode) {} - }; - class SocketCreationException : public BaseException { public: diff --git a/include/socket.hpp b/include/socket.hpp index e0ab77b..79e87ab 100644 --- a/include/socket.hpp +++ b/include/socket.hpp @@ -15,7 +15,7 @@ namespace tpt public: virtual void bindSocket() = 0; virtual void listenToConnections() = 0; - virtual void acceptConnection(SOCKET &client_socket, void *client_address) = 0; + virtual bool acceptConnection(SOCKET &client_socket, void *client_address) = 0; virtual ssize_t receiveData(SOCKET client_socket, char *buffer, unsigned int buffer_size) = 0; virtual void sendData(SOCKET client_socket, const void *buffer, unsigned int buffer_size, int flags) = 0; virtual void closeSocket() = 0; diff --git a/include/unix_socket.hpp b/include/unix_socket.hpp index d8c73bd..7c51d6c 100644 --- a/include/unix_socket.hpp +++ b/include/unix_socket.hpp @@ -42,7 +42,7 @@ namespace tpt std::string getClientIp(); virtual void bindSocket() override; virtual void listenToConnections() override; - virtual void acceptConnection(SOCKET &client_socket, void *client_address) override; + virtual bool acceptConnection(SOCKET &client_socket, void *client_address) override; virtual ssize_t receiveData(SOCKET client_socket, char *buffer, unsigned int buffer_size) override; virtual void sendData(SOCKET client_socket, const void *buffer, unsigned int buffer_size, int flags) override; virtual void closeSocket() override; diff --git a/include/win_socket.hpp b/include/win_socket.hpp index 3dbf714..ea4ea41 100644 --- a/include/win_socket.hpp +++ b/include/win_socket.hpp @@ -46,7 +46,7 @@ namespace tpt std::string getClientIp(); virtual void bindSocket() override; virtual void listenToConnections() override; - virtual void acceptConnection(SOCKET &client_socket, void *client_address) override; + virtual bool acceptConnection(SOCKET &client_socket, void *client_address) override; virtual ssize_t receiveData(SOCKET client_socket, char *buffer, unsigned int buffer_size) override; virtual void sendData(SOCKET client_socket, const void *buffer, unsigned int buffer_size, int flags) override; virtual void closeSocket() override; diff --git a/src/teapot.cpp b/src/teapot.cpp index 6df7f20..d421e98 100644 --- a/src/teapot.cpp +++ b/src/teapot.cpp @@ -195,7 +195,7 @@ Teapot::Teapot(std::string ip_address, unsigned int port, unsigned int max_conne this->cors_middleware = CORSMiddleware(); this->security_middleware = SecurityMiddleware(); this->logger = ConsoleLogger(); - // Conditional compilation based on the operating system + #ifdef __linux__ this->socket = tpt::UnixSocket(this->logger, this->ip_address, this->port, this->max_connections); #endif @@ -229,14 +229,14 @@ void Teapot::run() try { - socket.acceptConnection(client_socket, client_addr); - auto res = std::async(std::launch::async, &Teapot::requestHandler, this, client_socket); - // std::jthread th(&Teapot::requestHandler, this, client_socket); - } - catch (IPBlackListedException &e) - { - std::cout << e.what(); - this->socket.closeSocket(client_socket); + if (socket.acceptConnection(client_socket, client_addr)) + { + auto res = std::async(std::launch::async, &Teapot::requestHandler, this, client_socket); + } + else + { + this->socket.closeSocket(client_socket); + } } catch (SocketAcceptException &) { diff --git a/src/unix_socket.cpp b/src/unix_socket.cpp index 5ae30b3..81eb402 100644 --- a/src/unix_socket.cpp +++ b/src/unix_socket.cpp @@ -52,7 +52,7 @@ void UnixSocket::listenToConnections() } } -void UnixSocket::acceptConnection(SOCKET &client_socket, void *client_address) +bool UnixSocket::acceptConnection(SOCKET &client_socket, void *client_address) { struct sockaddr_storage client_addr_storage; socklen_t client_addr_size = sizeof(client_addr_storage); @@ -64,22 +64,19 @@ void UnixSocket::acceptConnection(SOCKET &client_socket, void *client_address) throw SocketAcceptException(); } - // Assuming client_address is meant to store the result if (client_address != nullptr) { std::memcpy(client_address, &client_addr_storage, client_addr_size); } - char ip_str[INET6_ADDRSTRLEN] = {0}; // Large enough for both IPv4 and IPv6 + char ip_str[INET6_ADDRSTRLEN] = {0}; if (client_addr_storage.ss_family == AF_INET) { - // IPv4 struct sockaddr_in *addr_in = (struct sockaddr_in *)&client_addr_storage; inet_ntop(AF_INET, &addr_in->sin_addr, ip_str, INET_ADDRSTRLEN); } else if (client_addr_storage.ss_family == AF_INET6) { - // IPv6 struct sockaddr_in6 *addr_in6 = (struct sockaddr_in6 *)&client_addr_storage; inet_ntop(AF_INET6, &addr_in6->sin6_addr, ip_str, INET6_ADDRSTRLEN); } @@ -90,11 +87,13 @@ void UnixSocket::acceptConnection(SOCKET &client_socket, void *client_address) { if (this->client_ip == it) { - throw IPBlackListedException(); + return false; } } this->client_sockets.push_back(client_socket); + + return true; } ssize_t UnixSocket::receiveData(SOCKET client_socket, char *buffer, unsigned int buffer_size) diff --git a/src/win_socket.cpp b/src/win_socket.cpp index f399d6d..55c21f8 100644 --- a/src/win_socket.cpp +++ b/src/win_socket.cpp @@ -56,7 +56,7 @@ void WinSocket::listenToConnections() } } -void WinSocket::acceptConnection(SOCKET &client_socket, void *client_address) +bool WinSocket::acceptConnection(SOCKET &client_socket, void *client_address) { struct sockaddr_storage client_addr_storage; int client_addr_size = sizeof(client_addr_storage); @@ -68,22 +68,19 @@ void WinSocket::acceptConnection(SOCKET &client_socket, void *client_address) throw SocketAcceptException("error accepting connections", WSAGetLastError()); } - // Assuming client_address is meant to store the result if (client_address != nullptr) { std::memcpy(client_address, &client_addr_storage, client_addr_size); } - char ip_str[INET6_ADDRSTRLEN] = {0}; // Large enough for both IPv4 and IPv6 + char ip_str[INET6_ADDRSTRLEN] = {0}; if (client_addr_storage.ss_family == AF_INET) { - // IPv4 struct sockaddr_in *addr_in = (struct sockaddr_in *)&client_addr_storage; inet_ntop(AF_INET, &addr_in->sin_addr, ip_str, INET_ADDRSTRLEN); } else if (client_addr_storage.ss_family == AF_INET6) { - // IPv6 struct sockaddr_in6 *addr_in6 = (struct sockaddr_in6 *)&client_addr_storage; inet_ntop(AF_INET6, &addr_in6->sin6_addr, ip_str, INET6_ADDRSTRLEN); } @@ -94,9 +91,13 @@ void WinSocket::acceptConnection(SOCKET &client_socket, void *client_address) { if (this->client_ip == it) { - throw IPBlackListedException(); + return false; } } + + this->client_sockets.push_back(client_socket); + + return true; } ssize_t WinSocket::receiveData(SOCKET client_socket, char *buffer, unsigned int buffer_size)