diff --git a/source/opt/eliminate_dead_output_stores_pass.cpp b/source/opt/eliminate_dead_output_stores_pass.cpp index 99711a16e81..8a4e8e1dc2c 100644 --- a/source/opt/eliminate_dead_output_stores_pass.cpp +++ b/source/opt/eliminate_dead_output_stores_pass.cpp @@ -27,6 +27,8 @@ constexpr uint32_t kOpDecorateBuiltInLiteralInIdx = 2; constexpr uint32_t kOpDecorateMemberBuiltInLiteralInIdx = 3; constexpr uint32_t kOpAccessChainIdx0InIdx = 1; constexpr uint32_t kOpConstantValueInIdx = 0; +constexpr uint32_t kVariableStorageClassInIdx = 0; +constexpr uint32_t kSpvTypePointerTypeIdInIdx = 1; } // namespace Pass::Status EliminateDeadOutputStoresPass::Process() { @@ -45,6 +47,64 @@ bool EliminateDeadOutputStoresPass::IsLiveBuiltin(uint32_t bi) { return live_builtins_->find(bi) != live_builtins_->end(); } +bool EliminateDeadOutputStoresPass::IsVariableRead( + Instruction const* const variable) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + bool is_read = false; + def_use_mgr->ForEachUser(variable, [&is_read](Instruction* user) { + if (user->opcode() == spv::Op::OpLoad) is_read = true; + }); + + return is_read; +} + +bool EliminateDeadOutputStoresPass::DemoteToPrivate( + Instruction* const variable) { + assert(spv::StorageClass(variable->GetSingleWordInOperand( + kVariableStorageClassInIdx)) == spv::StorageClass::Output); + + auto type_mgr = context()->get_type_mgr(); + + // Set the variable's storage class to private. + variable->SetInOperand(kVariableStorageClassInIdx, + {uint32_t(spv::StorageClass::Private)}); + + // Get the pointee type of the variable. + auto const res_type = get_def_use_mgr()->GetDef(variable->type_id()); + auto pointee_type_id = + res_type->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx); + + // Find a type that matches the pointee type and uses private storage. + pointee_type_id = + type_mgr->FindPointerToType(pointee_type_id, spv::StorageClass::Private); + if (pointee_type_id == 0) { + return false; + } + context()->UpdateDefUse(get_def_use_mgr()->GetDef(pointee_type_id)); + + // Remove variable decorations that are not valid for private storage. + get_def_use_mgr()->ForEachUser(variable, [this](Instruction* user) { + if (user->opcode() == spv::Op::OpDecorate) { + auto const decoration = spv::Decoration(user->GetSingleWordInOperand(1)); + if (decoration == spv::Decoration::Location || + decoration == spv::Decoration::Component) { + kill_list_.push_back(user); + } + } + }); + + // Update the variable's result type to match the pointee type. + variable->SetResultType(pointee_type_id); + context()->UpdateDefUse(variable); + + // Move the variable after the result type. + variable->RemoveFromList(); + variable->InsertAfter(get_def_use_mgr()->GetDef(pointee_type_id)); + + return true; +} + bool EliminateDeadOutputStoresPass::AnyLocsAreLive(uint32_t start, uint32_t count) { auto finish = start + count; @@ -194,6 +254,17 @@ Pass::Status EliminateDeadOutputStoresPass::DoDeadOutputStoreElimination() { if (ptr_type->storage_class() != spv::StorageClass::Output) { continue; } + + // Stores to write-only output variables are safe to eliminate; however, + // stores to read-write output variables may be required for functional + // correctness. If a read-write output variable is detected, then it will + // be demoted to a private variable and any associated stores will not be + // eliminated. + if (IsVariableRead(&var)) { + if (!DemoteToPrivate(&var)) return Status::Failure; + continue; + } + // If builtin decoration on variable, process as builtin. auto var_id = var.result_id(); bool is_builtin = false; @@ -217,7 +288,7 @@ Pass::Status EliminateDeadOutputStoresPass::DoDeadOutputStoreElimination() { // locations are dead, kill store or all access chain's stores def_use_mgr->ForEachUser( var_id, [this, &var, is_builtin](Instruction* user) { - auto op = user->opcode(); + auto const op = user->opcode(); if (op == spv::Op::OpEntryPoint || op == spv::Op::OpName || op == spv::Op::OpDecorate || user->IsNonSemanticInstruction()) return; diff --git a/source/opt/eliminate_dead_output_stores_pass.h b/source/opt/eliminate_dead_output_stores_pass.h index 676d4f4f000..0be3d422604 100644 --- a/source/opt/eliminate_dead_output_stores_pass.h +++ b/source/opt/eliminate_dead_output_stores_pass.h @@ -69,6 +69,12 @@ class EliminateDeadOutputStoresPass : public Pass { // Return true if builtin |bi| is live. bool IsLiveBuiltin(uint32_t bi); + // Return true if an OpLoad is using this variable. + bool IsVariableRead(Instruction const* const var); + + // Demote the variable's storage class from output to private. + bool DemoteToPrivate(Instruction* const var); + std::unordered_set* live_locs_; std::unordered_set* live_builtins_; diff --git a/test/opt/eliminate_dead_output_stores_test.cpp b/test/opt/eliminate_dead_output_stores_test.cpp index 4c2e44c0016..e4667ad4985 100644 --- a/test/opt/eliminate_dead_output_stores_test.cpp +++ b/test/opt/eliminate_dead_output_stores_test.cpp @@ -946,6 +946,73 @@ TEST_F(ElimDeadOutputStoresTest, VertMultipleLocationsF16) { &live_builtins); } +TEST_F(ElimDeadOutputStoresTest, DemoteOutputVariableToPrivate) { + // + // #version 450 + // + // layout(location=0) out vec4 pos; + // void main() + // { + // pos = vec4(0.0); + // gl_Position = pos; + // } + + const std::string text = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %pos %_ + OpSource GLSL 450 + OpName %main "main" + OpName %pos "pos" + OpName %gl_PerVertex "gl_PerVertex" + OpMemberName %gl_PerVertex 0 "gl_Position" + OpMemberName %gl_PerVertex 1 "gl_PointSize" + OpMemberName %gl_PerVertex 2 "gl_ClipDistance" + OpMemberName %gl_PerVertex 3 "gl_CullDistance" + OpName %_ "" + OpDecorate %pos Location 0 + OpMemberDecorate %gl_PerVertex 0 BuiltIn Position + OpMemberDecorate %gl_PerVertex 1 BuiltIn PointSize + OpMemberDecorate %gl_PerVertex 2 BuiltIn ClipDistance + OpMemberDecorate %gl_PerVertex 3 BuiltIn CullDistance + OpDecorate %gl_PerVertex Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +;CHECK: %_ptr_Private_v4float = OpTypePointer Private %v4float + %pos = OpVariable %_ptr_Output_v4float Output +;CHECK: %pos = OpVariable %_ptr_Private_v4float Private + %float_0 = OpConstant %float 0 + %11 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 + %uint = OpTypeInt 32 0 + %uint_1 = OpConstant %uint 1 +%_arr_float_uint_1 = OpTypeArray %float %uint_1 +%gl_PerVertex = OpTypeStruct %v4float %float %_arr_float_uint_1 %_arr_float_uint_1 +%_ptr_Output_gl_PerVertex = OpTypePointer Output %gl_PerVertex + %_ = OpVariable %_ptr_Output_gl_PerVertex Output + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %main = OpFunction %void None %3 + %5 = OpLabel + OpStore %pos %11 +;CHECK: OpStore %pos %11 + %20 = OpLoad %v4float %pos + %21 = OpAccessChain %_ptr_Output_v4float %_ %int_0 + OpStore %21 %20 + OpReturn + OpFunctionEnd)"; + + SetTargetEnv(SPV_ENV_VULKAN_1_3); + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + std::unordered_set live_inputs{}; + std::unordered_set live_builtins{}; + SinglePassRunAndMatch(text, true, &live_inputs, + &live_builtins); +} + } // namespace } // namespace opt } // namespace spvtools