diff --git a/source/opcode.cpp b/source/opcode.cpp index c80e3a001b..d87e828738 100644 --- a/source/opcode.cpp +++ b/source/opcode.cpp @@ -444,15 +444,32 @@ bool spvOpcodeIsReturn(SpvOp opcode) { } } +bool spvOpcodeIsAbort(SpvOp opcode) { + switch (opcode) { + case SpvOpKill: + case SpvOpUnreachable: + case SpvOpTerminateInvocation: + case SpvOpTerminateRayKHR: + case SpvOpIgnoreIntersectionKHR: + return true; + default: + return false; + } +} + bool spvOpcodeIsReturnOrAbort(SpvOp opcode) { - return spvOpcodeIsReturn(opcode) || opcode == SpvOpKill || - opcode == SpvOpUnreachable || opcode == SpvOpTerminateInvocation; + return spvOpcodeIsReturn(opcode) || spvOpcodeIsAbort(opcode); } bool spvOpcodeIsBlockTerminator(SpvOp opcode) { return spvOpcodeIsBranch(opcode) || spvOpcodeIsReturnOrAbort(opcode); } +bool spvOpcodeTerminatesExecution(SpvOp opcode) { + return opcode == SpvOpKill || opcode == SpvOpTerminateInvocation || + opcode == SpvOpTerminateRayKHR || opcode == SpvOpIgnoreIntersectionKHR; +} + bool spvOpcodeIsBaseOpaqueType(SpvOp opcode) { switch (opcode) { case SpvOpTypeImage: diff --git a/source/opcode.h b/source/opcode.h index 3702cb35fb..c8525a253a 100644 --- a/source/opcode.h +++ b/source/opcode.h @@ -110,10 +110,18 @@ bool spvOpcodeIsBranch(SpvOp opcode); // Returns true if the given opcode is a return instruction. bool spvOpcodeIsReturn(SpvOp opcode); +// Returns true if the given opcode aborts execution. +bool spvOpcodeIsAbort(SpvOp opcode); + // Returns true if the given opcode is a return instruction or it aborts // execution. bool spvOpcodeIsReturnOrAbort(SpvOp opcode); +// Returns true if the given opcode is a kill instruction or it terminates +// execution. Note that branches, returns, and unreachables do not terminate +// execution. +bool spvOpcodeTerminatesExecution(SpvOp opcode); + // Returns true if the given opcode is a basic block terminator. bool spvOpcodeIsBlockTerminator(SpvOp opcode); diff --git a/source/opt/basic_block.cpp b/source/opt/basic_block.cpp index b7e122c413..e82a744af5 100644 --- a/source/opt/basic_block.cpp +++ b/source/opt/basic_block.cpp @@ -230,7 +230,7 @@ std::string BasicBlock::PrettyPrint(uint32_t options) const { std::ostringstream str; ForEachInst([&str, options](const Instruction* inst) { str << inst->PrettyPrint(options); - if (!IsTerminatorInst(inst->opcode())) { + if (!spvOpcodeIsBlockTerminator(inst->opcode())) { str << std::endl; } }); diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp index 88f395f07e..8159ebf787 100644 --- a/source/opt/inline_pass.cpp +++ b/source/opt/inline_pass.cpp @@ -383,9 +383,7 @@ std::unique_ptr InlinePass::InlineReturn( uint32_t returnLabelId = 0; for (auto callee_block_itr = calleeFn->begin(); callee_block_itr != calleeFn->end(); ++callee_block_itr) { - if (callee_block_itr->tail()->opcode() == SpvOpUnreachable || - callee_block_itr->tail()->opcode() == SpvOpKill || - callee_block_itr->tail()->opcode() == SpvOpTerminateInvocation) { + if (spvOpcodeIsAbort(callee_block_itr->tail()->opcode())) { returnLabelId = context()->TakeNextId(); break; } @@ -759,8 +757,7 @@ bool InlinePass::IsInlinableFunction(Function* func) { bool InlinePass::ContainsKillOrTerminateInvocation(Function* func) const { return !func->WhileEachInst([](Instruction* inst) { - const auto opcode = inst->opcode(); - return (opcode != SpvOpKill) && (opcode != SpvOpTerminateInvocation); + return !spvOpcodeTerminatesExecution(inst->opcode()); }); } diff --git a/source/opt/ir_loader.cpp b/source/opt/ir_loader.cpp index 06099ce066..70e5144aac 100644 --- a/source/opt/ir_loader.cpp +++ b/source/opt/ir_loader.cpp @@ -137,7 +137,7 @@ bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) { return false; } block_ = MakeUnique(std::move(spv_inst)); - } else if (IsTerminatorInst(opcode)) { + } else if (spvOpcodeIsBlockTerminator(opcode)) { if (function_ == nullptr) { Error(consumer_, src, loc, "terminator instruction outside function"); return false; diff --git a/source/opt/module.cpp b/source/opt/module.cpp index 9d3b0edc6e..0c88601097 100644 --- a/source/opt/module.cpp +++ b/source/opt/module.cpp @@ -188,7 +188,7 @@ void Module::ToBinary(std::vector* binary, bool skip_nop) const { i->ToBinaryWithoutAttachedDebugInsts(binary); } // Update the last line instruction. - if (IsTerminatorInst(opcode) || opcode == SpvOpNoLine) { + if (spvOpcodeIsBlockTerminator(opcode) || opcode == SpvOpNoLine) { last_line_inst = nullptr; } else if (opcode == SpvOpLoopMerge || opcode == SpvOpSelectionMerge) { between_merge_and_branch = true; diff --git a/source/opt/reflect.h b/source/opt/reflect.h index d374e6823f..c7d46df548 100644 --- a/source/opt/reflect.h +++ b/source/opt/reflect.h @@ -59,10 +59,6 @@ inline bool IsCompileTimeConstantInst(SpvOp opcode) { inline bool IsSpecConstantInst(SpvOp opcode) { return opcode >= SpvOpSpecConstantTrue && opcode <= SpvOpSpecConstantOp; } -inline bool IsTerminatorInst(SpvOp opcode) { - return (opcode >= SpvOpBranch && opcode <= SpvOpUnreachable) || - (opcode == SpvOpTerminateInvocation); -} } // namespace opt } // namespace spvtools diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp index d3ef09c3fc..2939901338 100644 --- a/test/opt/inline_test.cpp +++ b/test/opt/inline_test.cpp @@ -2581,6 +2581,63 @@ OpFunctionEnd SinglePassRunAndCheck(before, after, false, true); } +TEST_F(InlineTest, InlineFuncWithOpTerminateRayNotInContinue) { + const std::string text = + R"( + OpCapability RayTracingKHR + OpExtension "SPV_KHR_ray_tracing" + OpMemoryModel Logical GLSL450 + OpEntryPoint AnyHitKHR %MyAHitMain2 "MyAHitMain2" %a + OpSource HLSL 630 + OpName %a "a" + OpName %MyAHitMain2 "MyAHitMain2" + OpName %param_var_a "param.var.a" + OpName %src_MyAHitMain2 "src.MyAHitMain2" + OpName %a_0 "a" + OpName %bb_entry "bb.entry" + %int = OpTypeInt 32 1 +%_ptr_IncomingRayPayloadKHR_int = OpTypePointer IncomingRayPayloadKHR %int + %void = OpTypeVoid + %6 = OpTypeFunction %void +%_ptr_Function_int = OpTypePointer Function %int + %14 = OpTypeFunction %void %_ptr_Function_int + %a = OpVariable %_ptr_IncomingRayPayloadKHR_int IncomingRayPayloadKHR +%MyAHitMain2 = OpFunction %void None %6 + %7 = OpLabel +%param_var_a = OpVariable %_ptr_Function_int Function + %10 = OpLoad %int %a + OpStore %param_var_a %10 + %11 = OpFunctionCall %void %src_MyAHitMain2 %param_var_a + %13 = OpLoad %int %param_var_a + OpStore %a %13 + OpReturn + OpFunctionEnd +%src_MyAHitMain2 = OpFunction %void None %14 + %a_0 = OpFunctionParameter %_ptr_Function_int + %bb_entry = OpLabel + %17 = OpLoad %int %a_0 + OpStore %a %17 + OpTerminateRayKHR + OpFunctionEnd + +; CHECK: %MyAHitMain2 = OpFunction %void None +; CHECK-NEXT: OpLabel +; CHECK-NEXT: %param_var_a = OpVariable %_ptr_Function_int Function +; CHECK-NEXT: OpLoad %int %a +; CHECK-NEXT: OpStore %param_var_a {{%\d+}} +; CHECK-NEXT: OpLoad %int %param_var_a +; CHECK-NEXT: OpStore %a {{%\d+}} +; CHECK-NEXT: OpTerminateRayKHR +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpLoad %int %param_var_a +; CHECK-NEXT: OpStore %a %16 +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, false); +} + TEST_F(InlineTest, EarlyReturnFunctionInlined) { // #version 140 //