Skip to content

Commit

Permalink
Add the barebones support for using embedded masking with AVX512 (#97675
Browse files Browse the repository at this point in the history
)

* Add the barebones support for using embedded masking with AVX512

* Applying formatting patch

* Add some basic asserts to ensure _idCustom# isn't used incorrectly

* Ensure that the instruction check is correct for TlsGD
  • Loading branch information
tannergooding authored Jan 31, 2024
1 parent b91ed70 commit cd460db
Show file tree
Hide file tree
Showing 11 changed files with 501 additions and 60 deletions.
68 changes: 60 additions & 8 deletions src/coreclr/jit/emit.h
Original file line number Diff line number Diff line change
Expand Up @@ -768,12 +768,26 @@ class emitter
unsigned _idLargeDsp : 1; // does a large displacement follow?
unsigned _idLargeCall : 1; // large call descriptor used

unsigned _idBound : 1; // jump target / frame offset bound
#ifndef TARGET_ARMARCH
unsigned _idCallRegPtr : 1; // IL indirect calls: addr in reg
#endif
unsigned _idTlsGD : 1; // Used to store information related to TLS GD access on linux
unsigned _idNoGC : 1; // Some helpers don't get recorded in GC tables
// We have several pieces of information we need to encode but which are only applicable
// to a subset of instrDescs. To accommodate that, we define a several _idCustom# bitfields
// and then some defineds to make accessing them simpler

unsigned _idCustom1 : 1;
unsigned _idCustom2 : 1;
unsigned _idCustom3 : 1;

#define _idBound _idCustom1 /* jump target / frame offset bound */
#define _idTlsGD _idCustom2 /* Used to store information related to TLS GD access on linux */
#define _idNoGC _idCustom3 /* Some helpers don't get recorded in GC tables */
#define _idEvexAaaContext (_idCustom3 << 2) | (_idCustom2 << 1) | _idCustom1 /* bits used for the EVEX.aaa context */

#if !defined(TARGET_ARMARCH)
unsigned _idCustom4 : 1;

#define _idCallRegPtr _idCustom4 /* IL indirect calls : addr in reg */
#define _idEvexZContext _idCustom4 /* bits used for the EVEX.z context */
#endif // !TARGET_ARMARCH

#if defined(TARGET_XARCH)
// EVEX.b can indicate several context: embedded broadcast, embedded rounding.
// For normal and embedded broadcast intrinsics, EVEX.L'L has the same semantic, vector length.
Expand Down Expand Up @@ -1578,30 +1592,36 @@ class emitter

bool idIsBound() const
{
assert(!IsAvx512OrPriorInstruction(_idIns));
return _idBound != 0;
}
void idSetIsBound()
{
assert(!IsAvx512OrPriorInstruction(_idIns));
_idBound = 1;
}

#ifndef TARGET_ARMARCH
bool idIsCallRegPtr() const
{
assert(!IsAvx512OrPriorInstruction(_idIns));
return _idCallRegPtr != 0;
}
void idSetIsCallRegPtr()
{
assert(!IsAvx512OrPriorInstruction(_idIns));
_idCallRegPtr = 1;
}
#endif
#endif // !TARGET_ARMARCH

bool idIsTlsGD() const
{
assert(!IsAvx512OrPriorInstruction(_idIns));
return _idTlsGD != 0;
}
void idSetTlsGD()
{
assert(!IsAvx512OrPriorInstruction(_idIns));
_idTlsGD = 1;
}

Expand All @@ -1610,10 +1630,12 @@ class emitter
// code, it is not necessary to generate GC info for a call so labeled.
bool idIsNoGC() const
{
assert(!IsAvx512OrPriorInstruction(_idIns));
return _idNoGC != 0;
}
void idSetIsNoGC(bool val)
{
assert(!IsAvx512OrPriorInstruction(_idIns));
_idNoGC = val;
}

Expand All @@ -1625,7 +1647,8 @@ class emitter

void idSetEvexbContext(insOpts instOptions)
{
assert(_idEvexbContext == 0);
assert(!idIsEvexbContextSet());

if (instOptions == INS_OPTS_EVEX_eb_er_rd)
{
_idEvexbContext = 1;
Expand All @@ -1648,6 +1671,34 @@ class emitter
{
return _idEvexbContext;
}

unsigned idGetEvexAaaContext() const
{
assert(IsAvx512OrPriorInstruction(_idIns));
return _idEvexAaaContext;
}

void idSetEvexAaaContext(insOpts instOptions)
{
assert(idGetEvexAaaContext() == 0);
unsigned value = static_cast<unsigned>((instOptions & INS_OPTS_EVEX_aaa_MASK) >> 2);

_idCustom1 = ((value >> 0) & 1);
_idCustom2 = ((value >> 1) & 1);
_idCustom3 = ((value >> 2) & 1);
}

bool idIsEvexZContextSet() const
{
assert(IsAvx512OrPriorInstruction(_idIns));
return _idEvexZContext != 0;
}

void idSetEvexZContext()
{
assert(!idIsEvexZContextSet());
_idEvexZContext = 1;
}
#endif

#ifdef TARGET_ARMARCH
Expand Down Expand Up @@ -2222,6 +2273,7 @@ class emitter
void emitDispInsHex(instrDesc* id, BYTE* code, size_t sz);
void emitDispEmbBroadcastCount(instrDesc* id);
void emitDispEmbRounding(instrDesc* id);
void emitDispEmbMasking(instrDesc* id);
void emitDispIns(instrDesc* id,
bool isNew,
bool doffs,
Expand Down
97 changes: 67 additions & 30 deletions src/coreclr/jit/emitxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,6 @@ bool emitter::IsKInstruction(instruction ins)
return (flags & KInstruction) != 0;
}

//------------------------------------------------------------------------
// IsAvx512OrPriorInstruction: Is this an Avx512 or Avx or Sse or K (opmask) instruction.
// Technically, K instructions would be considered under the VEX encoding umbrella, but due to
// the instruction table encoding had to be pulled out with the rest of the `INST5` definitions.
//
// Arguments:
// ins - The instruction to check.
//
// Returns:
// `true` if it is a sse or avx or avx512 instruction.
//
bool emitter::IsAvx512OrPriorInstruction(instruction ins)
{
// TODO-XArch-AVX512: Fix check once AVX512 instructions are added.
return ((ins >= INS_FIRST_SSE_INSTRUCTION) && (ins <= INS_LAST_AVX512_INSTRUCTION));
}

bool emitter::IsAVXOnlyInstruction(instruction ins)
{
return (ins >= INS_FIRST_AVX_INSTRUCTION) && (ins <= INS_LAST_AVX_INSTRUCTION);
Expand Down Expand Up @@ -1304,9 +1287,10 @@ bool emitter::TakesEvexPrefix(const instrDesc* id) const
#define DEFAULT_BYTE_EVEX_PREFIX 0x62F07C0800000000ULL

#define DEFAULT_BYTE_EVEX_PREFIX_MASK 0xFFFFFFFF00000000ULL
#define BBIT_IN_BYTE_EVEX_PREFIX 0x0000001000000000ULL
#define LBIT_IN_BYTE_EVEX_PREFIX 0x0000002000000000ULL
#define LPRIMEBIT_IN_BYTE_EVEX_PREFIX 0x0000004000000000ULL
#define EVEX_B_BIT 0x0000001000000000ULL
#define ZBIT_IN_BYTE_EVEX_PREFIX 0x0000008000000000ULL

//------------------------------------------------------------------------
// AddEvexPrefix: Add default EVEX prefix with only LL' bits set.
Expand Down Expand Up @@ -1344,7 +1328,7 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt

if (id->idIsEvexbContextSet())
{
code |= EVEX_B_BIT;
code |= BBIT_IN_BYTE_EVEX_PREFIX;

if (!id->idHasMem())
{
Expand Down Expand Up @@ -1385,6 +1369,8 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt
{
case IF_RWR_RRD_ARD_RRD:
{
assert(id->idGetEvexAaaContext() == 0);

CnsVal cnsVal;
emitGetInsAmdCns(id, &cnsVal);

Expand All @@ -1394,6 +1380,8 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt

case IF_RWR_RRD_MRD_RRD:
{
assert(id->idGetEvexAaaContext() == 0);

CnsVal cnsVal;
emitGetInsDcmCns(id, &cnsVal);

Expand All @@ -1403,6 +1391,8 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt

case IF_RWR_RRD_SRD_RRD:
{
assert(id->idGetEvexAaaContext() == 0);

CnsVal cnsVal;
emitGetInsCns(id, &cnsVal);

Expand All @@ -1412,12 +1402,24 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt

case IF_RWR_RRD_RRD_RRD:
{
assert(id->idGetEvexAaaContext() == 0);
maskReg = id->idReg4();
break;
}

default:
{
unsigned aaaContext = id->idGetEvexAaaContext();

if (aaaContext != 0)
{
maskReg = static_cast<regNumber>(aaaContext + KBASE);

if (id->idIsEvexZContextSet())
{
code |= ZBIT_IN_BYTE_EVEX_PREFIX;
}
}
break;
}
}
Expand Down Expand Up @@ -4170,9 +4172,8 @@ UNATIVE_OFFSET emitter::emitInsSizeAM(instrDesc* id, code_t code)
}

// If this is just "call reg", we're done.
if (id->idIsCallRegPtr())
if (((ins == INS_call) || (ins == INS_tail_i_jmp)) && id->idIsCallRegPtr())
{
assert(ins == INS_call || ins == INS_tail_i_jmp);
assert(dsp == 0);
return size;
}
Expand Down Expand Up @@ -6822,7 +6823,9 @@ void emitter::emitIns_R_R_A(
id->idIns(ins);
id->idReg1(reg1);
id->idReg2(reg2);

SetEvexBroadcastIfNeeded(id, instOptions);
SetEvexEmbMaskIfNeeded(id, instOptions);

emitHandleMemOp(indir, id, (ins == INS_mulx) ? IF_RWR_RWR_ARD : emitInsModeFormat(ins, IF_RRD_RRD_ARD), ins);

Expand Down Expand Up @@ -6947,7 +6950,9 @@ void emitter::emitIns_R_R_C(instruction ins,
id->idReg1(reg1);
id->idReg2(reg2);
id->idAddr()->iiaFieldHnd = fldHnd;

SetEvexBroadcastIfNeeded(id, instOptions);
SetEvexEmbMaskIfNeeded(id, instOptions);

UNATIVE_OFFSET sz = emitInsSizeCV(id, insCodeRM(ins));
id->idCodeSize(sz);
Expand All @@ -6974,12 +6979,13 @@ void emitter::emitIns_R_R_R(
id->idReg2(reg1);
id->idReg3(reg2);

if ((instOptions & INS_OPTS_b_MASK) != INS_OPTS_NONE)
if ((instOptions & INS_OPTS_EVEX_b_MASK) != 0)
{
// if EVEX.b needs to be set in this path, then it should be embedded rounding.
assert(UseEvexEncoding());
id->idSetEvexbContext(instOptions);
}
SetEvexEmbMaskIfNeeded(id, instOptions);

UNATIVE_OFFSET sz = emitInsSizeRR(id, insCodeRM(ins));
id->idCodeSize(sz);
Expand All @@ -7001,7 +7007,9 @@ void emitter::emitIns_R_R_S(
id->idReg1(reg1);
id->idReg2(reg2);
id->idAddr()->iiaLclVar.initLclVarAddr(varx, offs);

SetEvexBroadcastIfNeeded(id, instOptions);
SetEvexEmbMaskIfNeeded(id, instOptions);

#ifdef DEBUG
id->idDebugOnlyInfo()->idVarRefOffs = emitVarRefOffs;
Expand Down Expand Up @@ -10785,6 +10793,28 @@ void emitter::emitDispEmbRounding(instrDesc* id)
}
}

// emitDispEmbMasking: Display the tag where embedded masking is activated
//
// Arguments:
// id - The instruction descriptor
//
void emitter::emitDispEmbMasking(instrDesc* id)
{
regNumber maskReg = static_cast<regNumber>(id->idGetEvexAaaContext() + KBASE);

if (maskReg == REG_K0)
{
return;
}

printf(" {%s}", emitRegName(maskReg));

if (id->idIsEvexZContextSet())
{
printf(" {z}");
}
}

//--------------------------------------------------------------------
// emitDispIns: Dump the given instruction to jitstdout.
//
Expand Down Expand Up @@ -11033,7 +11063,7 @@ void emitter::emitDispIns(
case IF_AWR:
case IF_ARW:
{
if (id->idIsCallRegPtr())
if (((ins == INS_call) || (ins == INS_tail_i_jmp)) && id->idIsCallRegPtr())
{
printf("%s", emitRegName(id->idAddr()->iiaAddrMode.amBaseReg));
}
Expand Down Expand Up @@ -11184,7 +11214,9 @@ void emitter::emitDispIns(
case IF_RRW_RRD_ARD:
case IF_RWR_RWR_ARD:
{
printf("%s, %s, %s", emitRegName(id->idReg1(), attr), emitRegName(id->idReg2(), attr), sstr);
printf("%s", emitRegName(id->idReg1(), attr));
emitDispEmbMasking(id);
printf(", %s, %s", emitRegName(id->idReg2(), attr), sstr);
emitDispAddrMode(id);
emitDispEmbBroadcastCount(id);
break;
Expand Down Expand Up @@ -11458,7 +11490,9 @@ void emitter::emitDispIns(
case IF_RRW_RRD_SRD:
case IF_RWR_RWR_SRD:
{
printf("%s, %s, %s", emitRegName(id->idReg1(), attr), emitRegName(id->idReg2(), attr), sstr);
printf("%s", emitRegName(id->idReg1(), attr));
emitDispEmbMasking(id);
printf(", %s, %s", emitRegName(id->idReg2(), attr), sstr);
emitDispFrameRef(id->idAddr()->iiaLclVar.lvaVarNum(), id->idAddr()->iiaLclVar.lvaOffset(),
id->idDebugOnlyInfo()->idVarRefOffs, asmfm);
emitDispEmbBroadcastCount(id);
Expand Down Expand Up @@ -11652,8 +11686,9 @@ void emitter::emitDispIns(
reg2 = reg3;
reg3 = tmp;
}
printf("%s, ", emitRegName(id->idReg1(), attr));
printf("%s, ", emitRegName(reg2, attr));
printf("%s", emitRegName(id->idReg1(), attr));
emitDispEmbMasking(id);
printf(", %s, ", emitRegName(reg2, attr));
printf("%s", emitRegName(reg3, attr));
emitDispEmbRounding(id);
break;
Expand Down Expand Up @@ -11964,7 +11999,9 @@ void emitter::emitDispIns(
case IF_RRW_RRD_MRD:
case IF_RWR_RWR_MRD:
{
printf("%s, %s, %s", emitRegName(id->idReg1(), attr), emitRegName(id->idReg2(), attr), sstr);
printf("%s", emitRegName(id->idReg1(), attr));
emitDispEmbMasking(id);
printf(", %s, %s", emitRegName(id->idReg2(), attr), sstr);
offs = emitGetInsDsp(id);
emitDispClsVar(id->idAddr()->iiaFieldHnd, offs, ID_INFO_DSP_RELOC);
emitDispEmbBroadcastCount(id);
Expand Down Expand Up @@ -12918,7 +12955,7 @@ BYTE* emitter::emitOutputAM(BYTE* dst, instrDesc* id, code_t code, CnsVal* addc)
#else
dst += emitOutputLong(dst, dsp);
#endif
if (id->idIsTlsGD())
if (!IsAvx512OrPriorInstruction(ins) && id->idIsTlsGD())
{
addlDelta = -4;
emitRecordRelocationWithAddlDelta((void*)(dst - sizeof(INT32)), (void*)dsp, IMAGE_REL_TLSGD,
Expand Down Expand Up @@ -16648,7 +16685,7 @@ size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
}

#ifdef DEBUG
if (ins == INS_call && !id->idIsTlsGD())
if ((ins == INS_call) && !id->idIsTlsGD())
{
emitRecordCallSite(emitCurCodeOffs(*dp), id->idDebugOnlyInfo()->idCallSig,
(CORINFO_METHOD_HANDLE)id->idDebugOnlyInfo()->idMemCookie);
Expand Down
Loading

0 comments on commit cd460db

Please sign in to comment.