Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of Span.Clear #51534

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public static void Fill<T>(ref T refData, nuint numElements, T value)

ref byte refDataAsBytes = ref Unsafe.As<T, byte>(ref refData);
nuint totalByteLength = numElements * (nuint)Unsafe.SizeOf<T>(); // get this calculation ready ahead of time
nuint stopLoopAtOffset = totalByteLength & (nuint)(nint)(2 * (int)-Vector<byte>.Count); // intentional sign extension carries the negative bit
nuint stopLoopAtOffset = totalByteLength & (nuint)(nint)(2 * -Vector<byte>.Count); // intentional sign extension carries the negative bit
nuint offset = 0;

// Loop, writing 2 vectors at a time.
Expand Down Expand Up @@ -122,6 +122,7 @@ public static void Fill<T>(ref T refData, nuint numElements, T value)
// fit an entire vector's worth of data. Instead of falling back to a loop, we'll write
// a vector at the very end of the buffer. This may involve overwriting previously
// populated data, which is fine since we're splatting the same value for all entries.
// (n.b. This statement is no longer valid if we try to ensure these writes are aligned.)
// There's no need to perform a length check here because we already performed this
// check before entering the vectorized code path.

Expand Down
169 changes: 111 additions & 58 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Diagnostics;

using System.Numerics;
using Internal.Runtime.CompilerServices;

namespace System
Expand Down Expand Up @@ -332,77 +333,129 @@ public static unsafe void ClearWithoutReferences(ref byte b, nuint byteLength)

public static unsafe void ClearWithReferences(ref IntPtr ip, nuint pointerSizeLength)
{
Debug.Assert((int)Unsafe.AsPointer(ref ip) % sizeof(IntPtr) == 0, "Should've been aligned on natural word boundary.");
Debug.Assert((nint)Unsafe.AsPointer(ref ip) % sizeof(IntPtr) == 0, "Should've been aligned on natural word boundary.");

// First write backward 8 natural words at a time.
// Writing backward allows us to get away with only simple modifications to the
// mov instruction's base and index registers between loop iterations.
// Since references are always natural word-aligned, our "unaligned" writes below will
// always be natural word-aligned as well. Even if the full SIMD write is split across
// pages, no core will ever observe any reference as containing a torn address.

for (; pointerSizeLength >= 8; pointerSizeLength -= 8)
if (Vector.IsHardwareAccelerated && pointerSizeLength >= (uint)(Vector<byte>.Count / sizeof(IntPtr)))
{
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -1) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -2) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -3) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -4) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -5) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -6) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -7) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -8) = default;
}
// We have enough data for at least one vectorized write.
// Perform that write now, potentially unaligned.

Debug.Assert(pointerSizeLength <= 7);

// The logic below works by trying to minimize the number of branches taken for any
// given range of lengths. For example, the lengths [ 4 .. 7 ] are handled by a single
// branch, [ 2 .. 3 ] are handled by a single branch, and [ 1 ] is handled by a single
// branch.
//
// We can write both forward and backward as a perf improvement. For example,
// the lengths [ 4 .. 7 ] can be handled by zeroing out the first four natural
// words and the last 3 natural words. In the best case (length = 7), there are
// no overlapping writes. In the worst case (length = 4), there are three
// overlapping writes near the middle of the buffer. In perf testing, the
// penalty for performing duplicate writes is less expensive than the penalty
// for complex branching.

if (pointerSizeLength >= 4)
{
goto Write4To7;
}
else if (pointerSizeLength >= 2)
{
goto Write2To3;
}
else if (pointerSizeLength > 0)
{
goto Write1;
Vector<byte> zero = default;
ref byte refDataAsBytes = ref Unsafe.As<IntPtr, byte>(ref ip);
Unsafe.WriteUnaligned(ref refDataAsBytes, zero);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that Unsafe.WriteUnaligned is atomic for hardware accelerated vectors. It is not a safe assumption to make in general. It may the reason for the failing Mono tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't need to be atomic across the entire vector. It only needs to be atomic at a native word granularity. @tannergooding might know more, but I believe every retail or microprocessor in the market provides this guarantee.

Copy link
Member

@jkotas jkotas Apr 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are making assumption about the instruction that this is going to compile into.

The compiler (LLVM in this case) is free to compile this into byte-at-a-time copy, especially with optimizations off.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, IL interpretter is unlikely to execute this atomically either.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you would like to keep this, I think you may need #define ATOMIC_UNALIGNED_VECTOR_OPERATIONS that won't be set for MONO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the ARM spec actually guarantee this?

Yes for ARM64. image

I'm not 100% on how it handles ARM32 given that would be 4x 32-bit reads/writes. Probably need to dig further into the spec and the ARM32 specific sections.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular:

  • Reads to SIMD and floating-point registers of a 128-bit value that is 64-bit aligned in memory are treated as a pair of single-copy atomic 64-bit reads.
  • Writes from SIMD and floating-point registers of a 128-bit value that is 64-bit aligned in memory are treated as a pair of single-copy atomic 64-bit writes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also noting, I don't think this guarantees that something like LLVM won't rewrite the copy loop to be "more efficient". I'd imagine Mono is explicitly marking certain loops to prevent LLVM from doing such optimizations.

Copy link
Member Author

@GrabYourPitchforks GrabYourPitchforks Apr 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we scrapped the Vector.IsHardwareAccelerated check and dropped down to raw SSE2 / ARM64 intrinsics instead? Presumably that wouldn't provide the opportunity for anybody to rewrite these methods in a non-compliant manner? Though I wonder if that would hurt perf since it'd require pinning (and the associated stack spill) to use the intrinsics.

Could also use Jan's suggestion of wrapping the vectorized code in #if CORECLR, bypassing the whole issue.

Copy link
Member

@tannergooding tannergooding Apr 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine that depends on which instructions are being used and what optimizations LLVM is doing. It is conceivably "correct" in C/C++ code for it to recognize __mm_load_si128 and __mm_store_si128 as a "trivial copy loop" and to replace it with ERMSB.

(and https://godbolt.org/z/hsfE4nKnr shows that it does recognize and replace with memcpy "improved" code 😄)


// Now, attempt to align the rest of the writes.
// It's possible that the GC could kick in mid-method and unalign everything, but that should
// be rare enough that it's not worth worrying about here. Worst case it slows things down a bit.

nint offsetFromAligned = (nint)Unsafe.AsPointer(ref refDataAsBytes) & (Vector<byte>.Count - 1);
nuint totalByteLength = pointerSizeLength * (nuint)sizeof(IntPtr) + (nuint)offsetFromAligned - (nuint)Vector<byte>.Count;
refDataAsBytes = ref Unsafe.Add(ref refDataAsBytes, Vector<byte>.Count); // legal GC-trackable reference due to earlier length check
refDataAsBytes = ref Unsafe.Add(ref refDataAsBytes, -offsetFromAligned); // this subtraction MUST BE AFTER the addition above to avoid creating an intermediate invalid gcref
nuint offset = 0;

// Loop, writing 2 vectors at a time.

if (totalByteLength >= (uint)(2 * Vector<byte>.Count))
{
nuint stopLoopAtOffset = totalByteLength & (nuint)(nint)(2 * -Vector<byte>.Count); // intentional sign extension carries the negative bit

do
{
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset), zero);
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset + (nuint)Vector<byte>.Count), zero);
offset += (uint)(2 * Vector<byte>.Count);
} while (offset < stopLoopAtOffset);
}

// At this point, if any data remains to be written, it's strictly less than
// 2 * sizeof(Vector) bytes. The loop above had us write an even number of vectors.
// If the total byte length instead involves us writing an odd number of vectors, write
// one additional vector now. The bit check below tells us if we're in an "odd vector
// count" situation.

if ((totalByteLength & (nuint)Vector<byte>.Count) != 0)
{
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset), zero);
}

// It's possible that some small buffer remains to be populated - something that won't
// fit an entire vector's worth of data. Instead of falling back to a loop, we'll write
// a vector at the very end of the buffer. This may involve overwriting previously
// populated data, which is fine since we're just zeroing everything out anyway.
// There's no need to perform a length check here because we already performed this
// check before entering the vectorized code path.

Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, totalByteLength - (nuint)Vector<byte>.Count), zero);
}
else
{
return; // nothing to write
}
// If we reached this point, vectorization is disabled, or there are too few
// elements for us to vectorize. Fall back to an unrolled loop.

nuint i = 0;

Write4To7:
Debug.Assert(pointerSizeLength >= 4);
// Write 8 elements at a time
// Skip this check if "write 8 elements" would've gone down the vectorized code path earlier

// Write first four and last three.
Unsafe.Add(ref ip, 2) = default;
Unsafe.Add(ref ip, 3) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -3) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -2) = default;
if (!Vector.IsHardwareAccelerated || Vector<byte>.Count / sizeof(IntPtr) > 8) // JIT turns this into constant true or false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm missing it... why is this check necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JIT won't properly elide the following if without this check. Here's the codegen if this and the similar check a few lines down are removed. The fallback logic becomes much larger. But note that blocks d07b - d0ba and blocks d0c1 - d0e2 will never be executed. The explicit if in the code above is hinting to JIT that it shouldn't worry about this scenario. (That's also the crux behind the "skip this check if..." comments above these lines, but I can wordsmith if necessary.)

00007ffa`30eed000 c5f877                 vzeroupper 
00007ffa`30eed003 4883fa04               cmp     rdx, 4
00007ffa`30eed007 726a                   jb      System_Private_CoreLib!System.SpanHelpers.ClearWithReferences(IntPtr ByRef,  UIntPtr)+0x73 (00007ffa`30eed073)

;; BEGIN MAIN VECTOR LOGIC
00007ffa`30eed009 c4e17c57c0             vxorps  ymm0, ymm0, ymm0
00007ffa`30eed00e c4e17d1101             vmovupd ymmword ptr [rcx], ymm0
; <snip>
00007ffa`30eed06f c5f877                 vzeroupper 
00007ffa`30eed072 c3                     ret     
;; END MAIN VECTOR LOGIC

;; BEGIN FALLBACK LOGIC
00007ffa`30eed073 33c0                   xor     eax, eax
00007ffa`30eed075 4883fa08               cmp     rdx, 8  ; rdx can never be < 8 (see lines d003 - d007 earlier)
00007ffa`30eed079 7241                   jb      System_Private_CoreLib!System.SpanHelpers.ClearWithReferences(IntPtr ByRef,  UIntPtr)+0xbc (00007ffa`30eed0bc)
00007ffa`30eed07b 4c8bc2                 mov     r8, rdx
00007ffa`30eed07e 4983e0f8               and     r8, 0FFFFFFFFFFFFFFF8h
00007ffa`30eed082 4c8bc8                 mov     r9, rax
00007ffa`30eed085 49c1e103               shl     r9, 3
00007ffa`30eed089 4533d2                 xor     r10d, r10d
00007ffa`30eed08c 4e891409               mov     qword ptr [rcx+r9], r10
00007ffa`30eed090 4e89540908             mov     qword ptr [rcx+r9+8], r10
00007ffa`30eed095 4e89540910             mov     qword ptr [rcx+r9+10h], r10
00007ffa`30eed09a 4e89540918             mov     qword ptr [rcx+r9+18h], r10
00007ffa`30eed09f 4e89540920             mov     qword ptr [rcx+r9+20h], r10
00007ffa`30eed0a4 4e89540928             mov     qword ptr [rcx+r9+28h], r10
00007ffa`30eed0a9 4e89540930             mov     qword ptr [rcx+r9+30h], r10
00007ffa`30eed0ae 4e89540938             mov     qword ptr [rcx+r9+38h], r10
00007ffa`30eed0b3 4883c008               add     rax, 8
00007ffa`30eed0b7 493bc0                 cmp     rax, r8
00007ffa`30eed0ba 72c6                   jb      System_Private_CoreLib!System.SpanHelpers.ClearWithReferences(IntPtr ByRef,  UIntPtr)+0x82 (00007ffa`30eed082)
00007ffa`30eed0bc f6c204                 test    dl, 4  ; rdx & 4 will always result in 0 (see lines d003 - d007 earlier)
00007ffa`30eed0bf 7421                   je      System_Private_CoreLib!System.SpanHelpers.ClearWithReferences(IntPtr ByRef,  UIntPtr)+0xe2 (00007ffa`30eed0e2)
00007ffa`30eed0c1 4c8bc0                 mov     r8, rax
00007ffa`30eed0c4 49c1e003               shl     r8, 3
00007ffa`30eed0c8 4533c9                 xor     r9d, r9d
00007ffa`30eed0cb 4e890c01               mov     qword ptr [rcx+r8], r9
00007ffa`30eed0cf 4e894c0108             mov     qword ptr [rcx+r8+8], r9
00007ffa`30eed0d4 4e894c0110             mov     qword ptr [rcx+r8+10h], r9
00007ffa`30eed0d9 4e894c0118             mov     qword ptr [rcx+r8+18h], r9
00007ffa`30eed0de 4883c004               add     rax, 4
00007ffa`30eed0e2 f6c202                 test    dl, 2
00007ffa`30eed0e5 7417                   je      System_Private_CoreLib!System.SpanHelpers.ClearWithReferences(IntPtr ByRef,  UIntPtr)+0xfe (00007ffa`30eed0fe)
00007ffa`30eed0e7 4c8bc0                 mov     r8, rax
00007ffa`30eed0ea 49c1e003               shl     r8, 3
00007ffa`30eed0ee 4533c9                 xor     r9d, r9d
00007ffa`30eed0f1 4e890c01               mov     qword ptr [rcx+r8], r9
00007ffa`30eed0f5 4e894c0108             mov     qword ptr [rcx+r8+8], r9
00007ffa`30eed0fa 4883c002               add     rax, 2
00007ffa`30eed0fe f6c201                 test    dl, 1
00007ffa`30eed101 7406                   je      System_Private_CoreLib!System.SpanHelpers.ClearWithReferences(IntPtr ByRef,  UIntPtr)+0x109 (00007ffa`30eed109)
00007ffa`30eed103 33d2                   xor     edx, edx
00007ffa`30eed105 488914c1               mov     qword ptr [rcx+rax*8], rdx
00007ffa`30eed109 c5f877                 vzeroupper 
00007ffa`30eed10c c3                     ret

{
if (pointerSizeLength >= 8)
{
nuint stopLoopAtOffset = pointerSizeLength & ~(nuint)7;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I like this pattern I'm not sure if

nuint stopLoopAt = pointerSizeLength - 8;
do
{
    // ...
} while ((i += 8) <= stopLoopAt);

is a "bit better".

-mov       r9,rdx
-and       r9,0FFFFFFFFFFFFFFF8
+lea       r9,[rdx+0FFF8]

I doubt it will make a measurable difference in either side.

My thinking is:
The mov r9,rdx likely won't reach the cpu's backend, but the and needs to be executed.
For the lea this can be handled in the AGU.
And it's 3 bytes less code (hooray).
It's a bit more readable, though I assume anyone reading code like this is aware of these tricks.

Out of interest: did you have a reason to choose the bit-hack?

The same applies to other places too (for the record 😉).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit twiddling hack is muscle memory, since it allows this pattern to work:

nuint i = 0;
nuint lastOffset = (whatever) & ~(nuint)7;
for (; i != lastOffset; i += 8)
{
    /* do work */
}

JIT will use the and from the preceding line as the input into the jcc at the start of the loop. Allows repurposing the bit-twiddling as the "are you even large enough?" comparison in the first place.

Though to your point, I agree that if we already do the comparison upfront, this twiddling isn't needed. Or we could skip the earlier comparison, eagerly perform the bit-twiddling, and use the result of the twiddling as the comparison.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

do
{
Unsafe.Add(ref ip, (nint)i + 0) = default;
Unsafe.Add(ref ip, (nint)i + 1) = default;
Unsafe.Add(ref ip, (nint)i + 2) = default;
Unsafe.Add(ref ip, (nint)i + 3) = default;
Unsafe.Add(ref ip, (nint)i + 4) = default;
Unsafe.Add(ref ip, (nint)i + 5) = default;
Unsafe.Add(ref ip, (nint)i + 6) = default;
Unsafe.Add(ref ip, (nint)i + 7) = default;
} while ((i += 8) < stopLoopAtOffset);
}
}

Write2To3:
Debug.Assert(pointerSizeLength >= 2);
// Write next 4 elements if needed
// Skip this check if "write 4 elements" would've gone down the vectorized code path earlier

// Write first two and last one.
Unsafe.Add(ref ip, 1) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -1) = default;
if (!Vector.IsHardwareAccelerated || Vector<byte>.Count / sizeof(IntPtr) > 4) // JIT turns this into const true or false
{
if ((pointerSizeLength & 4) != 0)
{
Unsafe.Add(ref ip, (nint)i + 0) = default;
Unsafe.Add(ref ip, (nint)i + 1) = default;
Unsafe.Add(ref ip, (nint)i + 2) = default;
Unsafe.Add(ref ip, (nint)i + 3) = default;
i += 4;
}
}

Write1:
Debug.Assert(pointerSizeLength >= 1);
// Write next 2 elements if needed
// Skip this check if "write 2 elements" would've gone down the vectorized code path earlier

// Write only element.
ip = default;
if (!Vector.IsHardwareAccelerated || Vector<byte>.Count / sizeof(IntPtr) > 2) // JIT turns this into const true or false
{
if ((pointerSizeLength & 2) != 0)
{
Unsafe.Add(ref ip, (nint)i + 0) = default;
Unsafe.Add(ref ip, (nint)i + 1) = default;
i += 2;
}
}

// Write final element if needed

if ((pointerSizeLength & 1) != 0)
{
Unsafe.Add(ref ip, (nint)i) = default;
}
Comment on lines +454 to +457
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ((pointerSizeLength & 1) != 0)
{
Unsafe.Add(ref ip, (nint)i) = default;
}
Unsafe.Add(ref ip, (nint)(pointerSizeLength - 1) = default;

Can be written at the end directly. Same as #51365 (comment) applies here.

-test      dl,1
-je        short M01_L04
-mov       [rcx+rax],0x0
+mov       [rcx+rdx+0FFFF],0x0

Unfortunately in the other PR I forgot to look at the scalar path 😢

if ((numElements & 1) != 0)
{
Unsafe.Add(ref refData, (nint)i) = value;
}

Can you please update this? (I don't want to submit a PR for this mini change when you're on it.)

}
}
}
}