Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(MQ cleanup) Remove some unsafe code from System.Xml #43379

Merged
merged 10 commits into from
Nov 3, 2020
76 changes: 30 additions & 46 deletions src/libraries/System.Private.Xml/src/System/Xml/Base64Decoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ internal override bool IsFull
}
}

internal override unsafe int Decode(char[] chars, int startPos, int len)
internal override int Decode(char[] chars, int startPos, int len)
{
if (chars == null)
{
Expand All @@ -65,19 +65,14 @@ internal override unsafe int Decode(char[] chars, int startPos, int len)
{
return 0;
}
int bytesDecoded, charsDecoded;
fixed (char* pChars = &chars[startPos])
{
fixed (byte* pBytes = &_buffer![_curIndex])
{
Decode(pChars, pChars + len, pBytes, pBytes + (_endIndex - _curIndex), out charsDecoded, out bytesDecoded);
}
}

Decode(chars.AsSpan(startPos, len), _buffer.AsSpan(_curIndex, _endIndex - _curIndex), out int charsDecoded, out int bytesDecoded);

_curIndex += bytesDecoded;
return charsDecoded;
}

internal override unsafe int Decode(string str, int startPos, int len)
internal override int Decode(string str, int startPos, int len)
{
if (str == null)
{
Expand All @@ -101,14 +96,7 @@ internal override unsafe int Decode(string str, int startPos, int len)
return 0;
}

int bytesDecoded, charsDecoded;
fixed (char* pChars = str)
{
fixed (byte* pBytes = &_buffer![_curIndex])
{
Decode(pChars + startPos, pChars + startPos + len, pBytes, pBytes + (_endIndex - _curIndex), out charsDecoded, out bytesDecoded);
}
}
Decode(str.AsSpan(startPos, len), _buffer.AsSpan(_curIndex, _endIndex - _curIndex), out int charsDecoded, out int bytesDecoded);

_curIndex += bytesDecoded;
return charsDecoded;
Expand Down Expand Up @@ -151,29 +139,28 @@ private static byte[] ConstructMapBase64()
return mapBase64;
}

private unsafe void Decode(char* pChars, char* pCharsEndPos,
byte* pBytes, byte* pBytesEndPos,
out int charsDecoded, out int bytesDecoded)
private void Decode(ReadOnlySpan<char> chars, Span<byte> bytes, out int charsDecoded, out int bytesDecoded)
{
#if DEBUG
Debug.Assert(pCharsEndPos - pChars >= 0);
Debug.Assert(pBytesEndPos - pBytes >= 0);
#endif

// walk hex digits pairing them up and shoving the value of each pair into a byte
byte* pByte = pBytes;
char* pChar = pChars;
int iByte = 0;
int iChar = 0;
int b = _bits;
int bFilled = _bitsFilled;
while (pChar < pCharsEndPos && pByte < pBytesEndPos)

while ((uint)iChar < (uint)chars.Length)
{
char ch = *pChar;
if ((uint)iByte >= (uint)bytes.Length)
{
break; // ran out of space in the destination buffer
}

char ch = chars[iChar];
// end?
if (ch == '=')
{
break;
}
pChar++;
iChar++;

// ignore whitespace
if (XmlCharType.IsWhiteSpace(ch))
Expand All @@ -184,7 +171,7 @@ private unsafe void Decode(char* pChars, char* pCharsEndPos,
int digit;
if (ch > 122 || (digit = s_mapBase64[ch]) == Invalid)
{
throw new XmlException(SR.Xml_InvalidBase64Value, new string(pChars, 0, (int)(pCharsEndPos - pChars)));
throw new XmlException(SR.Xml_InvalidBase64Value, chars.ToString());
GrabYourPitchforks marked this conversation as resolved.
Show resolved Hide resolved
}

b = (b << 6) | digit;
Expand All @@ -193,44 +180,41 @@ private unsafe void Decode(char* pChars, char* pCharsEndPos,
if (bFilled >= 8)
{
// get top eight valid bits
*pByte++ = (byte)((b >> (bFilled - 8)) & 0xFF);
bytes[iByte++] = (byte)((b >> (bFilled - 8)) & 0xFF);
bFilled -= 8;

if (pByte == pBytesEndPos)
if (iByte == bytes.Length)
{
goto Return;
}
}
}

if (pChar < pCharsEndPos && *pChar == '=')
if ((uint)iChar < (uint)chars.Length && chars[iChar] == '=')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to cast it everywhere? Should you just make it unsigned from the start?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This follows the "bounds check elision" pattern used elsewhere in the runtime code base, where the indexer is cast to uint only for the purposes of comparison against Length. All other uses of the indexer remain as normal signed integers.

I'm not a huge fan of this pattern, since I wish we'd just use nuint everywhere and call it a day. :)

{
bFilled = 0;
// ignore padding chars
do
{
pChar++;
} while (pChar < pCharsEndPos && *pChar == '=');
iChar++;
} while ((uint)iChar < (uint)chars.Length && chars[iChar] == '=');

// ignore whitespace after the padding chars
if (pChar < pCharsEndPos)
while ((uint)iChar < (uint)chars.Length)
{
do
if (!XmlCharType.IsWhiteSpace(chars[iChar++]))
{
if (!XmlCharType.IsWhiteSpace(*pChar++))
{
throw new XmlException(SR.Xml_InvalidBase64Value, new string(pChars, 0, (int)(pCharsEndPos - pChars)));
}
} while (pChar < pCharsEndPos);
throw new XmlException(SR.Xml_InvalidBase64Value, chars.ToString());
krwq marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

Return:
_bits = b;
_bitsFilled = bFilled;

bytesDecoded = (int)(pByte - pBytes);
charsDecoded = (int)(pChar - pChars);
bytesDecoded = iByte;
charsDecoded = iChar;
}
}
}
77 changes: 29 additions & 48 deletions src/libraries/System.Private.Xml/src/System/Xml/BinHexDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ internal override bool IsFull
}
}

internal override unsafe int Decode(char[] chars, int startPos, int len)
internal override int Decode(char[] chars, int startPos, int len)
{
if (chars == null)
{
Expand All @@ -61,20 +61,15 @@ internal override unsafe int Decode(char[] chars, int startPos, int len)
return 0;
}

int bytesDecoded, charsDecoded;
fixed (char* pChars = &chars[startPos])
{
fixed (byte* pBytes = &_buffer![_curIndex])
{
Decode(pChars, pChars + len, pBytes, pBytes + (_endIndex - _curIndex),
ref _hasHalfByteCached, ref _cachedHalfByte, out charsDecoded, out bytesDecoded);
}
}
Decode(chars.AsSpan(startPos, len), _buffer.AsSpan(_curIndex, _endIndex - _curIndex),
ref _hasHalfByteCached, ref _cachedHalfByte,
out int charsDecoded, out int bytesDecoded);

_curIndex += bytesDecoded;
return charsDecoded;
}

internal override unsafe int Decode(string str, int startPos, int len)
internal override int Decode(string str, int startPos, int len)
{
if (str == null)
{
Expand All @@ -98,15 +93,9 @@ internal override unsafe int Decode(string str, int startPos, int len)
return 0;
}

int bytesDecoded, charsDecoded;
fixed (char* pChars = str)
{
fixed (byte* pBytes = &_buffer![_curIndex])
{
Decode(pChars + startPos, pChars + startPos + len, pBytes, pBytes + (_endIndex - _curIndex),
ref _hasHalfByteCached, ref _cachedHalfByte, out charsDecoded, out bytesDecoded);
}
}
Decode(str.AsSpan(startPos, len), _buffer.AsSpan(_curIndex, _endIndex - _curIndex),
ref _hasHalfByteCached, ref _cachedHalfByte,
out int charsDecoded, out int bytesDecoded);

_curIndex += bytesDecoded;
return charsDecoded;
Expand Down Expand Up @@ -135,7 +124,7 @@ internal override void SetNextOutputBuffer(Array buffer, int index, int count)
//
// Static methods
//
public static unsafe byte[] Decode(char[] chars, bool allowOddChars)
public static byte[] Decode(char[] chars, bool allowOddChars)
{
if (chars == null)
{
Expand All @@ -149,17 +138,10 @@ public static unsafe byte[] Decode(char[] chars, bool allowOddChars)
}

byte[] bytes = new byte[(len + 1) / 2];
int bytesDecoded, charsDecoded;
bool hasHalfByteCached = false;
byte cachedHalfByte = 0;

fixed (char* pChars = &chars[0])
{
fixed (byte* pBytes = &bytes[0])
{
Decode(pChars, pChars + len, pBytes, pBytes + bytes.Length, ref hasHalfByteCached, ref cachedHalfByte, out charsDecoded, out bytesDecoded);
}
}
Decode(chars, bytes, ref hasHalfByteCached, ref cachedHalfByte, out int charsDecoded, out int bytesDecoded);

if (hasHalfByteCached && !allowOddChars)
{
Expand All @@ -168,9 +150,7 @@ public static unsafe byte[] Decode(char[] chars, bool allowOddChars)

if (bytesDecoded < bytes.Length)
{
byte[] tmp = new byte[bytesDecoded];
Buffer.BlockCopy(bytes, 0, tmp, 0, bytesDecoded);
bytes = tmp;
Array.Resize(ref bytes, bytesDecoded);
}

return bytes;
Expand All @@ -180,22 +160,23 @@ public static unsafe byte[] Decode(char[] chars, bool allowOddChars)
// Private methods
//

private static unsafe void Decode(char* pChars, char* pCharsEndPos,
byte* pBytes, byte* pBytesEndPos,
ref bool hasHalfByteCached, ref byte cachedHalfByte,
out int charsDecoded, out int bytesDecoded)
private static void Decode(ReadOnlySpan<char> chars,
Span<byte> bytes,
ref bool hasHalfByteCached, ref byte cachedHalfByte,
out int charsDecoded, out int bytesDecoded)
{
#if DEBUG
Debug.Assert(pCharsEndPos - pChars >= 0);
Debug.Assert(pBytesEndPos - pBytes >= 0);
#endif
int iByte = 0;
int iChar = 0;

char* pChar = pChars;
byte* pByte = pBytes;
while (pChar < pCharsEndPos && pByte < pBytesEndPos)
for (; iChar < chars.Length; iChar++)
{
if ((uint)iByte >= (uint)bytes.Length)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code is written this way to avoid bounds checks later in the method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GrabYourPitchforks, (maybe you already know it but) since #40180 was merged, the uint cast workaround is not needed in some cases where constant operands are involved. It turned out that in some cases, RyuJIT does not elide the bound check with uint casts, where it was doing before. I posted some findings on macOS here: #11623 (comment). It is unlikely that this construct with (non const) operands is affected by that change, but perhaps would be good to double check with the latest master. I have a feeling that in some places in the framework, we can remove the uint cast, as they are deoptomized.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@am11 This is a good observation! As these other optimizations come online I think it'll be useful to perform a libraries-wide sweep of all of these patterns. It's always great to make the code more readable while maintaining peak efficiency. :)

{
break; // ran out of space in the destination buffer
}

byte halfByte;
char ch = *pChar++;
char ch = chars[iChar];

int val = HexConverter.FromChar(ch);
if (val != 0xFF)
Expand All @@ -208,12 +189,12 @@ private static unsafe void Decode(char* pChars, char* pCharsEndPos,
}
else
{
throw new XmlException(SR.Xml_InvalidBinHexValue, new string(pChars, 0, (int)(pCharsEndPos - pChars)));
throw new XmlException(SR.Xml_InvalidBinHexValue, chars.ToString());
}

if (hasHalfByteCached)
{
*pByte++ = (byte)((cachedHalfByte << 4) + halfByte);
bytes[iByte++] = (byte)((cachedHalfByte << 4) + halfByte);
hasHalfByteCached = false;
}
else
Expand All @@ -223,8 +204,8 @@ private static unsafe void Decode(char* pChars, char* pCharsEndPos,
}
}

bytesDecoded = (int)(pByte - pBytes);
charsDecoded = (int)(pChar - pChars);
bytesDecoded = iByte;
charsDecoded = iChar;
}
}
}
Loading