Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new I/O engine in CLI commands and features #7010

Closed
wants to merge 8 commits into from
102 changes: 102 additions & 0 deletions lib/base/netstring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,85 @@ size_t NetString::WriteStringToStream(const Stream::Ptr& stream, const String& s
return msg.GetLength();
}

/**
* Reads data from a stream in netstring format.
*
* @param stream The stream to read from.
* @returns The String that has been read from the IOQueue.
* @exception invalid_argument The input stream is invalid.
* @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c
*/
String NetString::ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream,
ssize_t maxMessageLength)
{
namespace asio = boost::asio;

size_t len = 0;
bool leadingZero = false;

for (uint_fast8_t readBytes = 0;; ++readBytes) {
char byte = 0;

{
asio::mutable_buffer byteBuf (&byte, 1);
asio::read(*stream, byteBuf);
}

if (isdigit(byte)) {
if (readBytes == 9) {
BOOST_THROW_EXCEPTION(std::invalid_argument("Length specifier must not exceed 9 characters"));
}

if (leadingZero) {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (leading zero)"));
}

len = len * 10u + size_t(byte - '0');

if (!readBytes && byte == '0') {
leadingZero = true;
}
} else if (byte == ':') {
if (!readBytes) {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (no length specifier)"));
}

break;
} else {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing :)"));
}
}

if (maxMessageLength >= 0 && len > maxMessageLength) {
std::stringstream errorMessage;
errorMessage << "Max data length exceeded: " << (maxMessageLength / 1024) << " KB";

BOOST_THROW_EXCEPTION(std::invalid_argument(errorMessage.str()));
}

String payload;

if (len) {
payload.Append(len, 0);

asio::mutable_buffer payloadBuf (&*payload.Begin(), payload.GetLength());
asio::read(*stream, payloadBuf);
}

char trailer = 0;

{
asio::mutable_buffer trailerBuf (&trailer, 1);
asio::read(*stream, trailerBuf);
}

if (trailer != ',') {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing ,)"));
}

return std::move(payload);
}

/**
* Reads data from a stream in netstring format.
*
Expand Down Expand Up @@ -197,6 +276,29 @@ String NetString::ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& str
return std::move(payload);
}

/**
* Writes data into a stream using the netstring format and returns bytes written.
*
* @param stream The stream.
* @param str The String that is to be written.
*
* @return The amount of bytes written.
*/
size_t NetString::WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& str)
{
namespace asio = boost::asio;

std::ostringstream msgbuf;
WriteStringToStream(msgbuf, str);

String msg = msgbuf.str();
asio::const_buffer msgBuf (msg.CStr(), msg.GetLength());

asio::write(*stream, msgBuf);

return msg.GetLength();
}

/**
* Writes data into a stream using the netstring format and returns bytes written.
*
Expand Down
2 changes: 2 additions & 0 deletions lib/base/netstring.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ class NetString
public:
static StreamReadStatus ReadStringFromStream(const Stream::Ptr& stream, String *message, StreamReadContext& context,
bool may_wait = false, ssize_t maxMessageLength = -1);
static String ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream, ssize_t maxMessageLength = -1);
static String ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream,
boost::asio::yield_context yc, ssize_t maxMessageLength = -1);
static size_t WriteStringToStream(const Stream::Ptr& stream, const String& message);
static size_t WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& message);
static size_t WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& message, boost::asio::yield_context yc);
static void WriteStringToStream(std::ostream& stream, const String& message);

Expand Down
29 changes: 29 additions & 0 deletions lib/base/tcpsocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,35 @@ class TcpSocket final : public Socket
void Connect(const String& node, const String& service);
};

template<class Socket>
void Connect(Socket& socket, const String& node, const String& service)
{
using boost::asio::ip::tcp;

tcp::resolver resolver (socket.get_io_service());
tcp::resolver::query query (node, service);
auto result (resolver.resolve(query));
auto current (result.begin());

for (;;) {
try {
socket.open(current->endpoint().protocol());
socket.set_option(tcp::socket::keep_alive(true));
socket.connect(current->endpoint());

break;
} catch (const std::exception&) {
if (++current == result.end()) {
throw;
}

if (socket.is_open()) {
socket.close();
}
}
}
}

template<class Socket>
void Connect(Socket& socket, const String& node, const String& service, boost::asio::yield_context yc)
{
Expand Down
5 changes: 5 additions & 0 deletions lib/base/tlsstream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,11 @@ String UnbufferedAsioTlsStream::GetVerifyError() const
return m_VerifyError;
}

std::shared_ptr<X509> UnbufferedAsioTlsStream::GetPeerCertificate()
{
return std::shared_ptr<X509>(SSL_get_peer_certificate(native_handle()), X509_free);
}

void UnbufferedAsioTlsStream::BeforeHandshake(handshake_type type)
{
namespace ssl = boost::asio::ssl;
Expand Down
1 change: 1 addition & 0 deletions lib/base/tlsstream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class UnbufferedAsioTlsStream : public AsioTcpTlsStream

bool IsVerifyOK() const;
String GetVerifyError() const;
std::shared_ptr<X509> GetPeerCertificate();

template<class... Args>
inline
Expand Down
Loading