-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the barebones support for using embedded masking with AVX512 #97675
Conversation
Tagging subscribers to this area: @JulieLeeMSFT, @jakobbotsch Issue DetailsAs per the title, this adds the minimal recognition/support for emitting AVX-512 instructions with embedded masking support. It currently only targets For something like: public static Vector512<double> M(Vector512<double> x, Vector512<double> y)
{
return Vector512.ConditionalSelect(Vector512.Equals(x, x), x + x, y + y);
} We will emit the following: vmovups zmm0, zmmword ptr [rdx]
vcmppd k1, zmm0, zmm0, 0
vmovups zmm1, zmmword ptr [r8]
vaddpd zmm1, zmm1, zmm1
vaddpd zmm1 {k1}, zmm0, zmm0 ; <--- embedded mask used here
vmovups zmmword ptr [rcx], zmm1
mov rax, rcx
|
// We have several pieces of information we need to encode but which are only applicable | ||
// to a subset of instrDescs. To accommodate that, we define a several _idCustom# bitfields | ||
// and then some defineds to make accessing them simpler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These bits are "expensive" and impact the maximum size of "small" constants, so I opted to repurpose these existing 3 bits that are only used for IF_LABEL
, IF_METHOD
, and related formats. They will never conflict with the SIMD instructions so this ends up being a nice way to fit it in, IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No TP impact, so this works!
{ | ||
regNumber maskReg = static_cast<regNumber>(id->idGetEvexAaaContext() + KBASE); | ||
|
||
if (maskReg == REG_K0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
K0
is special and basically means "don't mask"
// TODO-AVX512F-CQ: Expand this to the full set of APIs and make it table driven | ||
// using IsEmbMaskingCompatible. For now, however, limit it to some explicit ids | ||
// for prototyping purposes. | ||
return (AsHWIntrinsic()->GetHWIntrinsicId() == NI_AVX512F_Add); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the main bit that actually drives which intrinsics can support embedded masking. In actuality this is most intrinsics and so the actual flag we'll add to the table is HW_Flag_EmbMaskingIncompatible
to indicate the few that can't use it.
Just wanted to get the baseline support up first and then go and plumb through the minor connecting points after to avoid any excess churn based on feedback, etc.
@@ -556,6 +556,10 @@ enum GenTreeFlags : unsigned int | |||
GTF_MDARRLEN_NONFAULTING = 0x20000000, // GT_MDARR_LENGTH -- An MD array length operation that cannot fault. Same as GT_IND_NONFAULTING. | |||
|
|||
GTF_MDARRLOWERBOUND_NONFAULTING = 0x20000000, // GT_MDARR_LOWER_BOUND -- An MD array lower bound operation that cannot fault. Same as GT_IND_NONFAULTING. | |||
|
|||
#ifdef TARGET_XARCH | |||
GTF_HW_EM_OP = 0x10000000, // GT_HWINTRINSIC -- node is used as an operand to an embedded mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to figure out a way to do this without a flag, but we have several intrinsics that can be contained for a multiple of reasons and so it ends up being a bit cleaner to just tag the node that is contained for embedded masking purposes to differentiate it.
After we finish plumbing through the rest of the instructions, we might find its actually not needed (we might be able to introduce helper IDs to disambiguate for example) and we can remove it then if appropriate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be under ifdef FEATURE_HW_INTRINSICS
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be, but there isn't a strong need for it, so I'd prefer to handle in a follow up to avoid the additional CI churn (will fix it in this PR if I need to push anything for the seemingly unrelated tls_InlinedThreadStatic
failure though)
{ | ||
GenTree* op2 = node->Op(2); | ||
|
||
if (op2->IsEmbMaskOp()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much like with the embedded rounding support, this approach allows us to avoid recursion or tracking unnecessary state separately. Instead we can recognize the special scenario up front and extract the info to the tracked insOpt
and pass it through. This allows us to reuse all the existing code paths.
// TODO-AVX512-CQ: Ensure we can support embedded operations on RMW intrinsics | ||
assert(!op2->isRMWHWIntrinsic(compiler)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will need some more handling for cases like FMA
which are RMW already. Those cases may need to emit a movaps
with embedded masking + the underlying instruction with embedded masking to ensure we get all the correct codegen.
Diff results for #97675Assembly diffsAssembly diffs for linux/x64 ran on windows/x64Diffs are based on 2,249,675 contexts (981,298 MinOpts, 1,268,377 FullOpts). MISSED contexts: 134 (0.01%) Overall (-6 bytes)
FullOpts (-6 bytes)
Assembly diffs for windows/x64 ran on windows/x64Diffs are based on 2,098,432 contexts (926,221 MinOpts, 1,172,211 FullOpts). MISSED contexts: 138 (0.01%) Overall (-151 bytes)
FullOpts (-151 bytes)
Details here Throughput diffsThroughput diffs for linux/arm64 ran on windows/x64MinOpts (-0.00% to +0.01%)
Throughput diffs for linux/x64 ran on windows/x64MinOpts (-0.01% to 0.00%)
Throughput diffs for windows/arm64 ran on windows/x64MinOpts (-0.00% to +0.01%)
Details here |
48219f5
to
9c2a71f
Compare
Diff results for #97675Assembly diffsAssembly diffs for windows/x64 ran on linux/x64Diffs are based on 2,098,432 contexts (926,221 MinOpts, 1,172,211 FullOpts). MISSED contexts: 138 (0.01%) Overall (-151 bytes)
FullOpts (-151 bytes)
Details here Throughput diffsThroughput diffs for linux/arm64 ran on windows/x64MinOpts (-0.01% to +0.00%)
Throughput diffs for linux/x64 ran on windows/x64Overall (-0.01% to +0.00%)
MinOpts (-0.01% to +0.00%)
FullOpts (-0.01% to +0.00%)
Throughput diffs for windows/arm64 ran on windows/x64MinOpts (-0.00% to +0.01%)
Details here Throughput diffs for windows/x86 ran on windows/x86MinOpts (+0.00% to +0.01%)
Details here |
Diff results for #97675Assembly diffsAssembly diffs for linux/x64 ran on windows/x64Diffs are based on 2,249,675 contexts (981,298 MinOpts, 1,268,377 FullOpts). MISSED contexts: 134 (0.01%) Overall (-6 bytes)
FullOpts (-6 bytes)
Details here |
Diff results for #97675Assembly diffsAssembly diffs for linux/x64 ran on windows/x64Diffs are based on 2,249,675 contexts (981,298 MinOpts, 1,268,377 FullOpts). MISSED contexts: 134 (0.01%) Overall (-6 bytes)
FullOpts (-6 bytes)
Assembly diffs for windows/x64 ran on windows/x64Diffs are based on 2,227,722 contexts (987,923 MinOpts, 1,239,799 FullOpts). MISSED contexts: 138 (0.01%) Overall (-151 bytes)
FullOpts (-151 bytes)
Details here Throughput diffsThroughput diffs for linux/x64 ran on windows/x64Overall (-0.01% to -0.00%)
MinOpts (-0.01% to +0.00%)
FullOpts (-0.01% to -0.00%)
Throughput diffs for windows/arm64 ran on windows/x64MinOpts (-0.00% to +0.01%)
Details here Throughput diffs for windows/x86 ran on windows/x86MinOpts (+0.00% to +0.01%)
Details here |
b64594e
to
b4a5c72
Compare
b4a5c72
to
4a0eba9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// We have several pieces of information we need to encode but which are only applicable | ||
// to a subset of instrDescs. To accommodate that, we define a several _idCustom# bitfields | ||
// and then some defineds to make accessing them simpler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No TP impact, so this works!
@@ -556,6 +556,10 @@ enum GenTreeFlags : unsigned int | |||
GTF_MDARRLEN_NONFAULTING = 0x20000000, // GT_MDARR_LENGTH -- An MD array length operation that cannot fault. Same as GT_IND_NONFAULTING. | |||
|
|||
GTF_MDARRLOWERBOUND_NONFAULTING = 0x20000000, // GT_MDARR_LOWER_BOUND -- An MD array lower bound operation that cannot fault. Same as GT_IND_NONFAULTING. | |||
|
|||
#ifdef TARGET_XARCH | |||
GTF_HW_EM_OP = 0x10000000, // GT_HWINTRINSIC -- node is used as an operand to an embedded mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be under ifdef FEATURE_HW_INTRINSICS
?
HW_Flag_EmbRoundingCompatible = 0x10000000, | ||
|
||
// The intrinsic is an embedded masking incompatible intrinsic | ||
HW_Flag_EmbMaskingIncompatible = 0x20000000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
likewise this? should it be under #ifdef FEATURE_HW_INTRINSICS
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in the hwintrinsics header, so its already effectively FEATURE_HW_INTRINSICS
The diffs you highlighted are from a mistake in one of the earlier commits that I fixed this morning and I expect them to go away in the very latest. I wouldn't expect really any diffs to show up in our current code since we avoided several of the patterns while the support didn't exist and it only applies to |
Diff results for #97675Throughput diffsThroughput diffs for linux/arm64 ran on windows/x64MinOpts (-0.00% to +0.01%)
Throughput diffs for linux/x64 ran on windows/x64MinOpts (-0.01% to +0.00%)
Details here Throughput diffs for windows/x86 ran on linux/x86MinOpts (+0.00% to +0.01%)
Details here |
Diff results for #97675Throughput diffsThroughput diffs for linux/x64 ran on linux/x64MinOpts (-0.01% to +0.00%)
Throughput diffs for osx/arm64 ran on linux/x64MinOpts (-0.00% to +0.01%)
Details here |
@dotnet/runtime-infrastructure, @MichalStrehovsky I'm seeing some failures here like the following:
It's unclear what's causing this and I don't see any issues elsewhere. The only reference to that variable I can find is in the aot ILCompiler: src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/TlsRootNode.cs But of course there isn't any changes to that or related areas in this PR. Any help would be appreciated |
Diff results for #97675Throughput diffsThroughput diffs for linux/x64 ran on linux/x64MinOpts (-0.01% to +0.00%)
Details here Throughput diffs for linux/arm64 ran on windows/x64MinOpts (-0.01% to +0.00%)
Details here Throughput diffs for windows/x86 ran on windows/x86MinOpts (+0.00% to +0.01%)
Details here |
https://github.com/dotnet/runtime/pull/97675/files#diff-125fb0b9396b0e85ee82c6d592cdf1750778a7cff38ad382f7c2966ef588dce1R780 might have affected it. With the bits overriding, it might be that we are not setting the |
Thanks for the help Kunal! -- Issue was that a check was slightly wrong and so |
Diff results for #97675Throughput diffsThroughput diffs for linux/arm64 ran on windows/x64MinOpts (-0.01% to +0.00%)
Throughput diffs for windows/arm64 ran on windows/x64MinOpts (-0.00% to +0.01%)
Throughput diffs for windows/x64 ran on windows/x64MinOpts (+0.00% to +0.01%)
Details here Throughput diffs for windows/x86 ran on linux/x86MinOpts (+0.00% to +0.01%)
Details here |
Diff results for #97675Throughput diffsThroughput diffs for linux/arm64 ran on windows/x64MinOpts (-0.01% to +0.00%)
Throughput diffs for windows/arm64 ran on windows/x64MinOpts (-0.00% to +0.01%)
Throughput diffs for windows/x64 ran on windows/x64MinOpts (+0.00% to +0.01%)
Details here Throughput diffs for windows/x86 ran on windows/x86MinOpts (+0.00% to +0.01%)
Details here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
bool IsEmbMaskOp() | ||
{ | ||
bool result = (gtFlags & GTF_HW_EM_OP) != 0; | ||
assert(!result || (gtOper == GT_HWINTRINSIC)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assert is invalid: the same bit used by GTF_HW_EM_OP
is used by other flags for other GenTree types (e.g., GTF_OVERFLOW).
You could have:
assert(gtOper == GT_HWINTRINSIC);
if you only allow HWINTRINSIC nodes to call this. Or, you need it to be dynamic:
if (gtOper == GT_HWINTRINSIC)
{
result = (gtFlags & GTF_HW_EM_OP) != 0;
}
else
{
result = false;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed with #98306
As per the title, this adds the minimal recognition/support for emitting AVX-512 instructions with embedded masking support.
It currently only targets
Avx512F.Add
but can be trivially expanded to other instructions via a table driven approach after this first PR is merged.For something like:
We will emit the following: