diff --git a/src/app/CMakeLists.txt b/src/app/CMakeLists.txt index 873b611fc..ef5334f57 100644 --- a/src/app/CMakeLists.txt +++ b/src/app/CMakeLists.txt @@ -1,4 +1,4 @@ add_library(clio_app) -target_sources(clio_app PRIVATE CliArgs.cpp ClioApplication.cpp) +target_sources(clio_app PRIVATE CliArgs.cpp ClioApplication.cpp WebHandlers.cpp) target_link_libraries(clio_app PUBLIC clio_etl clio_etlng clio_feed clio_web clio_rpc) diff --git a/src/app/ClioApplication.cpp b/src/app/ClioApplication.cpp index aceb9a290..c06494765 100644 --- a/src/app/ClioApplication.cpp +++ b/src/app/ClioApplication.cpp @@ -19,6 +19,7 @@ #include "app/ClioApplication.hpp" +#include "app/WebHandlers.hpp" #include "data/AmendmentCenter.hpp" #include "data/BackendFactory.hpp" #include "etl/ETLService.hpp" @@ -30,11 +31,9 @@ #include "rpc/RPCEngine.hpp" #include "rpc/WorkQueue.hpp" #include "rpc/common/impl/HandlerProvider.hpp" -#include "util/Assert.hpp" #include "util/build/Build.hpp" #include "util/config/Config.hpp" #include "util/log/Logger.hpp" -#include "util/prometheus/Http.hpp" #include "util/prometheus/Prometheus.hpp" #include "web/AdminVerificationStrategy.hpp" #include "web/RPCServerHandler.hpp" @@ -52,11 +51,14 @@ #include #include #include +#include +#include #include #include #include #include +#include #include #include #include @@ -65,14 +67,6 @@ namespace app { namespace { -auto constexpr HealthCheckHTML = R"html( - - - Test page for Clio -

Clio Test

This page shows Clio http(s) connectivity is working.

- -)html"; - /** * @brief Start context threads * @@ -158,72 +152,18 @@ ClioApplication::run(bool const useNgWebServer) } auto const adminVerifier = std::move(expectedAdminVerifier).value(); - auto httpServer = web::ng::make_Server(config_, ioc); + auto httpServer = web::ng::make_Server(config_, OnConnectCheck{dosGuard}, DisconnectHook{dosGuard}, ioc); if (not httpServer.has_value()) { LOG(util::LogService::error()) << "Error creating web server: " << httpServer.error(); return EXIT_FAILURE; } - httpServer->onGet( - "/metrics", - [adminVerifier]( - web::ng::Request const& request, - web::ng::ConnectionMetadata& connectionMetadata, - web::SubscriptionContextPtr, - boost::asio::yield_context - ) -> web::ng::Response { - auto const maybeHttpRequest = request.asHttpRequest(); - ASSERT(maybeHttpRequest.has_value(), "Got not a http request in Get"); - auto const& httpRequest = maybeHttpRequest->get(); - - // FIXME(#1702): Using veb server thread to handle prometheus request. Better to post on work queue. - auto maybeResponse = util::prometheus::handlePrometheusRequest( - httpRequest, adminVerifier->isAdmin(httpRequest, connectionMetadata.ip()) - ); - ASSERT(maybeResponse.has_value(), "Got unexpected request for Prometheus"); - return web::ng::Response{std::move(maybeResponse).value(), request}; - } - ); - - httpServer->onGet( - "/health", - [](web::ng::Request const& request, - web::ng::ConnectionMetadata&, - web::SubscriptionContextPtr, - boost::asio::yield_context) -> web::ng::Response { - return web::ng::Response{boost::beast::http::status::ok, HealthCheckHTML, request}; - } - ); - - util::Logger webServerLog{"WebServer"}; - auto onRequest = [adminVerifier, &webServerLog, &handler]( - web::ng::Request const& request, - web::ng::ConnectionMetadata& connectionMetadata, - web::SubscriptionContextPtr subscriptionContext, - boost::asio::yield_context yield - ) -> web::ng::Response { - LOG(webServerLog.info()) << connectionMetadata.tag() - << "Received request from ip = " << connectionMetadata.ip() - << " - posting to WorkQueue"; - - connectionMetadata.setIsAdmin([&adminVerifier, &request, &connectionMetadata]() { - return adminVerifier->isAdmin(request.httpHeaders(), connectionMetadata.ip()); - }); - - try { - return handler(request, connectionMetadata, std::move(subscriptionContext), yield); - } catch (std::exception const&) { - return web::ng::Response{ - boost::beast::http::status::internal_server_error, - rpc::makeError(rpc::RippledError::rpcINTERNAL), - request - }; - } - }; - - httpServer->onPost("/", onRequest); - httpServer->onWs(onRequest); + httpServer->onGet("/metrics", MetricsHandler{adminVerifier}); + httpServer->onGet("/health", HealthCheckHandler{}); + auto requestHandler = RequestHandler{adminVerifier, handler, dosGuard}; + httpServer->onPost("/", requestHandler); + httpServer->onWs(std::move(requestHandler)); auto const maybeError = httpServer->run(); if (maybeError.has_value()) { diff --git a/src/app/WebHandlers.cpp b/src/app/WebHandlers.cpp new file mode 100644 index 000000000..7a1a1a412 --- /dev/null +++ b/src/app/WebHandlers.cpp @@ -0,0 +1,111 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "app/WebHandlers.hpp" + +#include "util/Assert.hpp" +#include "util/prometheus/Http.hpp" +#include "web/AdminVerificationStrategy.hpp" +#include "web/SubscriptionContextInterface.hpp" +#include "web/dosguard/DOSGuardInterface.hpp" +#include "web/ng/Connection.hpp" +#include "web/ng/Request.hpp" +#include "web/ng/Response.hpp" + +#include +#include + +#include +#include +#include + +namespace app { + +OnConnectCheck::OnConnectCheck(web::dosguard::DOSGuardInterface& dosguard) : dosguard_{dosguard} +{ +} + +std::expected +OnConnectCheck::operator()(web::ng::Connection const& connection) +{ + dosguard_.get().increment(connection.ip()); + if (not dosguard_.get().isOk(connection.ip())) { + return std::unexpected{ + web::ng::Response{boost::beast::http::status::too_many_requests, "Too many requests", connection} + }; + } + + return {}; +} + +DisconnectHook::DisconnectHook(web::dosguard::DOSGuardInterface& dosguard) : dosguard_{dosguard} +{ +} + +void +DisconnectHook::operator()(web::ng::Connection const& connection) +{ + dosguard_.get().decrement(connection.ip()); +} + +MetricsHandler::MetricsHandler(std::shared_ptr adminVerifier) + : adminVerifier_{std::move(adminVerifier)} +{ +} + +web::ng::Response +MetricsHandler::operator()( + web::ng::Request const& request, + web::ng::ConnectionMetadata& connectionMetadata, + web::SubscriptionContextPtr, + boost::asio::yield_context +) +{ + auto const maybeHttpRequest = request.asHttpRequest(); + ASSERT(maybeHttpRequest.has_value(), "Got not a http request in Get"); + auto const& httpRequest = maybeHttpRequest->get(); + + // FIXME(#1702): Using veb server thread to handle prometheus request. Better to post on work queue. + auto maybeResponse = util::prometheus::handlePrometheusRequest( + httpRequest, adminVerifier_->isAdmin(httpRequest, connectionMetadata.ip()) + ); + ASSERT(maybeResponse.has_value(), "Got unexpected request for Prometheus"); + return web::ng::Response{std::move(maybeResponse).value(), request}; +} + +web::ng::Response +HealthCheckHandler::operator()( + web::ng::Request const& request, + web::ng::ConnectionMetadata&, + web::SubscriptionContextPtr, + boost::asio::yield_context +) +{ + static auto constexpr HealthCheckHTML = R"html( + + + Test page for Clio +

Clio Test

This page shows Clio http(s) connectivity is working.

+ +)html"; + + return web::ng::Response{boost::beast::http::status::ok, HealthCheckHTML, request}; +} + +} // namespace app diff --git a/src/app/WebHandlers.hpp b/src/app/WebHandlers.hpp new file mode 100644 index 000000000..d66f7c783 --- /dev/null +++ b/src/app/WebHandlers.hpp @@ -0,0 +1,234 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include "rpc/Errors.hpp" +#include "util/log/Logger.hpp" +#include "web/AdminVerificationStrategy.hpp" +#include "web/SubscriptionContextInterface.hpp" +#include "web/dosguard/DOSGuardInterface.hpp" +#include "web/ng/Connection.hpp" +#include "web/ng/Request.hpp" +#include "web/ng/Response.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace app { + +/** + * @brief A function object that checks if the connection is allowed to proceed. + */ +class OnConnectCheck { + std::reference_wrapper dosguard_; + +public: + /** + * @brief Construct a new OnConnectCheck object + * + * @param dosguard The DOSGuardInterface to use for checking the connection. + */ + OnConnectCheck(web::dosguard::DOSGuardInterface& dosguard); + + /** + * @brief Check if the connection is allowed to proceed. + * + * @param connection The connection to check. + * @return A response if the connection is not allowed to proceed or void otherwise. + */ + std::expected + operator()(web::ng::Connection const& connection); +}; + +/** + * @brief A function object to be called when a connection is disconnected. + */ +class DisconnectHook { + std::reference_wrapper dosguard_; + +public: + /** + * @brief Construct a new DisconnectHook object + * + * @param dosguard The DOSGuardInterface to use for disconnecting the connection. + */ + DisconnectHook(web::dosguard::DOSGuardInterface& dosguard); + + /** + * @brief The call of the function object. + * + * @param connection The connection which has disconnected. + */ + void + operator()(web::ng::Connection const& connection); +}; + +/** + * @brief A function object that handles the metrics endpoint. + */ +class MetricsHandler { + std::shared_ptr adminVerifier_; + +public: + /** + * @brief Construct a new MetricsHandler object + * + * @param adminVerifier The AdminVerificationStrategy to use for verifying the connection for admin access. + */ + MetricsHandler(std::shared_ptr adminVerifier); + + /** + * @brief The call of the function object. + * + * @param request The request to handle. + * @param connectionMetadata The connection metadata. + * @return The response to the request. + */ + web::ng::Response + operator()( + web::ng::Request const& request, + web::ng::ConnectionMetadata& connectionMetadata, + web::SubscriptionContextPtr, + boost::asio::yield_context + ); +}; + +/** + * @brief A function object that handles the health check endpoint. + */ +class HealthCheckHandler { +public: + /** + * @brief The call of the function object. + * + * @param request The request to handle. + * @return The response to the request + */ + web::ng::Response + operator()( + web::ng::Request const& request, + web::ng::ConnectionMetadata&, + web::SubscriptionContextPtr, + boost::asio::yield_context + ); +}; + +/** + * @brief A function object that handles the websocket endpoint. + * + * @tparam RpcHandlerType The type of the RPC handler. + */ +template +class RequestHandler { + util::Logger webServerLog_{"WebServer"}; + std::shared_ptr adminVerifier_; + std::reference_wrapper rpcHandler_; + std::reference_wrapper dosguard_; + +public: + /** + * @brief Construct a new RequestHandler object + * + * @param adminVerifier The AdminVerificationStrategy to use for verifying the connection for admin access. + * @param rpcHandler The RPC handler to use for handling the request. + * @param dosguard The DOSGuardInterface to use for checking the connection. + */ + RequestHandler( + std::shared_ptr adminVerifier, + RpcHandlerType& rpcHandler, + web::dosguard::DOSGuardInterface& dosguard + ) + : adminVerifier_(std::move(adminVerifier)), rpcHandler_(rpcHandler), dosguard_(dosguard) + { + } + + /** + * @brief The call of the function object. + * + * @param request The request to handle. + * @param connectionMetadata The connection metadata. + * @param subscriptionContext The subscription context. + * @param yield The yield context. + * @return The response to the request. + */ + web::ng::Response + operator()( + web::ng::Request const& request, + web::ng::ConnectionMetadata& connectionMetadata, + web::SubscriptionContextPtr subscriptionContext, + boost::asio::yield_context yield + ) + { + if (not dosguard_.get().request(connectionMetadata.ip())) { + auto error = rpc::makeError(rpc::RippledError::rpcSLOW_DOWN); + + if (not request.isHttp()) { + try { + auto requestJson = boost::json::parse(request.message()); + if (requestJson.is_object() && requestJson.as_object().contains("id")) + error["id"] = requestJson.as_object().at("id"); + error["request"] = request.message(); + } catch (std::exception const&) { + error["request"] = request.message(); + } + } + return web::ng::Response{boost::beast::http::status::service_unavailable, error, request}; + } + LOG(webServerLog_.info()) << connectionMetadata.tag() + << "Received request from ip = " << connectionMetadata.ip() + << " - posting to WorkQueue"; + + connectionMetadata.setIsAdmin([this, &request, &connectionMetadata]() { + return adminVerifier_->isAdmin(request.httpHeaders(), connectionMetadata.ip()); + }); + + try { + auto response = rpcHandler_(request, connectionMetadata, std::move(subscriptionContext), yield); + + if (not dosguard_.get().add(connectionMetadata.ip(), response.message().size())) { + auto jsonResponse = boost::json::parse(response.message()).as_object(); + jsonResponse["warning"] = "load"; + if (jsonResponse.contains("warnings") && jsonResponse["warnings"].is_array()) { + jsonResponse["warnings"].as_array().push_back(rpc::makeWarning(rpc::warnRPC_RATE_LIMIT)); + } else { + jsonResponse["warnings"] = boost::json::array{rpc::makeWarning(rpc::warnRPC_RATE_LIMIT)}; + } + response.setMessage(jsonResponse); + } + + return response; + } catch (std::exception const&) { + return web::ng::Response{ + boost::beast::http::status::internal_server_error, + rpc::makeError(rpc::RippledError::rpcINTERNAL), + request + }; + } + } +}; + +} // namespace app diff --git a/src/web/ng/Response.cpp b/src/web/ng/Response.cpp index 545c989ae..6407f4e81 100644 --- a/src/web/ng/Response.cpp +++ b/src/web/ng/Response.cpp @@ -22,6 +22,7 @@ #include "util/Assert.hpp" #include "util/OverloadSet.hpp" #include "util/build/Build.hpp" +#include "web/ng/Connection.hpp" #include "web/ng/Request.hpp" #include @@ -33,6 +34,7 @@ #include #include +#include #include #include #include @@ -45,42 +47,63 @@ namespace web::ng { namespace { -template -consteval bool -isString() -{ - return std::is_same_v; -} +struct MessageData { + template + MessageData(MessageType message) + { + if constexpr (std::is_same_v) { + body = std::move(message); + contentType = "text/html"; + } else { + body = boost::json::serialize(message); + contentType = "application/json"; + } + } + + std::string body; + std::string contentType; +}; http::response -prepareResponse(http::response response, http::request const& request) +prepareResponse(http::response response, bool keepAlive) { response.set(http::field::server, fmt::format("clio-server-{}", util::build::getClioVersionString())); - response.keep_alive(request.keep_alive()); + response.keep_alive(keepAlive); response.prepare_payload(); return response; } +http::response +makeHttpData(MessageData messageData, http::status status, uint16_t httpVersion, bool keepAlive) +{ + http::response result{status, httpVersion, std::move(messageData.body)}; + result.set(http::field::content_type, messageData.contentType); + return prepareResponse(std::move(result), keepAlive); +} + template std::variant, std::string> makeData(http::status status, MessageType message, Request const& request) { - std::string body; - if constexpr (isString()) { - body = std::move(message); - } else { - body = boost::json::serialize(message); - } + MessageData messageData{std::move(message)}; if (not request.isHttp()) - return body; + return std::move(messageData).body; auto const& httpRequest = request.asHttpRequest()->get(); - std::string const contentType = isString() ? "text/html" : "application/json"; + return makeHttpData(std::move(messageData), status, httpRequest.version(), httpRequest.keep_alive()); +} + +template +std::variant, std::string> +makeData(http::status status, MessageType message, Connection const& connection) +{ + MessageData messageData{std::move(message)}; + + if (connection.wasUpgraded()) + return std::move(messageData).body; - http::response result{status, httpRequest.version(), std::move(body)}; - result.set(http::field::content_type, contentType); - return prepareResponse(std::move(result), httpRequest); + return makeHttpData(std::move(messageData), status, 11, false); } } // namespace @@ -95,10 +118,20 @@ Response::Response(boost::beast::http::status status, boost::json::object const& { } +Response::Response(boost::beast::http::status status, boost::json::object const& message, Connection const& connection) + : data_{makeData(status, message, connection)} +{ +} + +Response::Response(boost::beast::http::status status, std::string message, Connection const& connection) + : data_{makeData(status, std::move(message), connection)} +{ +} + Response::Response(boost::beast::http::response response, Request const& request) { ASSERT(request.isHttp(), "Request must be HTTP to construct response from HTTP response"); - data_ = prepareResponse(std::move(response), request.asHttpRequest()->get()); + data_ = prepareResponse(std::move(response), request.asHttpRequest()->get().keep_alive()); } std::string const& @@ -115,6 +148,34 @@ Response::message() const ); } +void +Response::setMessage(std::string newMessage) +{ + if (std::holds_alternative(data_)) { + std::get(data_) = std::move(newMessage); + return; + } + MessageData messageData{std::move(newMessage)}; + auto const& oldHttpResponse = std::get>(data_); + data_ = makeHttpData( + std::move(messageData), oldHttpResponse.result(), oldHttpResponse.version(), oldHttpResponse.keep_alive() + ); +} + +void +Response::setMessage(boost::json::object const& newMessage) +{ + MessageData messageData{newMessage}; + if (std::holds_alternative(data_)) { + std::get(data_) = std::move(messageData).body; + return; + } + auto const& oldHttpResponse = std::get>(data_); + data_ = makeHttpData( + std::move(messageData), oldHttpResponse.result(), oldHttpResponse.version(), oldHttpResponse.keep_alive() + ); +} + http::response Response::intoHttpResponse() && { diff --git a/src/web/ng/Response.hpp b/src/web/ng/Response.hpp index b9d197283..3657021bb 100644 --- a/src/web/ng/Response.hpp +++ b/src/web/ng/Response.hpp @@ -32,6 +32,8 @@ namespace web::ng { +class Connection; + /** * @brief Represents an HTTP or Websocket response. */ @@ -60,6 +62,26 @@ class Response { */ Response(boost::beast::http::status status, boost::json::object const& message, Request const& request); + /** + * @brief Construct a Response from string. Content type will be text/html. + * + * @param status The HTTP status. + * @param message The message to send. + * @param connection The connection that triggered this response. Used to determine whether the response should + * contain HTTP or WebSocket data. + */ + Response(boost::beast::http::status status, boost::json::object const& message, Connection const& connection); + + /** + * @brief Construct a Response from string. Content type will be text/html. + * + * @param status The HTTP status. + * @param message The message to send. + * @param connection The connection that triggered this response. Used to determine whether the response should + * contain HTTP or WebSocket data. + */ + Response(boost::beast::http::status status, std::string message, Connection const& connection); + /** * @brief Construct a Response from HTTP response. * @@ -76,6 +98,22 @@ class Response { std::string const& message() const; + /** + * @brief Replace existing message (or body) with new message. + * + * @param newMessage The new message. + */ + void + setMessage(std::string newMessage); + + /** + * @brief Replace existing message (or body) with new message. + * + * @param newMessage The new message. + */ + void + setMessage(boost::json::object const& newMessage); + /** * @brief Convert the Response to an HTTP response. * @note The Response must be constructed with an HTTP request. diff --git a/src/web/ng/Server.cpp b/src/web/ng/Server.cpp index 1628ba33f..8d905e257 100644 --- a/src/web/ng/Server.cpp +++ b/src/web/ng/Server.cpp @@ -124,12 +124,13 @@ detectSsl(boost::asio::ip::tcp::socket socket, boost::asio::yield_context yield) return SslDetectionResult{.socket = tcpStream.release_socket(), .isSsl = isSsl, .buffer = std::move(buffer)}; } -std::expected +std::expected> makeConnection( SslDetectionResult sslDetectionResult, std::optional& sslContext, std::string ip, util::TagDecoratorFactory& tagDecoratorFactory, + Server::OnConnectCheck onConnectCheck, boost::asio::yield_context yield ) { @@ -154,6 +155,13 @@ makeConnection( ); } + auto expectedSuccess = onConnectCheck(*connection); + if (not expectedSuccess.has_value()) { + connection->send(std::move(expectedSuccess).error(), yield); + connection->close(yield); + return std::unexpected{std::nullopt}; + } + auto const expectedIsUpgrade = connection->isUpgradeRequested(yield); if (not expectedIsUpgrade.has_value()) { return std::unexpected{ @@ -182,13 +190,16 @@ Server::Server( ProcessingPolicy processingPolicy, std::optional parallelRequestLimit, util::TagDecoratorFactory tagDecoratorFactory, - std::optional maxSubscriptionSendQueueSize + std::optional maxSubscriptionSendQueueSize, + OnConnectCheck onConnectCheck, + OnDisconnectHook onDisconnectHook ) : ctx_{ctx} , sslContext_{std::move(sslContext)} , tagDecoratorFactory_{tagDecoratorFactory} - , connectionHandler_{processingPolicy, parallelRequestLimit, tagDecoratorFactory_, maxSubscriptionSendQueueSize} + , connectionHandler_{processingPolicy, parallelRequestLimit, tagDecoratorFactory_, maxSubscriptionSendQueueSize, std::move(onDisconnectHook)} , endpoint_{std::move(endpoint)} + , onConnectCheck_{std::move(onConnectCheck)} { } @@ -269,13 +280,18 @@ Server::handleConnection(boost::asio::ip::tcp::socket socket, boost::asio::yield return; } - // TODO(kuznetsss): check ip with dosguard here - auto connectionExpected = makeConnection( - std::move(sslDetectionResult).value(), sslContext_, std::move(ip).value(), tagDecoratorFactory_, yield + std::move(sslDetectionResult).value(), + sslContext_, + std::move(ip).value(), + tagDecoratorFactory_, + onConnectCheck_, + yield ); if (not connectionExpected.has_value()) { - LOG(log_.info()) << "Error creating a connection: " << connectionExpected.error(); + if (connectionExpected.error().has_value()) { + LOG(log_.info()) << "Error creating a connection: " << *connectionExpected.error(); + } return; } @@ -288,7 +304,12 @@ Server::handleConnection(boost::asio::ip::tcp::socket socket, boost::asio::yield } std::expected -make_Server(util::Config const& config, boost::asio::io_context& context) +make_Server( + util::Config const& config, + Server::OnConnectCheck onConnectCheck, + Server::OnDisconnectHook onDisconnectHook, + boost::asio::io_context& context +) { auto const serverConfig = config.section("server"); @@ -321,7 +342,9 @@ make_Server(util::Config const& config, boost::asio::io_context& context) processingPolicy, parallelRequestLimit, util::TagDecoratorFactory(config), - maxSubscriptionSendQueueSize + maxSubscriptionSendQueueSize, + std::move(onConnectCheck), + std::move(onDisconnectHook) }; } diff --git a/src/web/ng/Server.hpp b/src/web/ng/Server.hpp index 674d9167c..1df7361a0 100644 --- a/src/web/ng/Server.hpp +++ b/src/web/ng/Server.hpp @@ -22,8 +22,10 @@ #include "util/Taggable.hpp" #include "util/config/Config.hpp" #include "util/log/Logger.hpp" +#include "web/ng/Connection.hpp" #include "web/ng/MessageHandler.hpp" #include "web/ng/ProcessingPolicy.hpp" +#include "web/ng/Response.hpp" #include "web/ng/impl/ConnectionHandler.hpp" #include @@ -42,6 +44,19 @@ namespace web::ng { * @brief Web server class. */ class Server { +public: + /** + * @brief Check to perform for each new client connection. The check takes client ip as input and returns a Response + * if the check failed. Response will be sent to the client and the connection will be closed. + */ + using OnConnectCheck = std::function(Connection const&)>; + + /** + * @brief Hook called when any connection disconnects + */ + using OnDisconnectHook = impl::ConnectionHandler::OnDisconnectHook; + +private: util::Logger log_{"WebServer"}; util::Logger perfLog_{"Performance"}; @@ -53,6 +68,8 @@ class Server { impl::ConnectionHandler connectionHandler_; boost::asio::ip::tcp::endpoint endpoint_; + OnConnectCheck onConnectCheck_; + bool running_{false}; public: @@ -67,6 +84,8 @@ class Server { * if processingPolicy is parallel. * @param tagDecoratorFactory The tag decorator factory. * @param maxSubscriptionSendQueueSize The maximum size of the subscription send queue. + * @param onConnectCheck The check to perform on each connection. + * @param onDisconnectHook The hook to call on each disconnection. */ Server( boost::asio::io_context& ctx, @@ -75,7 +94,9 @@ class Server { ProcessingPolicy processingPolicy, std::optional parallelRequestLimit, util::TagDecoratorFactory tagDecoratorFactory, - std::optional maxSubscriptionSendQueueSize + std::optional maxSubscriptionSendQueueSize, + OnConnectCheck onConnectCheck, + OnDisconnectHook onDisconnectHook ); /** @@ -141,11 +162,18 @@ class Server { * @brief Create a new Server. * * @param config The configuration. + * @param onConnectCheck The check to perform on each client connection. + * @param onDisconnectHook The hook to call when client disconnects. * @param context The boost::asio::io_context to use. * * @return The Server or an error message. */ std::expected -make_Server(util::Config const& config, boost::asio::io_context& context); +make_Server( + util::Config const& config, + Server::OnConnectCheck onConnectCheck, + Server::OnDisconnectHook onDisconnectHook, + boost::asio::io_context& context +); } // namespace web::ng diff --git a/src/web/ng/impl/ConnectionHandler.cpp b/src/web/ng/impl/ConnectionHandler.cpp index 60a09e34a..f60545a3c 100644 --- a/src/web/ng/impl/ConnectionHandler.cpp +++ b/src/web/ng/impl/ConnectionHandler.cpp @@ -106,12 +106,14 @@ ConnectionHandler::ConnectionHandler( ProcessingPolicy processingPolicy, std::optional maxParallelRequests, util::TagDecoratorFactory& tagFactory, - std::optional maxSubscriptionSendQueueSize + std::optional maxSubscriptionSendQueueSize, + OnDisconnectHook onDisconnectHook ) : processingPolicy_{processingPolicy} , maxParallelRequests_{maxParallelRequests} , tagFactory_{tagFactory} , maxSubscriptionSendQueueSize_{maxSubscriptionSendQueueSize} + , onDisconnectHook_{std::move(onDisconnectHook)} { } @@ -171,6 +173,7 @@ ConnectionHandler::processConnection(ConnectionPtr connectionPtr, boost::asio::y connectionRef.close(yield); signalConnection.disconnect(); + onDisconnectHook_(connectionRef); } void diff --git a/src/web/ng/impl/ConnectionHandler.hpp b/src/web/ng/impl/ConnectionHandler.hpp index 874f0fbc9..0572539a7 100644 --- a/src/web/ng/impl/ConnectionHandler.hpp +++ b/src/web/ng/impl/ConnectionHandler.hpp @@ -44,6 +44,8 @@ namespace web::ng::impl { class ConnectionHandler { public: + using OnDisconnectHook = std::function; + struct StringHash { using hash_type = std::hash; using is_transparent = void; @@ -68,6 +70,8 @@ class ConnectionHandler { std::reference_wrapper tagFactory_; std::optional maxSubscriptionSendQueueSize_; + OnDisconnectHook onDisconnectHook_; + TargetToHandlerMap getHandlers_; TargetToHandlerMap postHandlers_; std::optional wsHandler_; @@ -79,7 +83,8 @@ class ConnectionHandler { ProcessingPolicy processingPolicy, std::optional maxParallelRequests, util::TagDecoratorFactory& tagFactory, - std::optional maxSubscriptionSendQueueSize + std::optional maxSubscriptionSendQueueSize, + OnDisconnectHook onDisconnectHook ); void diff --git a/src/web/ng/impl/HttpConnection.hpp b/src/web/ng/impl/HttpConnection.hpp index 3a598c38b..f714edc53 100644 --- a/src/web/ng/impl/HttpConnection.hpp +++ b/src/web/ng/impl/HttpConnection.hpp @@ -65,6 +65,13 @@ class UpgradableConnection : public Connection { util::TagDecoratorFactory const& tagDecoratorFactory, boost::asio::yield_context yield ) = 0; + + virtual std::optional + sendRaw( + boost::beast::http::response response, + boost::asio::yield_context yield, + std::chrono::steady_clock::duration timeout = DEFAULT_TIMEOUT + ) = 0; }; using UpgradableConnectionPtr = std::unique_ptr; @@ -106,21 +113,31 @@ class HttpConnection : public UpgradableConnection { } std::optional - send( - Response response, + sendRaw( + boost::beast::http::response response, boost::asio::yield_context yield, std::chrono::steady_clock::duration timeout = DEFAULT_TIMEOUT ) override { - auto const httpResponse = std::move(response).intoHttpResponse(); boost::system::error_code error; boost::beast::get_lowest_layer(stream_).expires_after(timeout); - boost::beast::http::async_write(stream_, httpResponse, yield[error]); + boost::beast::http::async_write(stream_, response, yield[error]); if (error) return error; return std::nullopt; } + std::optional + send( + Response response, + boost::asio::yield_context yield, + std::chrono::steady_clock::duration timeout = DEFAULT_TIMEOUT + ) override + { + auto httpResponse = std::move(response).intoHttpResponse(); + return sendRaw(std::move(httpResponse), yield, timeout); + } + std::expected receive(boost::asio::yield_context yield, std::chrono::steady_clock::duration timeout = DEFAULT_TIMEOUT) override { diff --git a/tests/common/web/dosguard/DOSGuardMock.hpp b/tests/common/web/dosguard/DOSGuardMock.hpp new file mode 100644 index 000000000..a465d8b7f --- /dev/null +++ b/tests/common/web/dosguard/DOSGuardMock.hpp @@ -0,0 +1,41 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include "web/dosguard/DOSGuardInterface.hpp" + +#include + +#include +#include +#include + +struct DOSGuardMockImpl : web::dosguard::DOSGuardInterface { + MOCK_METHOD(bool, isWhiteListed, (std::string_view const ip), (const, noexcept, override)); + MOCK_METHOD(bool, isOk, (std::string const& ip), (const, noexcept, override)); + MOCK_METHOD(void, increment, (std::string const& ip), (noexcept, override)); + MOCK_METHOD(void, decrement, (std::string const& ip), (noexcept, override)); + MOCK_METHOD(bool, add, (std::string const& ip, uint32_t size), (noexcept, override)); + MOCK_METHOD(bool, request, (std::string const& ip), (noexcept, override)); + MOCK_METHOD(void, clear, (), (noexcept, override)); +}; + +using DOSGuardMock = testing::NiceMock; +using DOSGuardStrictMock = testing::StrictMock; diff --git a/tests/common/web/ng/MockConnection.hpp b/tests/common/web/ng/MockConnection.hpp index c4f015728..96ab959f5 100644 --- a/tests/common/web/ng/MockConnection.hpp +++ b/tests/common/web/ng/MockConnection.hpp @@ -19,7 +19,6 @@ #pragma once -#include "util/Taggable.hpp" #include "web/ng/Connection.hpp" #include "web/ng/Error.hpp" #include "web/ng/Request.hpp" diff --git a/tests/common/web/ng/impl/MockHttpConnection.hpp b/tests/common/web/ng/impl/MockHttpConnection.hpp index 53f204d08..fdb02929e 100644 --- a/tests/common/web/ng/impl/MockHttpConnection.hpp +++ b/tests/common/web/ng/impl/MockHttpConnection.hpp @@ -29,6 +29,8 @@ #include #include +#include +#include #include #include @@ -48,6 +50,15 @@ struct MockHttpConnectionImpl : web::ng::impl::UpgradableConnection { (override) ); + MOCK_METHOD( + SendReturnType, + sendRaw, + (boost::beast::http::response, + boost::asio::yield_context, + std::chrono::steady_clock::duration), + (override) + ); + using ReceiveReturnType = std::expected; MOCK_METHOD( ReceiveReturnType, diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index b454474d4..be3e3cda8 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -5,6 +5,7 @@ target_sources( PRIVATE # Common ConfigTests.cpp app/CliArgsTests.cpp + app/WebHandlersTests.cpp data/AmendmentCenterTests.cpp data/BackendCountersTests.cpp data/BackendInterfaceTests.cpp diff --git a/tests/unit/app/WebHandlersTests.cpp b/tests/unit/app/WebHandlersTests.cpp new file mode 100644 index 000000000..4c5e6b34d --- /dev/null +++ b/tests/unit/app/WebHandlersTests.cpp @@ -0,0 +1,321 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2024, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "app/WebHandlers.hpp" +#include "util/AsioContextTestFixture.hpp" +#include "util/LoggerFixtures.hpp" +#include "util/MockPrometheus.hpp" +#include "util/Taggable.hpp" +#include "util/config/Config.hpp" +#include "web/AdminVerificationStrategy.hpp" +#include "web/SubscriptionContextInterface.hpp" +#include "web/dosguard/DOSGuardMock.hpp" +#include "web/ng/Connection.hpp" +#include "web/ng/MockConnection.hpp" +#include "web/ng/Request.hpp" +#include "web/ng/Response.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +using namespace app; +namespace http = boost::beast::http; + +struct WebHandlersTest : virtual NoLoggerFixture { + DOSGuardStrictMock dosGuardMock_; + util::TagDecoratorFactory tagFactory_{util::Config{}}; + std::string const ip_ = "some ip"; + StrictMockConnection connectionMock_{ip_, boost::beast::flat_buffer{}, tagFactory_}; + + struct AdminVerificationStrategyMock : web::AdminVerificationStrategy { + MOCK_METHOD(bool, isAdmin, (RequestHeader const&, std::string_view), (const, override)); + }; + using AdminVerificationStrategyStrictMockPtr = std::shared_ptr>; +}; + +struct OnConnectCheckTests : WebHandlersTest { + OnConnectCheck onConnectCheck_{dosGuardMock_}; +}; + +TEST_F(OnConnectCheckTests, Ok) +{ + EXPECT_CALL(dosGuardMock_, increment(ip_)); + EXPECT_CALL(dosGuardMock_, isOk(ip_)).WillOnce(testing::Return(true)); + EXPECT_TRUE(onConnectCheck_(connectionMock_).has_value()); +} + +TEST_F(OnConnectCheckTests, RateLimited) +{ + EXPECT_CALL(dosGuardMock_, increment(ip_)); + EXPECT_CALL(dosGuardMock_, isOk(ip_)).WillOnce(testing::Return(false)); + EXPECT_CALL(connectionMock_, wasUpgraded).WillOnce(testing::Return(false)); + + auto response = onConnectCheck_(connectionMock_); + ASSERT_FALSE(response.has_value()); + auto const httpResponse = std::move(response).error().intoHttpResponse(); + EXPECT_EQ(httpResponse.result(), boost::beast::http::status::too_many_requests); + EXPECT_EQ(httpResponse.body(), "Too many requests"); +} + +struct DisconnectHookTests : WebHandlersTest { + DisconnectHook disconnectHook_{dosGuardMock_}; +}; + +TEST_F(DisconnectHookTests, CallsDecrement) +{ + EXPECT_CALL(dosGuardMock_, decrement(ip_)); + disconnectHook_(connectionMock_); +} + +struct MetricsHandlerTests : util::prometheus::WithPrometheus, SyncAsioContextTest, WebHandlersTest { + AdminVerificationStrategyStrictMockPtr adminVerifier_{ + std::make_shared>() + }; + + MetricsHandler metricsHandler_{adminVerifier_}; + web::ng::Request request_{http::request{http::verb::get, "/metrics", 11}}; +}; + +TEST_F(MetricsHandlerTests, Call) +{ + EXPECT_CALL(*adminVerifier_, isAdmin).WillOnce(testing::Return(true)); + runSpawn([&](boost::asio::yield_context yield) { + auto response = metricsHandler_(request_, connectionMock_, nullptr, yield); + auto const httpResponse = std::move(response).intoHttpResponse(); + EXPECT_EQ(httpResponse.result(), boost::beast::http::status::ok); + }); +} + +struct HealthCheckHandlerTests : SyncAsioContextTest, WebHandlersTest { + web::ng::Request request_{http::request{http::verb::get, "/", 11}}; + HealthCheckHandler healthCheckHandler_; +}; + +TEST_F(HealthCheckHandlerTests, Call) +{ + runSpawn([&](boost::asio::yield_context yield) { + auto response = healthCheckHandler_(request_, connectionMock_, nullptr, yield); + auto const httpResponse = std::move(response).intoHttpResponse(); + EXPECT_EQ(httpResponse.result(), boost::beast::http::status::ok); + }); +} + +struct RequestHandlerTest : SyncAsioContextTest, WebHandlersTest { + AdminVerificationStrategyStrictMockPtr adminVerifier_{ + std::make_shared>() + }; + + struct RpcHandlerMock { + MOCK_METHOD( + web::ng::Response, + call, + (web::ng::Request const&, + web::ng::ConnectionMetadata const&, + web::SubscriptionContextPtr, + boost::asio::yield_context), + () + ); + + web::ng::Response + operator()( + web::ng::Request const& request, + web::ng::ConnectionMetadata const& connectionMetadata, + web::SubscriptionContextPtr subscriptionContext, + boost::asio::yield_context yield + ) + { + return call(request, connectionMetadata, std::move(subscriptionContext), yield); + } + }; + + testing::StrictMock rpcHandler_; + StrictMockConnection connectionMock_{ip_, boost::beast::flat_buffer{}, tagFactory_}; + RequestHandler requestHandler_{adminVerifier_, rpcHandler_, dosGuardMock_}; +}; + +TEST_F(RequestHandlerTest, DosguardRateLimited_Http) +{ + web::ng::Request const request{http::request{http::verb::get, "/", 11}}; + + EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(false)); + + runSpawn([&](boost::asio::yield_context yield) { + auto response = requestHandler_(request, connectionMock_, nullptr, yield); + auto const httpResponse = std::move(response).intoHttpResponse(); + + EXPECT_EQ(httpResponse.result(), boost::beast::http::status::service_unavailable); + + auto const body = boost::json::parse(httpResponse.body()).as_object(); + EXPECT_EQ(body.at("error").as_string(), "slowDown"); + EXPECT_EQ(body.at("error_code").as_int64(), 10); + EXPECT_EQ(body.at("status").as_string(), "error"); + EXPECT_FALSE(body.contains("id")); + EXPECT_FALSE(body.contains("request")); + }); +} + +TEST_F(RequestHandlerTest, DosguardRateLimited_Ws) +{ + auto const requestMessage = R"json({"some": "request", "id": "some id"})json"; + web::ng::Request::HttpHeaders const headers{}; + web::ng::Request const request{requestMessage, headers}; + + EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(false)); + + runSpawn([&](boost::asio::yield_context yield) { + auto const response = requestHandler_(request, connectionMock_, nullptr, yield); + auto const message = boost::json::parse(response.message()).as_object(); + + EXPECT_EQ(message.at("error").as_string(), "slowDown"); + EXPECT_EQ(message.at("error_code").as_int64(), 10); + EXPECT_EQ(message.at("status").as_string(), "error"); + EXPECT_EQ(message.at("id").as_string(), "some id"); + EXPECT_EQ(message.at("request").as_string(), requestMessage); + }); +} + +TEST_F(RequestHandlerTest, DosguardRateLimited_Ws_ErrorParsing) +{ + auto const requestMessage = R"json(some request "id": "some id")json"; + web::ng::Request::HttpHeaders const headers{}; + web::ng::Request const request{requestMessage, headers}; + + EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(false)); + + runSpawn([&](boost::asio::yield_context yield) { + auto const response = requestHandler_(request, connectionMock_, nullptr, yield); + auto const message = boost::json::parse(response.message()).as_object(); + + EXPECT_EQ(message.at("error").as_string(), "slowDown"); + EXPECT_EQ(message.at("error_code").as_int64(), 10); + EXPECT_EQ(message.at("status").as_string(), "error"); + EXPECT_FALSE(message.contains("id")); + EXPECT_EQ(message.at("request").as_string(), requestMessage); + }); +} + +TEST_F(RequestHandlerTest, RpcHandlerThrows) +{ + web::ng::Request const request{http::request{http::verb::get, "/", 11}}; + + EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(true)); + EXPECT_CALL(*adminVerifier_, isAdmin).WillOnce(testing::Return(true)); + EXPECT_CALL(rpcHandler_, call).WillOnce(testing::Throw(std::runtime_error{"some error"})); + + runSpawn([&](boost::asio::yield_context yield) { + auto response = requestHandler_(request, connectionMock_, nullptr, yield); + + auto const httpResponse = std::move(response).intoHttpResponse(); + + EXPECT_EQ(httpResponse.result(), boost::beast::http::status::internal_server_error); + + auto const body = boost::json::parse(httpResponse.body()).as_object(); + EXPECT_EQ(body.at("error").as_string(), "internal"); + EXPECT_EQ(body.at("error_code").as_int64(), 73); + EXPECT_EQ(body.at("status").as_string(), "error"); + }); +} + +TEST_F(RequestHandlerTest, NoErrors) +{ + web::ng::Request const request{http::request{http::verb::get, "/", 11}}; + web::ng::Response const response{http::status::ok, "some response", request}; + auto const httpResponse = web::ng::Response{response}.intoHttpResponse(); + + EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(true)); + EXPECT_CALL(*adminVerifier_, isAdmin).WillOnce(testing::Return(true)); + EXPECT_CALL(rpcHandler_, call).WillOnce(testing::Return(response)); + EXPECT_CALL(dosGuardMock_, add(ip_, testing::_)).WillOnce(testing::Return(true)); + + runSpawn([&](boost::asio::yield_context yield) { + auto actualResponse = requestHandler_(request, connectionMock_, nullptr, yield); + + auto const actualHttpResponse = std::move(actualResponse).intoHttpResponse(); + + EXPECT_EQ(actualHttpResponse.result(), httpResponse.result()); + EXPECT_EQ(actualHttpResponse.body(), httpResponse.body()); + EXPECT_EQ(actualHttpResponse.version(), 11); + }); +} + +TEST_F(RequestHandlerTest, ResponseDosGuardWarning_ResponseHasWarnings) +{ + web::ng::Request const request{http::request{http::verb::get, "/", 11}}; + web::ng::Response const response{ + http::status::ok, R"json({"some":"response", "warnings":["some warning"]})json", request + }; + auto const httpResponse = web::ng::Response{response}.intoHttpResponse(); + + EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(true)); + EXPECT_CALL(*adminVerifier_, isAdmin).WillOnce(testing::Return(true)); + EXPECT_CALL(rpcHandler_, call).WillOnce(testing::Return(response)); + EXPECT_CALL(dosGuardMock_, add(ip_, testing::_)).WillOnce(testing::Return(false)); + + runSpawn([&](boost::asio::yield_context yield) { + auto actualResponse = requestHandler_(request, connectionMock_, nullptr, yield); + + auto const actualHttpResponse = std::move(actualResponse).intoHttpResponse(); + + EXPECT_EQ(actualHttpResponse.result(), httpResponse.result()); + EXPECT_EQ(actualHttpResponse.version(), 11); + + auto actualBody = boost::json::parse(actualHttpResponse.body()).as_object(); + EXPECT_EQ(actualBody.at("some").as_string(), "response"); + EXPECT_EQ(actualBody.at("warnings").as_array().size(), 2); + }); +} + +TEST_F(RequestHandlerTest, ResponseDosGuardWarning_ResponseDoesntHaveWarnings) +{ + web::ng::Request const request{http::request{http::verb::get, "/", 11}}; + web::ng::Response const response{http::status::ok, R"json({"some":"response"})json", request}; + auto const httpResponse = web::ng::Response{response}.intoHttpResponse(); + + EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(true)); + EXPECT_CALL(*adminVerifier_, isAdmin).WillOnce(testing::Return(true)); + EXPECT_CALL(rpcHandler_, call).WillOnce(testing::Return(response)); + EXPECT_CALL(dosGuardMock_, add(ip_, testing::_)).WillOnce(testing::Return(false)); + + runSpawn([&](boost::asio::yield_context yield) { + auto actualResponse = requestHandler_(request, connectionMock_, nullptr, yield); + + auto const actualHttpResponse = std::move(actualResponse).intoHttpResponse(); + + EXPECT_EQ(actualHttpResponse.result(), httpResponse.result()); + EXPECT_EQ(actualHttpResponse.version(), 11); + + auto actualBody = boost::json::parse(actualHttpResponse.body()).as_object(); + EXPECT_EQ(actualBody.at("some").as_string(), "response"); + EXPECT_EQ(actualBody.at("warnings").as_array().size(), 1); + }); +} diff --git a/tests/unit/web/dosguard/IntervalSweepHandlerTests.cpp b/tests/unit/web/dosguard/IntervalSweepHandlerTests.cpp index 9c9127047..9ca842ba9 100644 --- a/tests/unit/web/dosguard/IntervalSweepHandlerTests.cpp +++ b/tests/unit/web/dosguard/IntervalSweepHandlerTests.cpp @@ -19,7 +19,7 @@ #include "util/AsioContextTestFixture.hpp" #include "util/config/Config.hpp" -#include "web/dosguard/DOSGuardInterface.hpp" +#include "web/dosguard/DOSGuardMock.hpp" #include "web/dosguard/IntervalSweepHandler.hpp" #include @@ -40,10 +40,7 @@ struct IntervalSweepHandlerTest : SyncAsioContextTest { } )JSON"; - struct DosGuardMock : BaseDOSGuard { - MOCK_METHOD(void, clear, (), (noexcept, override)); - }; - testing::StrictMock guardMock; + DOSGuardStrictMock guardMock; util::Config cfg{boost::json::parse(JSONData)}; IntervalSweepHandler sweepHandler{cfg, ctx, guardMock}; diff --git a/tests/unit/web/ng/RequestTests.cpp b/tests/unit/web/ng/RequestTests.cpp index fc3e5da2f..b45c55ce2 100644 --- a/tests/unit/web/ng/RequestTests.cpp +++ b/tests/unit/web/ng/RequestTests.cpp @@ -34,7 +34,10 @@ using namespace web::ng; namespace http = boost::beast::http; -struct RequestTest : public ::testing::Test {}; +struct RequestTest : public ::testing::Test { + static Request::HttpHeaders const headers_; +}; +Request::HttpHeaders const RequestTest::headers_ = {}; struct RequestMethodTestBundle { std::string testName; @@ -65,7 +68,7 @@ INSTANTIATE_TEST_SUITE_P( }, RequestMethodTestBundle{ .testName = "WebSocket", - .request = Request{"websocket message", Request::HttpHeaders{}}, + .request = Request{"websocket message", RequestTest::headers_}, .expectedMethod = Request::Method::Websocket, }, RequestMethodTestBundle{ @@ -101,7 +104,7 @@ INSTANTIATE_TEST_SUITE_P( }, RequestIsHttpTestBundle{ .testName = "WebSocketRequest", - .request = Request{"websocket message", Request::HttpHeaders{}}, + .request = Request{"websocket message", RequestTest::headers_}, .expectedIsHttp = false, } ), @@ -124,7 +127,7 @@ TEST_F(RequestAsHttpRequestTest, HttpRequest) TEST_F(RequestAsHttpRequestTest, WebSocketRequest) { - Request const request{"websocket message", Request::HttpHeaders{}}; + Request const request{"websocket message", RequestTest::headers_}; auto const maybeHttpRequest = request.asHttpRequest(); EXPECT_FALSE(maybeHttpRequest.has_value()); } @@ -142,7 +145,7 @@ TEST_F(RequestMessageTest, HttpRequest) TEST_F(RequestMessageTest, WebSocketRequest) { std::string const message = "websocket message"; - Request const request{message, Request::HttpHeaders{}}; + Request const request{message, RequestTest::headers_}; EXPECT_EQ(request.message(), message); } @@ -171,7 +174,7 @@ INSTANTIATE_TEST_SUITE_P( }, RequestTargetTestBundle{ .testName = "WebSocketRequest", - .request = Request{"websocket message", Request::HttpHeaders{}}, + .request = Request{"websocket message", RequestTest::headers_}, .expectedTarget = std::nullopt, } ), diff --git a/tests/unit/web/ng/ResponseTests.cpp b/tests/unit/web/ng/ResponseTests.cpp index 754f41357..6e2e1021b 100644 --- a/tests/unit/web/ng/ResponseTests.cpp +++ b/tests/unit/web/ng/ResponseTests.cpp @@ -17,10 +17,14 @@ */ //============================================================================== +#include "util/Taggable.hpp" #include "util/build/Build.hpp" +#include "util/config/Config.hpp" +#include "web/ng/MockConnection.hpp" #include "web/ng/Request.hpp" #include "web/ng/Response.hpp" +#include #include #include #include @@ -29,6 +33,7 @@ #include #include #include +#include #include #include @@ -41,21 +46,23 @@ struct ResponseDeathTest : testing::Test {}; TEST_F(ResponseDeathTest, intoHttpResponseWithoutHttpData) { - Request const request{"some messsage", Request::HttpHeaders{}}; - web::ng::Response response{boost::beast::http::status::ok, "message", request}; + Request::HttpHeaders const headers{}; + Request const request{"some message", headers}; + Response response{boost::beast::http::status::ok, "message", request}; EXPECT_DEATH(std::move(response).intoHttpResponse(), ""); } TEST_F(ResponseDeathTest, asConstBufferWithHttpData) { Request const request{http::request{http::verb::get, "/", 11}}; - web::ng::Response const response{boost::beast::http::status::ok, "message", request}; + Response const response{boost::beast::http::status::ok, "message", request}; EXPECT_DEATH(response.asWsResponse(), ""); } struct ResponseTest : testing::Test { int const httpVersion_ = 11; http::status const responseStatus_ = http::status::ok; + Request::HttpHeaders const headers_; }; TEST_F(ResponseTest, intoHttpResponse) @@ -63,7 +70,7 @@ TEST_F(ResponseTest, intoHttpResponse) Request const request{http::request{http::verb::post, "/", httpVersion_, "some message"}}; std::string const responseMessage = "response message"; - web::ng::Response response{responseStatus_, responseMessage, request}; + Response response{responseStatus_, responseMessage, request}; auto const httpResponse = std::move(response).intoHttpResponse(); EXPECT_EQ(httpResponse.result(), responseStatus_); @@ -83,7 +90,7 @@ TEST_F(ResponseTest, intoHttpResponseJson) Request const request{http::request{http::verb::post, "/", httpVersion_, "some message"}}; boost::json::object const responseMessage{{"key", "value"}}; - web::ng::Response response{responseStatus_, responseMessage, request}; + Response response{responseStatus_, responseMessage, request}; auto const httpResponse = std::move(response).intoHttpResponse(); EXPECT_EQ(httpResponse.result(), responseStatus_); @@ -100,9 +107,9 @@ TEST_F(ResponseTest, intoHttpResponseJson) TEST_F(ResponseTest, asConstBuffer) { - Request const request("some request", Request::HttpHeaders{}); + Request const request("some request", headers_); std::string const responseMessage = "response message"; - web::ng::Response const response{responseStatus_, responseMessage, request}; + Response const response{responseStatus_, responseMessage, request}; auto const buffer = response.asWsResponse(); EXPECT_EQ(buffer.size(), responseMessage.size()); @@ -113,9 +120,9 @@ TEST_F(ResponseTest, asConstBuffer) TEST_F(ResponseTest, asConstBufferJson) { - Request const request("some request", Request::HttpHeaders{}); + Request const request("some request", headers_); boost::json::object const responseMessage{{"key", "value"}}; - web::ng::Response const response{responseStatus_, responseMessage, request}; + Response const response{responseStatus_, responseMessage, request}; auto const buffer = response.asWsResponse(); EXPECT_EQ(buffer.size(), boost::json::serialize(responseMessage).size()); @@ -123,3 +130,88 @@ TEST_F(ResponseTest, asConstBufferJson) std::string const messageFromBuffer{static_cast(buffer.data()), buffer.size()}; EXPECT_EQ(messageFromBuffer, boost::json::serialize(responseMessage)); } + +TEST_F(ResponseTest, createFromStringAndConnection) +{ + util::TagDecoratorFactory tagDecoratorFactory{util::Config{}}; + StrictMockConnection connection{"some ip", boost::beast::flat_buffer{}, tagDecoratorFactory}; + std::string const responseMessage = "response message"; + + EXPECT_CALL(connection, wasUpgraded()).WillOnce(testing::Return(false)); + Response response{responseStatus_, responseMessage, connection}; + + EXPECT_EQ(response.message(), responseMessage); + auto const httpResponse = std::move(response).intoHttpResponse(); + EXPECT_EQ(httpResponse.result(), responseStatus_); + auto const it = httpResponse.find(http::field::content_type); + ASSERT_NE(it, httpResponse.end()); + EXPECT_EQ(it->value(), "text/html"); +} + +TEST_F(ResponseTest, createFromJsonAndConnection) +{ + util::TagDecoratorFactory tagDecoratorFactory{util::Config{}}; + StrictMockConnection connection{"some ip", boost::beast::flat_buffer{}, tagDecoratorFactory}; + boost::json::object const responseMessage{{"key", "value"}}; + + EXPECT_CALL(connection, wasUpgraded()).WillOnce(testing::Return(false)); + Response response{responseStatus_, responseMessage, connection}; + + EXPECT_EQ(response.message(), boost::json::serialize(responseMessage)); + auto const httpResponse = std::move(response).intoHttpResponse(); + EXPECT_EQ(httpResponse.result(), responseStatus_); + auto const it = httpResponse.find(http::field::content_type); + ASSERT_NE(it, httpResponse.end()); + EXPECT_EQ(it->value(), "application/json"); +} + +TEST_F(ResponseTest, setMessageString_HttpResponse) +{ + Request const request{http::request{http::verb::post, "/", httpVersion_, "some request"}}; + Response response{boost::beast::http::status::ok, "message", request}; + + std::string const newMessage = "new message"; + response.setMessage(newMessage); + + EXPECT_EQ(response.message(), newMessage); + auto const httpResponse = std::move(response).intoHttpResponse(); + auto it = httpResponse.find(http::field::content_type); + ASSERT_NE(it, httpResponse.end()); + EXPECT_EQ(it->value(), "text/html"); +} + +TEST_F(ResponseTest, setMessageString_WsResponse) +{ + Request const request{"some request", headers_}; + Response response{boost::beast::http::status::ok, "message", request}; + + std::string const newMessage = "new message"; + response.setMessage(newMessage); + + EXPECT_EQ(response.message(), newMessage); +} + +TEST_F(ResponseTest, setMessageJson_HttpResponse) +{ + Request const request{http::request{http::verb::post, "/", httpVersion_, "some request"}}; + Response response{boost::beast::http::status::ok, "message", request}; + + boost::json::object const newMessage{{"key", "value"}}; + response.setMessage(newMessage); + + auto const httpResponse = std::move(response).intoHttpResponse(); + auto it = httpResponse.find(http::field::content_type); + ASSERT_NE(it, httpResponse.end()); + EXPECT_EQ(it->value(), "application/json"); +} + +TEST_F(ResponseTest, setMessageJson_WsResponse) +{ + Request const request{"some request", headers_}; + Response response{boost::beast::http::status::ok, "message", request}; + + boost::json::object const newMessage{{"key", "value"}}; + response.setMessage(newMessage); + + EXPECT_EQ(response.message(), boost::json::serialize(newMessage)); +} diff --git a/tests/unit/web/ng/ServerTests.cpp b/tests/unit/web/ng/ServerTests.cpp index 3a0e524e9..e58058a3f 100644 --- a/tests/unit/web/ng/ServerTests.cpp +++ b/tests/unit/web/ng/ServerTests.cpp @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -68,7 +69,8 @@ struct MakeServerTest : NoLoggerFixture, testing::WithParamInterface std::expected { return {}; }, [](auto&&) {}, ioContext_); EXPECT_EQ(expectedServer.has_value(), GetParam().expectSuccess); } @@ -159,7 +161,9 @@ struct ServerTest : SyncAsioContextTest { boost::json::object{{"server", boost::json::object{{"ip", "127.0.0.1"}, {"port", serverPort_}}}} }; - std::expected server_ = make_Server(config_, ctx); + Server::OnConnectCheck emptyOnConnectCheck_ = [](auto&&) -> std::expected { return {}; }; + + std::expected server_ = make_Server(config_, emptyOnConnectCheck_, [](auto&&) {}, ctx); std::string requestMessage_ = "some request"; std::string const headerName_ = "Some-header"; @@ -181,8 +185,17 @@ TEST_F(ServerTest, BadEndpoint) boost::asio::ip::tcp::endpoint const endpoint{boost::asio::ip::address_v4::from_string("1.2.3.4"), 0}; util::TagDecoratorFactory const tagDecoratorFactory{util::Config{boost::json::value{}}}; Server server{ - ctx, endpoint, std::nullopt, ProcessingPolicy::Sequential, std::nullopt, tagDecoratorFactory, std::nullopt + ctx, + endpoint, + std::nullopt, + ProcessingPolicy::Sequential, + std::nullopt, + tagDecoratorFactory, + std::nullopt, + emptyOnConnectCheck_, + [](auto&&) {} }; + auto maybeError = server.run(); ASSERT_TRUE(maybeError.has_value()); EXPECT_THAT(*maybeError, testing::HasSubstr("Error creating TCP acceptor")); @@ -224,6 +237,170 @@ TEST_F(ServerHttpTest, ClientDisconnects) runContext(); } +TEST_F(ServerHttpTest, OnConnectCheck) +{ + auto const serverPort = tests::util::generateFreePort(); + boost::asio::ip::tcp::endpoint const endpoint{boost::asio::ip::address_v4::from_string("0.0.0.0"), serverPort}; + util::TagDecoratorFactory const tagDecoratorFactory{util::Config{boost::json::value{}}}; + + testing::StrictMock(Connection const&)>> onConnectCheck; + + Server server{ + ctx, + endpoint, + std::nullopt, + ProcessingPolicy::Sequential, + std::nullopt, + tagDecoratorFactory, + std::nullopt, + onConnectCheck.AsStdFunction(), + [](auto&&) {} + }; + + HttpAsyncClient client{ctx}; + + boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) { + boost::asio::steady_timer timer{yield.get_executor()}; + + EXPECT_CALL(onConnectCheck, Call) + .WillOnce([&timer](Connection const& connection) -> std::expected { + EXPECT_EQ(connection.ip(), "127.0.0.1"); + timer.cancel(); + return {}; + }); + + auto maybeError = + client.connect("127.0.0.1", std::to_string(serverPort), yield, std::chrono::milliseconds{100}); + [&]() { ASSERT_FALSE(maybeError.has_value()) << maybeError->message(); }(); + + // Have to send a request here because the server does async_detect_ssl() which waits for some data to appear + client.send( + http::request{http::verb::get, "/", 11, requestMessage_}, + yield, + std::chrono::milliseconds{100} + ); + + // Wait for the onConnectCheck to be called + timer.expires_after(std::chrono::milliseconds{100}); + boost::system::error_code error; // Unused + timer.async_wait(yield[error]); + + client.gracefulShutdown(); + ctx.stop(); + }); + + server.run(); + + runContext(); +} + +TEST_F(ServerHttpTest, OnConnectCheckFailed) +{ + auto const serverPort = tests::util::generateFreePort(); + boost::asio::ip::tcp::endpoint const endpoint{boost::asio::ip::address_v4::from_string("0.0.0.0"), serverPort}; + util::TagDecoratorFactory const tagDecoratorFactory{util::Config{boost::json::value{}}}; + + testing::StrictMock(Connection const&)>> onConnectCheck; + + Server server{ + ctx, + endpoint, + std::nullopt, + ProcessingPolicy::Sequential, + std::nullopt, + tagDecoratorFactory, + std::nullopt, + onConnectCheck.AsStdFunction(), + [](auto&&) {} + }; + + HttpAsyncClient client{ctx}; + + EXPECT_CALL(onConnectCheck, Call).WillOnce([](Connection const& connection) { + EXPECT_EQ(connection.ip(), "127.0.0.1"); + return std::unexpected{ + Response{http::status::too_many_requests, boost::json::object{{"error", "some error"}}, connection} + }; + }); + + boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) { + auto maybeError = + client.connect("127.0.0.1", std::to_string(serverPort), yield, std::chrono::milliseconds{100}); + [&]() { ASSERT_FALSE(maybeError.has_value()) << maybeError->message(); }(); + + // Have to send a request here because the server does async_detect_ssl() which waits for some data to appear + client.send( + http::request{http::verb::get, "/", 11, requestMessage_}, + yield, + std::chrono::milliseconds{100} + ); + + auto const response = client.receive(yield, std::chrono::milliseconds{100}); + [&]() { ASSERT_TRUE(response.has_value()) << response.error().message(); }(); + EXPECT_EQ(response->result(), http::status::too_many_requests); + EXPECT_EQ(response->body(), R"json({"error":"some error"})json"); + EXPECT_EQ(response->version(), 11); + + client.gracefulShutdown(); + ctx.stop(); + }); + + server.run(); + + runContext(); +} + +TEST_F(ServerHttpTest, OnDisconnectHook) +{ + auto const serverPort = tests::util::generateFreePort(); + boost::asio::ip::tcp::endpoint const endpoint{boost::asio::ip::address_v4::from_string("0.0.0.0"), serverPort}; + util::TagDecoratorFactory const tagDecoratorFactory{util::Config{boost::json::value{}}}; + + testing::StrictMock> OnDisconnectHookMock; + + Server server{ + ctx, + endpoint, + std::nullopt, + ProcessingPolicy::Sequential, + std::nullopt, + tagDecoratorFactory, + std::nullopt, + emptyOnConnectCheck_, + OnDisconnectHookMock.AsStdFunction() + }; + + HttpAsyncClient client{ctx}; + + boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) { + boost::asio::steady_timer timer{ctx.get_executor(), std::chrono::milliseconds{100}}; + + EXPECT_CALL(OnDisconnectHookMock, Call).WillOnce([&timer](auto&&) { timer.cancel(); }); + + auto maybeError = + client.connect("127.0.0.1", std::to_string(serverPort), yield, std::chrono::milliseconds{100}); + [&]() { ASSERT_FALSE(maybeError.has_value()) << maybeError->message(); }(); + + client.send( + http::request{http::verb::get, "/", 11, requestMessage_}, + yield, + std::chrono::milliseconds{100} + ); + + client.gracefulShutdown(); + + // Wait for OnDisconnectHook is called + boost::system::error_code error; + timer.async_wait(yield[error]); + + ctx.stop(); + }); + + server.run(); + + runContext(); +} + TEST_P(ServerHttpTest, RequestResponse) { HttpAsyncClient client{ctx}; @@ -300,7 +477,8 @@ TEST_F(ServerTest, WsRequestResponse) { WebSocketAsyncClient client{ctx}; - Response const response{http::status::ok, "some response", Request{requestMessage_, Request::HttpHeaders{}}}; + Request::HttpHeaders const headers{}; + Response const response{http::status::ok, "some response", Request{requestMessage_, headers}}; boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) { auto maybeError = diff --git a/tests/unit/web/ng/impl/ConnectionHandlerTests.cpp b/tests/unit/web/ng/impl/ConnectionHandlerTests.cpp index aa75217cc..c7ef7336c 100644 --- a/tests/unit/web/ng/impl/ConnectionHandlerTests.cpp +++ b/tests/unit/web/ng/impl/ConnectionHandlerTests.cpp @@ -64,7 +64,14 @@ namespace websocket = boost::beast::websocket; struct ConnectionHandlerTest : SyncAsioContextTest { ConnectionHandlerTest(ProcessingPolicy policy, std::optional maxParallelConnections) - : tagFactory_{util::Config{}}, connectionHandler_{policy, maxParallelConnections, tagFactory_, std::nullopt} + : tagFactory_{util::Config{}} + , connectionHandler_{ + policy, + maxParallelConnections, + tagFactory_, + std::nullopt, + onDisconnectMock_.AsStdFunction() + } { } @@ -93,6 +100,7 @@ struct ConnectionHandlerTest : SyncAsioContextTest { return Request{std::forward(args)...}; } + testing::StrictMock> onDisconnectMock_; util::TagDecoratorFactory tagFactory_; ConnectionHandler connectionHandler_; @@ -101,6 +109,8 @@ struct ConnectionHandlerTest : SyncAsioContextTest { std::make_unique("1.2.3.4", beast::flat_buffer{}, tagDecoratorFactory_); StrictMockWsConnectionPtr mockWsConnection_ = std::make_unique("1.2.3.4", beast::flat_buffer{}, tagDecoratorFactory_); + + Request::HttpHeaders headers_; }; struct ConnectionHandlerSequentialProcessingTest : ConnectionHandlerTest { @@ -113,6 +123,9 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, ReceiveError) { EXPECT_CALL(*mockHttpConnection_, wasUpgraded).WillOnce(Return(false)); EXPECT_CALL(*mockHttpConnection_, receive).WillOnce(Return(makeError(http::error::end_of_stream))); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockHttpConnection_), yield); @@ -124,6 +137,9 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, ReceiveError_CloseConnection) EXPECT_CALL(*mockHttpConnection_, wasUpgraded).WillOnce(Return(false)); EXPECT_CALL(*mockHttpConnection_, receive).WillOnce(Return(makeError(boost::asio::error::timed_out))); EXPECT_CALL(*mockHttpConnection_, close); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockHttpConnection_), yield); @@ -134,7 +150,7 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_NoHandler_Send) { EXPECT_CALL(*mockHttpConnection_, wasUpgraded).WillOnce(Return(false)); EXPECT_CALL(*mockHttpConnection_, receive) - .WillOnce(Return(makeRequest("some_request", Request::HttpHeaders{}))) + .WillOnce(Return(makeRequest("some_request", headers_))) .WillOnce(Return(makeError(websocket::error::closed))); EXPECT_CALL(*mockHttpConnection_, send).WillOnce([](Response response, auto&&, auto&&) { @@ -142,6 +158,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_NoHandler_Send) return std::nullopt; }); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockHttpConnection_), yield); }); @@ -165,6 +185,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_BadTarget_Send) return std::nullopt; }); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockHttpConnection_), yield); }); @@ -182,6 +206,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_BadMethod_Send) return std::nullopt; }); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockHttpConnection_), yield); }); @@ -199,7 +227,7 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_Send) EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true)); EXPECT_CALL(*mockWsConnection_, receive) - .WillOnce(Return(makeRequest(requestMessage, Request::HttpHeaders{}))) + .WillOnce(Return(makeRequest(requestMessage, headers_))) .WillOnce(Return(makeError(websocket::error::closed))); EXPECT_CALL(wsHandlerMock, Call).WillOnce([&](Request const& request, auto&&, auto&&, auto&&) { @@ -212,6 +240,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_Send) return std::nullopt; }); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockWsConnection_), yield); }); @@ -228,7 +260,7 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, SendSubscriptionMessage) EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true)); EXPECT_CALL(*mockWsConnection_, receive) - .WillOnce(Return(makeRequest("", Request::HttpHeaders{}))) + .WillOnce(Return(makeRequest("", headers_))) .WillOnce(Return(makeError(websocket::error::closed))); EXPECT_CALL(wsHandlerMock, Call) @@ -246,6 +278,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, SendSubscriptionMessage) return std::nullopt; }); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockWsConnection_), yield); }); @@ -262,7 +298,7 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, SubscriptionContextIsDisconnec EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true)); testing::Expectation const expectationReceiveCalled = EXPECT_CALL(*mockWsConnection_, receive) - .WillOnce(Return(makeRequest("", Request::HttpHeaders{}))) + .WillOnce(Return(makeRequest("", headers_))) .WillOnce(Return(makeError(websocket::error::closed))); EXPECT_CALL(wsHandlerMock, Call) @@ -276,6 +312,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, SubscriptionContextIsDisconnec EXPECT_CALL(onDisconnectHook, Call).After(expectationReceiveCalled); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockWsConnection_), yield); }); @@ -314,6 +354,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, SubscriptionContextIsNullForHt EXPECT_CALL(*mockHttpConnection_, close); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockHttpConnection_), yield); }); @@ -354,6 +398,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_Send_Loop) EXPECT_CALL(*mockHttpConnection_, close); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockHttpConnection_), yield); }); @@ -385,6 +433,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_SendError) return makeError(http::error::end_of_stream).error(); }); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockHttpConnection_), yield); }); @@ -408,7 +460,7 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Stop) if (connectionClosed) { return makeError(websocket::error::closed); } - return makeRequest(requestMessage, Request::HttpHeaders{}); + return makeRequest(requestMessage, headers_); }); EXPECT_CALL(wsHandlerMock, Call).Times(3).WillRepeatedly([&](Request const& request, auto&&, auto&&, auto&&) { @@ -429,6 +481,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Stop) EXPECT_CALL(*mockWsConnection_, close).WillOnce([&connectionClosed]() { connectionClosed = true; }); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockWsConnection_), yield); }); @@ -459,6 +515,10 @@ TEST_F(ConnectionHandlerParallelProcessingTest, ReceiveError) EXPECT_CALL(*mockHttpConnection_, wasUpgraded).WillOnce(Return(false)); EXPECT_CALL(*mockHttpConnection_, receive).WillOnce(Return(makeError(http::error::end_of_stream))); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockHttpConnection_), yield); }); @@ -476,7 +536,7 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send) EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true)); EXPECT_CALL(*mockWsConnection_, receive) - .WillOnce(Return(makeRequest(requestMessage, Request::HttpHeaders{}))) + .WillOnce(Return(makeRequest(requestMessage, headers_))) .WillOnce(Return(makeError(websocket::error::closed))); EXPECT_CALL(wsHandlerMock, Call).WillOnce([&](Request const& request, auto&&, auto&&, auto&&) { @@ -489,6 +549,10 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send) return std::nullopt; }); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockWsConnection_), yield); }); @@ -504,7 +568,7 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop) std::string const requestMessage = "some message"; std::string const responseMessage = "some response"; - auto const returnRequest = [&](auto&&, auto&&) { return makeRequest(requestMessage, Request::HttpHeaders{}); }; + auto const returnRequest = [&](auto&&, auto&&) { return makeRequest(requestMessage, headers_); }; EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true)); EXPECT_CALL(*mockWsConnection_, receive) @@ -524,6 +588,10 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop) return std::nullopt; }); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockWsConnection_), yield); }); @@ -539,7 +607,7 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop_TooMany std::string const requestMessage = "some message"; std::string const responseMessage = "some response"; - auto const returnRequest = [&](auto&&, auto&&) { return makeRequest(requestMessage, Request::HttpHeaders{}); }; + auto const returnRequest = [&](auto&&, auto&&) { return makeRequest(requestMessage, headers_); }; testing::Sequence const sequence; EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true)); @@ -583,6 +651,10 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop_TooMany .Times(2) .WillRepeatedly(Return(std::nullopt)); + EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) { + EXPECT_EQ(&c, connectionPtr); + }); + runSpawn([this](boost::asio::yield_context yield) { connectionHandler_.processConnection(std::move(mockWsConnection_), yield); }); diff --git a/tests/unit/web/ng/impl/ErrorHandlingTests.cpp b/tests/unit/web/ng/impl/ErrorHandlingTests.cpp index a1d3e7a1b..2bf6ea968 100644 --- a/tests/unit/web/ng/impl/ErrorHandlingTests.cpp +++ b/tests/unit/web/ng/impl/ErrorHandlingTests.cpp @@ -49,7 +49,8 @@ struct ng_ErrorHandlingTests : NoLoggerFixture { { if (isHttp) return Request{http::request{http::verb::post, "/", 11, body.value_or("")}}; - return Request{body.value_or(""), Request::HttpHeaders{}}; + static Request::HttpHeaders const headers_; + return Request{body.value_or(""), headers_}; } }; diff --git a/tests/unit/web/ng/impl/WsConnectionTests.cpp b/tests/unit/web/ng/impl/WsConnectionTests.cpp index 48c8527d6..2b89ba3e0 100644 --- a/tests/unit/web/ng/impl/WsConnectionTests.cpp +++ b/tests/unit/web/ng/impl/WsConnectionTests.cpp @@ -51,7 +51,8 @@ struct web_WsConnectionTests : SyncAsioContextTest { util::TagDecoratorFactory tagDecoratorFactory_{util::Config{boost::json::object{{"log_tag_style", "int"}}}}; TestHttpServer httpServer_{ctx, "localhost"}; WebSocketAsyncClient wsClient_{ctx}; - Request request_{"some request", Request::HttpHeaders{}}; + Request::HttpHeaders const headers_; + Request request_{"some request", headers_}; std::unique_ptr acceptConnection(boost::asio::yield_context yield)