Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EXT_mesh_shader validation support #5640

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions source/val/validate_annotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ spv_result_t ValidateDecorationTarget(ValidationState_t& _, spv::Decoration dec,
return fail(0) << "must be a variable";
}
break;
case spv::Decoration::PerPrimitiveNV:
if (target->opcode() != spv::Op::OpVariable) {
return fail(0) << "must be a memory object declaration";
Comment on lines +147 to +148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory object declarations include OpFunctionParameter (and are handled in the next block of decorations). The spec matches that, is this a bug here or in the extension?

}
break;
case spv::Decoration::NoPerspective:
case spv::Decoration::Flat:
case spv::Decoration::Patch:
Expand Down Expand Up @@ -361,6 +366,15 @@ spv_result_t ValidateMemberDecorate(ValidationState_t& _,
<< _.SpvDecorationString(decoration)
<< " cannot be applied to structure members";
}

const auto target_id = inst->GetOperandAs<uint32_t>(0);
const auto target = _.FindDef(target_id);
if (decoration == spv::Decoration::PerPrimitiveNV &&
target->opcode() != spv::Op::OpTypeStruct) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< _.SpvDecorationString(decoration)
<< " must be a memory object declaration";
}
Comment on lines +370 to +377
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it would be caught above already at line 346.


return SPV_SUCCESS;
}
Expand Down
106 changes: 104 additions & 2 deletions source/val/validate_builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,36 @@ class BuiltInsValidator {
// instruction.
void Update(const Instruction& inst);

uint32_t GetMeshEntryPoint() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you only have a single mesh entry point in a module? I'm confused why this is a single entry point.

if (mesh_entry_point_ == 0) {
for (const uint32_t entry_point : _.entry_points()) {
// Every entry point from which this function is called needs to have
// Execution Mode DepthReplacing.
const auto* models = _.GetExecutionModels(entry_point);
if (models->find(spv::ExecutionModel::MeshEXT ) != models->end() ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the CI bot has not kicked off yet, but it will for sure fail for not running clang-format, so I would suggest running it on this PR

Suggested change
if (models->find(spv::ExecutionModel::MeshEXT ) != models->end() ||
if (models->find(spv::ExecutionModel::MeshEXT) != models->end() ||

models->find(spv::ExecutionModel::MeshNV ) != models->end()) {
mesh_entry_point_ = entry_point;
break;
}
}
}
return mesh_entry_point_;
}

bool isMeshInterfaceVar(const Instruction& inst) {
uint32_t mesh_entry_point = GetMeshEntryPoint();
if (!mesh_entry_point) return false;
for (const auto& desc : _.entry_point_descriptions(mesh_entry_point)) {
for (auto interface : desc.interfaces) {
if (inst.id() == interface) {
return true;
}
}
}
return false;
}


ValidationState_t& _;

// Mapping id -> list of rules which validate instruction referencing the
Expand All @@ -684,7 +714,7 @@ class BuiltInsValidator {
// or to no_entry_points_. The pointer is guaranteed to never be null.
const std::vector<uint32_t> no_entry_points;
const std::vector<uint32_t>* entry_points_ = &no_entry_points;

uint32_t mesh_entry_point_ = 0;
// Execution models with which the current function can be called.
std::set<spv::ExecutionModel> execution_models_;
};
Expand Down Expand Up @@ -2146,6 +2176,28 @@ spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtDefinition(
return error;
}
}

if (isMeshInterfaceVar(inst)) {
if (_.HasCapability(spv::Capability::MeshShadingEXT) &&
!_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
<< _.VkErrorID(7040)
<< "According to the Vulkan spec the variable decorated with "
"Builtin PrimitiveId within the MeshEXT Execution Model must "
"also be decorated with the PerPrimitiveEXT decoration. ";
}
#if 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you forget to remove this?

const spv::StorageClass storage_class =
inst.GetOperandAs<spv::StorageClass>(2);
if (storage_class != spv::StorageClass::Output) {
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
<< _.VkErrorID(4336)
<< "According to the Vulkan spec the variable decorated with "
"Builtin PrimitiveId within the MeshEXT Execution Model must "
"must be declared using the Output Storage Class. ";
}
#endif
}
}

// Seed at reference checks with this built-in.
Expand Down Expand Up @@ -2753,6 +2805,21 @@ spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtDefinition(
return error;
}
}

if (isMeshInterfaceVar(inst) &&
_.HasCapability(spv::Capability::MeshShadingEXT) &&
!_.HasDecoration(inst.id(), spv::Decoration::PerPrimitiveEXT)) {
const spv::BuiltIn label = spv::BuiltIn(decoration.params()[0]);
uint32_t vkerrid = (label == spv::BuiltIn::Layer) ? 7039 : 7060;
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
<< _.VkErrorID(vkerrid)
<< "According to the Vulkan spec the variable decorated with "
"Builtin "
<< _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
decoration.params()[0])
<< " within the MeshEXT Execution Model must also be decorated "
"with the PerPrimitiveEXT decoration. ";
}
}

// Seed at reference checks with this built-in.
Expand Down Expand Up @@ -3458,6 +3525,7 @@ spv_result_t BuiltInsValidator::ValidateViewIndexAtReference(
referenced_from_inst, execution_model);
}
}

}

if (function_id_ == 0) {
Expand Down Expand Up @@ -4167,6 +4235,23 @@ spv_result_t BuiltInsValidator::ValidateRayTracingBuiltinsAtReference(
spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
const Decoration& decoration, const Instruction& inst) {
if (spvIsVulkanEnv(_.context()->target_env)) {
uint32_t mesh_entry_point = GetMeshEntryPoint();
assert(mesh_entry_point);
bool execution_mode_OuputPoints = false;
bool execution_mode_OutputLinesEXT = false;
bool execution_mode_OutputTrianglesEXT = false;

const auto* modes = _.GetExecutionModes(mesh_entry_point);
if (modes->find(spv::ExecutionMode::OutputPoints) != modes->end()) {
execution_mode_OuputPoints = true;
}
if (modes->find(spv::ExecutionMode::OutputLinesEXT) != modes->end()) {
execution_mode_OutputLinesEXT = true;
}
if (modes->find(spv::ExecutionMode::OutputTrianglesEXT) !=
modes->end()) {
execution_mode_OutputTrianglesEXT = true;
}
const spv::BuiltIn builtin = spv::BuiltIn(decoration.params()[0]);
uint32_t vuid = GetVUIDForBuiltin(builtin, VUIDErrorType);
if (builtin == spv::BuiltIn::PrimitivePointIndicesEXT) {
Expand All @@ -4185,6 +4270,12 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
})) {
return error;
}
if (!execution_mode_OuputPoints) {
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
<< _.VkErrorID(7042)
<< "The PrimitivePointIndicesEXT decoration must be used with "
"the OutputPoints Execution Mode";
}
}
if (builtin == spv::BuiltIn::PrimitiveLineIndicesEXT) {
if (spv_result_t error = ValidateArrayedI32Vec(
Expand All @@ -4203,6 +4294,12 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
})) {
return error;
}
if (!execution_mode_OutputLinesEXT) {
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
<< _.VkErrorID(7048)
<< "The PrimitiveLineIndicesEXT decoration must be used with "
"the OutputLinesEXT Execution Mode";
}
}
if (builtin == spv::BuiltIn::PrimitiveTriangleIndicesEXT) {
if (spv_result_t error = ValidateArrayedI32Vec(
Expand All @@ -4221,6 +4318,12 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition(
})) {
return error;
}
if (!execution_mode_OutputTrianglesEXT) {
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
<< _.VkErrorID(7054)
<< "The PrimitiveTriangleIndicesEXT decoration must be used with "
"the OutputTrianglesEXT Execution Mode";
}
}
}
// Seed at reference checks with this built-in.
Expand Down Expand Up @@ -4249,7 +4352,6 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtReference(
referenced_from_inst)
<< " " << GetStorageClassDesc(referenced_from_inst);
}

for (const spv::ExecutionModel execution_model : execution_models_) {
if (execution_model != spv::ExecutionModel::MeshEXT) {
uint32_t vuid = GetVUIDForBuiltin(builtin, VUIDErrorExecutionModel);
Expand Down
14 changes: 14 additions & 0 deletions source/val/validate_decorations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) {
int num_workgroup_variables = 0;
int num_workgroup_variables_with_block = 0;
int num_workgroup_variables_with_aliased = 0;
bool has_task_payload = false;
for (const auto& desc : descs) {
std::unordered_set<Instruction*> seen_vars;
for (auto interface : desc.interfaces) {
Expand All @@ -779,6 +780,19 @@ spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) {
}
const spv::StorageClass storage_class =
var_instr->GetOperandAs<spv::StorageClass>(2);
if (vstate.version() >= SPV_SPIRV_VERSION_WORD(1, 5)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this version gated on 1.5? SPV_EXT_mesh_shader requires SPIR-V 1.4. What is intended to be checked here?

// SPV_EXT_mesh_shader, at most one task payload is permitted
// per entry point
if (storage_class == spv::StorageClass::TaskPayloadWorkgroupEXT) {
if (has_task_payload) {
return vstate.diag(SPV_ERROR_INVALID_ID, var_instr)
<< "There can be at most one OpVariable with storage "
"class TaskPayloadWorkgroupEXT associated with "
"an OpEntryPoint";
}
has_task_payload = true;
}
}
if (vstate.version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
// Starting in 1.4, OpEntryPoint must list all global variables
// it statically uses and those interfaces must be unique.
Expand Down
56 changes: 54 additions & 2 deletions source/val/validate_mesh_shading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,32 @@
// Validates ray query instructions from SPV_KHR_ray_query

#include "source/opcode.h"
#include "source/spirv_target_env.h"
#include "source/val/instruction.h"
#include "source/val/validate.h"
#include "source/val/validation_state.h"

namespace spvtools {
namespace val {

bool IsInterfaceVariable(ValidationState_t& _, const Instruction* inst,
spv::ExecutionModel model) {
bool foundInterface = false;
for (auto entry_point : _.entry_points()) {
const auto* models = _.GetExecutionModels(entry_point);
if (models->find(model) == models->end()) return false;
for (const auto& desc : _.entry_point_descriptions(entry_point)) {
for (auto interface : desc.interfaces) {
if (inst->id() == interface) {
foundInterface = true;
break;
}
}
}
}
return foundInterface;
}

spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) {
const spv::Op opcode = inst->opcode();
switch (opcode) {
Expand Down Expand Up @@ -103,15 +122,48 @@ spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Primitive Count must be a 32-bit unsigned int scalar";
}

break;
}

case spv::Op::OpWritePackedPrimitiveIndices4x8NV: {
// No validation rules (for the moment).
break;
}

case spv::Op::OpVariable: {
if (_.HasCapability(spv::Capability::MeshShadingEXT)) {
bool meshInterfaceVar = IsInterfaceVariable(
_, inst, spv::ExecutionModel::MeshEXT);
bool fragInterfaceVar = IsInterfaceVariable(
_, inst, spv::ExecutionModel::Fragment);

const spv::StorageClass storage_class =
inst->GetOperandAs<spv::StorageClass>(2);
bool storage_output = (storage_class == spv::StorageClass::Output);
bool storage_input = (storage_class == spv::StorageClass::Input);

if (_.HasDecoration(inst->id(), spv::Decoration::PerPrimitiveEXT)) {
if (fragInterfaceVar && !storage_input) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "PerPrimitiveEXT decoration must be applied only to "
"variables in the Input Storage Class in the Fragment "
"Execution Model.";
}

if (meshInterfaceVar && !storage_output) {
std::string vkerror = (spvIsVulkanEnv(_.context()->target_env))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VkErrorID already will return "" if its not a Vulkan Env, so you don't need to do that here

? _.VkErrorID(4336)
: "";
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< vkerror
<< "PerPrimitiveEXT decoration must be applied only to "
"variables in the Output Storage Class in the "
"Storage Class in the MeshEXT Execution Model.";
}
}
}
break;
}
default:
break;
}
Expand Down
19 changes: 19 additions & 0 deletions source/val/validate_mode_setting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,15 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
"tessellation execution model.";
}
}
if (spvIsVulkanEnv(_.context()->target_env)) {
if (_.HasCapability(spv::Capability::MeshShadingEXT) &&
inst->GetOperandAs<uint32_t>(2) == 0) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< _.VkErrorID(7330)
<< "In mesh shaders using the MeshEXT Execution Model the "
"OutputVertices Execution Mode must be greater than 0";
}
}
break;
case spv::ExecutionMode::OutputLinesEXT:
case spv::ExecutionMode::OutputTrianglesEXT:
Expand All @@ -557,6 +566,16 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
"execution "
"model.";
}
if (mode == spv::ExecutionMode::OutputPrimitivesEXT &&
spvIsVulkanEnv(_.context()->target_env)) {
if (_.HasCapability(spv::Capability::MeshShadingEXT) &&
inst->GetOperandAs<uint32_t>(2) == 0) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< _.VkErrorID(7331)
<< "In mesh shaders using the MeshEXT Execution Model the "
"OutputPrimitivesEXT Execution Mode must be greater than 0";
}
}
break;
case spv::ExecutionMode::QuadDerivativesKHR:
if (!std::all_of(models->begin(), models->end(),
Expand Down
Loading