Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Base64::Decode performance #11467

Merged
8 commits merged into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/actions/spelling/expect/expect.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ cfae
Cfg
cfie
cfiex
Cfj
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting -- I wonder if we should not add base64 snippets into the expect list and instead pattern or file exclude them?

Cfk
cfte
CFuzz
cgscrn
Expand Down Expand Up @@ -1535,6 +1537,7 @@ NOCOLOR
NOCOMM
NOCONTEXTHELP
NOCOPYBITS
NOCRLF
nodiscard
NODUP
noexcept
Expand Down Expand Up @@ -1897,6 +1900,7 @@ PUNICODE
pushd
putchar
putwchar
Pvf
PVOID
pwch
PWCHAR
Expand Down Expand Up @@ -2025,6 +2029,7 @@ riid
Rike
RIPMSG
RIS
Rjf
RMENU
rng
roadmap
Expand Down Expand Up @@ -2860,4 +2865,6 @@ zamora
ZCmd
ZCtrl
zsh
ZWc
zwn
zxcvbnm
7 changes: 6 additions & 1 deletion src/terminal/parser/OutputStateMachineEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,12 @@ bool OutputStateMachineEngine::_GetOscSetClipboard(const std::wstring_view strin
}
else
{
return Base64::s_Decode(substr, content);
try
{
Base64::Decode(substr, content);
return true;
}
CATCH_LOG()
lhecker marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
271 changes: 109 additions & 162 deletions src/terminal/parser/base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,190 +4,137 @@
#include "precomp.h"
#include "base64.hpp"

using namespace Microsoft::Console::VirtualTerminal;

static const char base64Chars[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
static const char padChar = '=';

#pragma warning(disable : 26446 26447 26482 26485 26493 26494)

// Routine Description:
// - Encode a string using base64. When there are not enough characters
// for one quantum, paddings are added.
// Arguments:
// - src - String to base64 encode.
// Return Value:
// - the encoded string.
std::wstring Base64::s_Encode(const std::wstring_view src) noexcept
{
std::wstring dst;
wchar_t input[3];

const auto len = (src.size() + 2) / 3 * 4;
if (len == 0)
{
return dst;
}
dst.reserve(len);

auto iter = src.cbegin();
// Encode each three chars into one quantum (four chars).
while (iter < src.cend() - 2)
{
input[0] = *iter++;
input[1] = *iter++;
input[2] = *iter++;
dst.push_back(base64Chars[input[0] >> 2]);
dst.push_back(base64Chars[(input[0] & 0x03) << 4 | input[1] >> 4]);
dst.push_back(base64Chars[(input[1] & 0x0f) << 2 | input[2] >> 6]);
dst.push_back(base64Chars[(input[2] & 0x3f)]);
}

// Here only zero, or one, or two chars are left. We may need to add paddings.
if (iter < src.cend())
{
input[0] = *iter++;
dst.push_back(base64Chars[input[0] >> 2]);
if (iter < src.cend()) // Two chars left.
{
input[1] = *iter++;
dst.push_back(base64Chars[(input[0] & 0x03) << 4 | input[1] >> 4]);
dst.push_back(base64Chars[(input[1] & 0x0f) << 2]);
}
else // Only one char left.
{
dst.push_back(base64Chars[(input[0] & 0x03) << 4]);
dst.push_back(padChar);
}
dst.push_back(padChar);
}
#pragma warning(disable : 26446) // Prefer to use gsl::at() instead of unchecked subscript operator (bounds.4).
#pragma warning(disable : 26481) // Don't use pointer arithmetic. Use span instead (bounds.1).
#pragma warning(disable : 26482) // Only index into arrays using constant expressions (bounds.2).

return dst;
}
using namespace Microsoft::Console::VirtualTerminal;

// Routine Description:
// - Decode a base64 string. This requires the base64 string is properly padded.
// Otherwise, false will be returned.
// Arguments:
// - src - String to decode.
// - dst - Destination to decode into.
// Return Value:
// - true if decoding successfully, otherwise false.
bool Base64::s_Decode(const std::wstring_view src, std::wstring& dst) noexcept
// clang-format off
static constexpr uint8_t decodeTable[128] = {
255 /* NUL */, 255 /* SOH */, 255 /* STX */, 255 /* ETX */, 255 /* EOT */, 255 /* ENQ */, 255 /* ACK */, 255 /* BEL */, 255 /* BS */, 255 /* HT */, 255 /* LF */, 255 /* VT */, 255 /* FF */, 255 /* CR */, 255 /* SO */, 255 /* SI */,
255 /* DLE */, 255 /* DC1 */, 255 /* DC2 */, 255 /* DC3 */, 255 /* DC4 */, 255 /* NAK */, 255 /* SYN */, 255 /* ETB */, 255 /* CAN */, 255 /* EM */, 255 /* SUB */, 255 /* ESC */, 255 /* FS */, 255 /* GS */, 255 /* RS */, 255 /* US */,
255 /* SP */, 255 /* ! */, 255 /* " */, 255 /* # */, 255 /* $ */, 255 /* % */, 255 /* & */, 255 /* ' */, 255 /* ( */, 255 /* ) */, 255 /* * */, 62 /* + */, 255 /* , */, 62 /* - */, 255 /* . */, 63 /* / */,
52 /* 0 */, 53 /* 1 */, 54 /* 2 */, 55 /* 3 */, 56 /* 4 */, 57 /* 5 */, 58 /* 6 */, 59 /* 7 */, 60 /* 8 */, 61 /* 9 */, 255 /* : */, 255 /* ; */, 255 /* < */, 255 /* = */, 255 /* > */, 255 /* ? */,
255 /* @ */, 0 /* A */, 1 /* B */, 2 /* C */, 3 /* D */, 4 /* E */, 5 /* F */, 6 /* G */, 7 /* H */, 8 /* I */, 9 /* J */, 10 /* K */, 11 /* L */, 12 /* M */, 13 /* N */, 14 /* O */,
15 /* P */, 16 /* Q */, 17 /* R */, 18 /* S */, 19 /* T */, 20 /* U */, 21 /* V */, 22 /* W */, 23 /* X */, 24 /* Y */, 25 /* Z */, 255 /* [ */, 255 /* \ */, 255 /* ] */, 255 /* ^ */, 63 /* _ */,
255 /* ` */, 26 /* a */, 27 /* b */, 28 /* c */, 29 /* d */, 30 /* e */, 31 /* f */, 32 /* g */, 33 /* h */, 34 /* i */, 35 /* j */, 36 /* k */, 37 /* l */, 38 /* m */, 39 /* n */, 40 /* o */,
41 /* p */, 42 /* q */, 43 /* r */, 44 /* s */, 45 /* t */, 46 /* u */, 47 /* v */, 48 /* w */, 49 /* x */, 50 /* y */, 51 /* z */, 255 /* { */, 255 /* | */, 255 /* } */, 255 /* ~ */, 255 /* DEL */,
};
// clang-format on

// Decodes an UTF8 string encoded with RFC 4648 (Base64) and returns it as UTF16 in dst.
// It supports both variants of the RFC (base64 and base64url), but
// throws an error for non-alphabet characters, including newlines.
// * Throws an exception for all invalid base64 inputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well,

// * Doesn't support whitespace and will throw an exception for such strings.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i bet that this will come bite us later, but i am willing to take that risk

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it’s "return ERROR_INVALID_DATA" now. I‘ll update the comment.
Is there anything else you’re worried about?

// * Doesn't validate the number of trailing "=". Those are basically ignored.
// Strings like "YQ===" will be accepted as valid input and simply result in "a".
void Base64::Decode(const std::wstring_view& src, std::wstring& dst)
{
std::string mbStr;
int state = 0;
char tmp;

const auto len = src.size() / 4 * 3;
if (len == 0)
std::string result;
result.resize(((src.size() + 3) / 4) * 3);

// in and inEnd may be nullptr if src.empty().
// The remaining code in this function ensures not to read from in if src.empty().
#pragma warning(suppress : 26429) // Symbol 'in' is never tested for nullness, it can be marked as not_null (f.23).
auto in = src.data();
auto inEnd = in + src.size();
// Sometimes in programming you have to ask yourself what the right offset for a pointer is.
// Is 4 enough? Certainly not. 6 on the other hand is just way too much. Clearly 5 is just right.
//
// In all seriousness however the offset is 5, because the batched loop reads 4 characters at a time,
// a base64 string can end with two "=" and the batched loop doesn't handle any such "=".
// Additionally the while() condition of the batched loop would make a lot more sense if it were using <=,
// but for reasons outlined below it needs to use < so we need to add 1 back again.
// We thus get -4-2+1 which is -5.
//
// There's a special reason we need to use < and not <= for the loop:
// In C++ it's undefined behavior to perform any pointer arithmetic that leads to unallocated memory,
// which is why we can't just write `inEnd - 6` as that might be UB if `src.size()` is less than 6.
// We thus would need write `inEnd - min(6, src.size())` in combination with `<=` for the batched loop.
// But if `src.size()` is actually less than 6 then `inEnd` is equal to the initial `in`, aka: an empty range.
// In such cases we'd enter the batched loop and read from `in` despite us not wanting to enter the loop.
// We can fix the issue by using < instead and adding +1 to the offset.
//
// Yes this works.
const auto inEndBatched = inEnd - std::min<size_t>(5, src.size());

// outBeg and out may be nullptr if src.empty().
// The remaining code in this function ensures not to write to out if src.empty().
const auto outBeg = result.data();
#pragma warning(suppress : 26429) // Symbol 'out' is never tested for nullness, it can be marked as not_null (f.23).
auto out = outBeg;

uint_fast32_t r = 0;
uint_fast16_t error = 0;

#define accumulate(ch) \
lhecker marked this conversation as resolved.
Show resolved Hide resolved
do \
{ \
const auto n = decodeTable[ch & 0x7f]; \
error |= (ch | n) & 0xff80; \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wat

how does this work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay so error is effectively, "is the loaded value 255 (the invalid sentinel in the table) or the character larger than 7 bits (clearly invalid ASCII)"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments explaining this.

r = r << 6 | n; \
} while (0)

// If src.empty() then `in == inEndBatched == nullptr` and this is skipped.
while (in < inEndBatched)
{
return false;
const auto a = *in++;
const auto b = *in++;
const auto c = *in++;
const auto d = *in++;

accumulate(a);
accumulate(b);
accumulate(c);
accumulate(d);

*out++ = gsl::narrow_cast<char>(r >> 16);
*out++ = gsl::narrow_cast<char>(r >> 8);
*out++ = gsl::narrow_cast<char>(r >> 0);
}
mbStr.reserve(len);

auto iter = src.cbegin();
while (iter < src.cend())
{
if (s_IsSpace(*iter)) // Skip whitespace anywhere.
{
iter++;
continue;
}

if (*iter == padChar)
{
break;
}
uint_fast8_t ri = 0;

auto pos = strchr(base64Chars, *iter);
if (!pos) // A non-base64 character found.
// If src.empty() then `in == inEnd == nullptr` and this is skipped.
for (; in < inEnd; ++in)
{
return false;
if (const auto ch = *in; ch != '=')
{
accumulate(ch);
ri++;
}
}

switch (state)
switch (ri)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ri => remainder index?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I called r because it's the accumulation "register" (since on a CPU level r will almost certainly live inside a register most of the time).
And since i is usually a counter for something I called it ri, the "register index-counter-thingy".

"remainder index" however is much better. I'll steal that. 😄

{
case 0:
tmp = (char)(pos - base64Chars) << 2;
state = 1;
break;
case 1:
tmp |= (char)(pos - base64Chars) >> 4;
mbStr += tmp;
tmp = (char)((pos - base64Chars) & 0x0f) << 4;
state = 2;
break;
case 2:
tmp |= (char)(pos - base64Chars) >> 2;
mbStr += tmp;
tmp = (char)((pos - base64Chars) & 0x03) << 6;
state = 3;
*out++ = gsl::narrow_cast<char>(r >> 4);
break;
case 3:
tmp |= pos - base64Chars;
mbStr += tmp;
state = 0;
break;
default:
*out++ = gsl::narrow_cast<char>(r >> 10);
*out++ = gsl::narrow_cast<char>(r >> 2);
break;
}

iter++;
}

if (iter < src.cend()) // Padding char is met.
{
iter++;
switch (state)
{
// Invalid when state is 0 or 1.
case 0:
case 1:
return false;
case 2:
// Skip any number of spaces.
while (iter < src.cend() && s_IsSpace(*iter))
{
iter++;
}
// Make sure there is another trailing padding character.
if (iter == src.cend() || *iter != padChar)
{
return false;
}
iter++; // Skip the padding character and fallthrough to "single trailing padding character" case.
[[fallthrough]];
case 3:
while (iter < src.cend())
{
if (!s_IsSpace(*iter))
{
return false;
}
iter++;
}
case 4:
*out++ = gsl::narrow_cast<char>(r >> 16);
*out++ = gsl::narrow_cast<char>(r >> 8);
*out++ = gsl::narrow_cast<char>(r >> 0);
break;
default:
error |= ri;
break;
}
}
else if (state != 0) // When no padding, we must be in state 0.

#undef accumulate

if (error)
{
return false;
throw std::runtime_error("invalid base64");
}

return SUCCEEDED(til::u8u16(mbStr, dst));
}

// Routine Description:
// - Check if parameter is a base64 whitespace. Only carriage return or line feed
// is valid whitespace.
// Arguments:
// - ch - Character to check.
// Return Value:
// - true iff ch is a carriage return or line feed.
constexpr bool Base64::s_IsSpace(const wchar_t ch) noexcept
{
return ch == L'\r' || ch == L'\n';
result.resize(out - outBeg);
THROW_IF_FAILED(til::u8u16(result, dst));
}
6 changes: 1 addition & 5 deletions src/terminal/parser/base64.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ namespace Microsoft::Console::VirtualTerminal
class Base64
{
public:
static std::wstring s_Encode(const std::wstring_view src) noexcept;
static bool s_Decode(const std::wstring_view src, std::wstring& dst) noexcept;

private:
static constexpr bool s_IsSpace(const wchar_t ch) noexcept;
static void Decode(const std::wstring_view& src, std::wstring& dst);
};
}
Loading