Skip to content

Commit

Permalink
[NFC][SPIR-V] Refactor SpirvGroupNonUniformOps (#6596)
Browse files Browse the repository at this point in the history
A follow-up change will use the PartitionedExclusiveScanNV
GroupOperation, which requires that an additional operand is added to
all GroupNonUniformArithmetic instructions. This means that some of the
SPIR-V opcodes which are currently categorized as unary will become
either unary or binary depending on the GroupOp. Since the arity
distinctions between the OpGroupNonUniform* instructions were already
somewhat arbitrary, I'm prefacing that change by refactoring them into a
single SpirvGroupNonUniformOp instruction type for better reusability.

Follow up: #6608
  • Loading branch information
sudonatalie authored May 15, 2024
1 parent ff623f8 commit d9caef5
Show file tree
Hide file tree
Showing 13 changed files with 125 additions and 262 deletions.
13 changes: 3 additions & 10 deletions tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,10 @@ class SpirvBuilder {

/// \brief Creates an operation with the given OpGroupNonUniform* SPIR-V
/// opcode.
SpirvNonUniformElect *createGroupNonUniformElect(spv::Op op,
QualType resultType,
spv::Scope execScope,
SourceLocation);
SpirvNonUniformUnaryOp *createGroupNonUniformUnaryOp(
SourceLocation, spv::Op op, QualType resultType, spv::Scope execScope,
SpirvInstruction *operand,
llvm::Optional<spv::GroupOperation> groupOp = llvm::None);
SpirvNonUniformBinaryOp *createGroupNonUniformBinaryOp(
SpirvGroupNonUniformOp *createGroupNonUniformOp(
spv::Op op, QualType resultType, spv::Scope execScope,
SpirvInstruction *operand1, SpirvInstruction *operand2, SourceLocation);
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation,
llvm::Optional<spv::GroupOperation> groupOp = llvm::None);

/// \brief Creates an atomic instruction with the given parameters and returns
/// its pointer.
Expand Down
97 changes: 17 additions & 80 deletions tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,7 @@ class SpirvInstruction {

IK_SetMeshOutputsEXT, // OpSetMeshOutputsEXT

// The following section is for group non-uniform instructions.
// Used by LLVM-style RTTI; order matters.
IK_GroupNonUniformBinaryOp, // Group non-uniform binary operations
IK_GroupNonUniformElect, // OpGroupNonUniformElect
IK_GroupNonUniformUnaryOp, // Group non-uniform unary operations
IK_GroupNonUniformOp, // Group non-uniform operations

IK_ImageOp, // OpImage*
IK_ImageQuery, // OpImageQuery*
Expand Down Expand Up @@ -1495,102 +1491,43 @@ class SpirvFunctionCall : public SpirvInstruction {
llvm::SmallVector<SpirvInstruction *, 4> args;
};

/// \brief Base for OpGroupNonUniform* instructions
/// \brief OpGroupNonUniform* instructions
class SpirvGroupNonUniformOp : public SpirvInstruction {
public:
// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() >= IK_GroupNonUniformBinaryOp &&
inst->getKind() <= IK_GroupNonUniformUnaryOp;
}

spv::Scope getExecutionScope() const { return execScope; }

protected:
SpirvGroupNonUniformOp(Kind kind, spv::Op opcode, QualType resultType,
SourceLocation loc, spv::Scope scope);

private:
spv::Scope execScope;
};

/// \brief OpGroupNonUniform* binary instructions.
class SpirvNonUniformBinaryOp : public SpirvGroupNonUniformOp {
public:
SpirvNonUniformBinaryOp(spv::Op opcode, QualType resultType,
SourceLocation loc, spv::Scope scope,
SpirvInstruction *arg1, SpirvInstruction *arg2);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformBinaryOp)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_GroupNonUniformBinaryOp;
}

bool invokeVisitor(Visitor *v) override;

SpirvInstruction *getArg1() const { return arg1; }
SpirvInstruction *getArg2() const { return arg2; }
void replaceOperand(
llvm::function_ref<SpirvInstruction *(SpirvInstruction *)> remapOp,
bool inEntryFunctionWrapper) override {
arg1 = remapOp(arg1);
arg2 = remapOp(arg2);
}

private:
SpirvInstruction *arg1;
SpirvInstruction *arg2;
};

/// \brief OpGroupNonUniformElect instruction. This is currently the only
/// non-uniform instruction that takes no other arguments.
class SpirvNonUniformElect : public SpirvGroupNonUniformOp {
public:
SpirvNonUniformElect(QualType resultType, SourceLocation loc,
spv::Scope scope);
SpirvGroupNonUniformOp(spv::Op opcode, QualType resultType, spv::Scope scope,
llvm::ArrayRef<SpirvInstruction *> operands,
SourceLocation loc,
llvm::Optional<spv::GroupOperation> group);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformElect)
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvGroupNonUniformOp)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_GroupNonUniformElect;
return inst->getKind() == IK_GroupNonUniformOp;
}

bool invokeVisitor(Visitor *v) override;
};

/// \brief OpGroupNonUniform* unary instructions.
class SpirvNonUniformUnaryOp : public SpirvGroupNonUniformOp {
public:
SpirvNonUniformUnaryOp(spv::Op opcode, QualType resultType,
SourceLocation loc, spv::Scope scope,
llvm::Optional<spv::GroupOperation> group,
SpirvInstruction *arg);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformUnaryOp)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_GroupNonUniformUnaryOp;
}
spv::Scope getExecutionScope() const { return execScope; }

bool invokeVisitor(Visitor *v) override;
llvm::ArrayRef<SpirvInstruction *> getOperands() const { return operands; }

SpirvInstruction *getArg() const { return arg; }
bool hasGroupOp() const { return groupOp.hasValue(); }
spv::GroupOperation getGroupOp() const { return groupOp.getValue(); }

void replaceOperand(
llvm::function_ref<SpirvInstruction *(SpirvInstruction *)> remapOp,
bool inEntryFunctionWrapper) override {
arg = remapOp(arg);
for (auto *operand : getOperands()) {
operand = remapOp(operand);
}
if (inEntryFunctionWrapper)
setAstResultType(arg->getAstResultType());
setAstResultType(getOperands()[0]->getAstResultType());
}

private:
SpirvInstruction *arg;
spv::Scope execScope;
llvm::SmallVector<SpirvInstruction *, 4> operands;
llvm::Optional<spv::GroupOperation> groupOp;
};

Expand Down
4 changes: 1 addition & 3 deletions tools/clang/include/clang/SPIRV/SpirvVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ class Visitor {
DEFINE_VISIT_METHOD(SpirvEndPrimitive)
DEFINE_VISIT_METHOD(SpirvExtInst)
DEFINE_VISIT_METHOD(SpirvFunctionCall)
DEFINE_VISIT_METHOD(SpirvNonUniformBinaryOp)
DEFINE_VISIT_METHOD(SpirvNonUniformElect)
DEFINE_VISIT_METHOD(SpirvNonUniformUnaryOp)
DEFINE_VISIT_METHOD(SpirvGroupNonUniformOp)
DEFINE_VISIT_METHOD(SpirvImageOp)
DEFINE_VISIT_METHOD(SpirvImageQuery)
DEFINE_VISIT_METHOD(SpirvImageSparseTexelsResident)
Expand Down
33 changes: 3 additions & 30 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1087,35 +1087,7 @@ bool EmitVisitor::visit(SpirvFunctionCall *inst) {
return true;
}

bool EmitVisitor::visit(SpirvNonUniformBinaryOp *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
curInst.push_back(typeHandler.getOrCreateConstantInt(
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
context.getUIntType(32), /* isSpecConst */ false));
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg1()));
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg2()));
finalizeInstruction(&mainBinary);
emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}

bool EmitVisitor::visit(SpirvNonUniformElect *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
curInst.push_back(typeHandler.getOrCreateConstantInt(
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
context.getUIntType(32), /* isSpecConst */ false));
finalizeInstruction(&mainBinary);
emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
inst->getDebugName());
return true;
}

bool EmitVisitor::visit(SpirvNonUniformUnaryOp *inst) {
bool EmitVisitor::visit(SpirvGroupNonUniformOp *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
Expand All @@ -1124,7 +1096,8 @@ bool EmitVisitor::visit(SpirvNonUniformUnaryOp *inst) {
context.getUIntType(32), /* isSpecConst */ false));
if (inst->hasGroupOp())
curInst.push_back(static_cast<uint32_t>(inst->getGroupOp()));
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg()));
for (auto *operand : inst->getOperands())
curInst.push_back(getOrAssignResultId<SpirvInstruction>(operand));
finalizeInstruction(&mainBinary);
emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
inst->getDebugName());
Expand Down
4 changes: 1 addition & 3 deletions tools/clang/lib/SPIRV/EmitVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,7 @@ class EmitVisitor : public Visitor {
bool visit(SpirvCompositeInsert *) override;
bool visit(SpirvExtInst *) override;
bool visit(SpirvFunctionCall *) override;
bool visit(SpirvNonUniformBinaryOp *) override;
bool visit(SpirvNonUniformElect *) override;
bool visit(SpirvNonUniformUnaryOp *) override;
bool visit(SpirvGroupNonUniformOp *) override;
bool visit(SpirvImageOp *) override;
bool visit(SpirvImageQuery *) override;
bool visit(SpirvImageSparseTexelsResident *) override;
Expand Down
14 changes: 3 additions & 11 deletions tools/clang/lib/SPIRV/LiteralTypeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,9 @@ bool LiteralTypeVisitor::visit(SpirvVectorShuffle *inst) {
return true;
}

bool LiteralTypeVisitor::visit(SpirvNonUniformUnaryOp *inst) {
// Went through each non-uniform binary operation and made sure the following
// does not result in a wrong type deduction.
tryToUpdateInstLitType(inst->getArg(), inst->getAstResultType());
return true;
}

bool LiteralTypeVisitor::visit(SpirvNonUniformBinaryOp *inst) {
// Went through each non-uniform unary operation and made sure the following
// does not result in a wrong type deduction.
tryToUpdateInstLitType(inst->getArg1(), inst->getAstResultType());
bool LiteralTypeVisitor::visit(SpirvGroupNonUniformOp *inst) {
for (auto *operand : inst->getOperands())
tryToUpdateInstLitType(operand, inst->getAstResultType());
return true;
}

Expand Down
3 changes: 1 addition & 2 deletions tools/clang/lib/SPIRV/LiteralTypeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class LiteralTypeVisitor : public Visitor {
bool visit(SpirvBitFieldExtract *) override;
bool visit(SpirvSelect *) override;
bool visit(SpirvVectorShuffle *) override;
bool visit(SpirvNonUniformUnaryOp *) override;
bool visit(SpirvNonUniformBinaryOp *) override;
bool visit(SpirvGroupNonUniformOp *) override;
bool visit(SpirvLoad *) override;
bool visit(SpirvStore *) override;
bool visit(SpirvConstantComposite *) override;
Expand Down
3 changes: 1 addition & 2 deletions tools/clang/lib/SPIRV/PervertexInputVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ class PervertexInputVisitor : public Visitor {
REMAP_FUNC_OP(ImageOp)
REMAP_FUNC_OP(ExtInst)
REMAP_FUNC_OP(Atomic)
REMAP_FUNC_OP(NonUniformBinaryOp)
REMAP_FUNC_OP(BitFieldInsert)
REMAP_FUNC_OP(BitFieldExtract)
REMAP_FUNC_OP(IntrinsicInstruction)
Expand All @@ -115,7 +114,7 @@ class PervertexInputVisitor : public Visitor {
REMAP_FUNC_OP(Select)
REMAP_FUNC_OP(Switch)
REMAP_FUNC_OP(CopyObject)
REMAP_FUNC_OP(NonUniformUnaryOp)
REMAP_FUNC_OP(GroupNonUniformOp)

private:
///< Whether in entry function wrapper, which will influence replace steps.
Expand Down
11 changes: 3 additions & 8 deletions tools/clang/lib/SPIRV/PreciseVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,9 @@ bool PreciseVisitor::visit(SpirvUnaryOp *inst) {
return true;
}

bool PreciseVisitor::visit(SpirvNonUniformBinaryOp *inst) {
inst->getArg1()->setPrecise(inst->isPrecise());
inst->getArg2()->setPrecise(inst->isPrecise());
return true;
}

bool PreciseVisitor::visit(SpirvNonUniformUnaryOp *inst) {
inst->getArg()->setPrecise(inst->isPrecise());
bool PreciseVisitor::visit(SpirvGroupNonUniformOp *inst) {
for (auto *operand : inst->getOperands())
operand->setPrecise(inst->isPrecise());
return true;
}

Expand Down
3 changes: 1 addition & 2 deletions tools/clang/lib/SPIRV/PreciseVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ class PreciseVisitor : public Visitor {
bool visit(SpirvStore *) override;
bool visit(SpirvBinaryOp *) override;
bool visit(SpirvUnaryOp *) override;
bool visit(SpirvNonUniformBinaryOp *) override;
bool visit(SpirvNonUniformUnaryOp *) override;
bool visit(SpirvGroupNonUniformOp *) override;
bool visit(SpirvExtInst *) override;
bool visit(SpirvFunctionCall *) override;

Expand Down
29 changes: 5 additions & 24 deletions tools/clang/lib/SPIRV/SpirvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,32 +431,13 @@ SpirvSpecConstantBinaryOp *SpirvBuilder::createSpecConstantBinaryOp(
return instruction;
}

SpirvNonUniformElect *SpirvBuilder::createGroupNonUniformElect(
spv::Op op, QualType resultType, spv::Scope execScope, SourceLocation loc) {
assert(insertPoint && "null insert point");
auto *instruction =
new (context) SpirvNonUniformElect(resultType, loc, execScope);
insertPoint->addInstruction(instruction);
return instruction;
}

SpirvNonUniformUnaryOp *SpirvBuilder::createGroupNonUniformUnaryOp(
SourceLocation loc, spv::Op op, QualType resultType, spv::Scope execScope,
SpirvInstruction *operand, llvm::Optional<spv::GroupOperation> groupOp) {
assert(insertPoint && "null insert point");
auto *instruction = new (context)
SpirvNonUniformUnaryOp(op, resultType, loc, execScope, groupOp, operand);
insertPoint->addInstruction(instruction);
return instruction;
}

SpirvNonUniformBinaryOp *SpirvBuilder::createGroupNonUniformBinaryOp(
SpirvGroupNonUniformOp *SpirvBuilder::createGroupNonUniformOp(
spv::Op op, QualType resultType, spv::Scope execScope,
SpirvInstruction *operand1, SpirvInstruction *operand2,
SourceLocation loc) {
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation loc,
llvm::Optional<spv::GroupOperation> groupOp) {
assert(insertPoint && "null insert point");
auto *instruction = new (context) SpirvNonUniformBinaryOp(
op, resultType, loc, execScope, operand1, operand2);
auto *instruction = new (context)
SpirvGroupNonUniformOp(op, resultType, execScope, operands, loc, groupOp);
insertPoint->addInstruction(instruction);
return instruction;
}
Expand Down
Loading

0 comments on commit d9caef5

Please sign in to comment.