diff --git a/test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py b/test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py index 731d59d23aeb1..1ccfe61a9c13b 100644 --- a/test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py +++ b/test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py @@ -45,10 +45,14 @@ def get_cuda_version(): "weight_only_linear requires CUDA >= 11.2 and CUDA_ARCH >= 8", ) class TestFusedWeightOnlyLinearPass_Fp32(PassTest): + def is_program_valid(self, program): + return True + def build_ir_progam(self): + pir_program = None with paddle.pir_utils.IrGuard(): - self.pir_program = paddle.static.Program() - with paddle.pir.core.program_guard(self.pir_program): + pir_program = paddle.static.Program() + with paddle.pir.core.program_guard(pir_program): x = paddle.static.data( name='x', shape=[3, 64, 64], dtype=self.dtype ) @@ -75,11 +79,14 @@ def build_ir_progam(self): "pd_op.matmul": 0, "pd_op.add": 0, } + return pir_program def setUp(self): self.place_runtime = "gpu" self.dtype = 'float32' - self.build_ir_progam() + + def sample_program(self): + yield self.build_ir_progam(), False def test_check_output(self): self.check_pass_correct() @@ -89,7 +96,6 @@ class TestFusedWeightOnlyLinearPass_Fp16(TestFusedWeightOnlyLinearPass_Fp32): def setUp(self): self.place_runtime = "gpu" self.dtype = 'float16' - self.build_ir_progam() if __name__ == "__main__":