From 6cf677ba6b9bcaf12a9a779a8209f546485ca473 Mon Sep 17 00:00:00 2001 From: zhangbaizhou Date: Mon, 8 Apr 2024 08:42:15 +0000 Subject: [PATCH 1/2] fix bug in broadcast_with_cf.cc --- .../operator/transforms/lowering_pass/broadcast_with_cf.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc index 7068221d77fe5..8b156d9c1a747 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc @@ -108,11 +108,12 @@ void UpdateGroupShapeExprs( const auto& origin_shape_or_data = origin_group->GetShapeOrDataExprs(origin_val); if (origin_shape_or_data.data()) { + std::vector shape_dim_expr_shape = { + symbol::DimExpr(static_cast(shape_dim_expr.size()))}; new_group->SetShapeOrDataExprs( new_val, symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs( - std::vector{shape_dim_expr.size()}, - shape_dim_expr)}); + shape_dim_expr_shape, shape_dim_expr)}); } else { new_group->SetShapeOrDataExprs( new_val, From 0db9e690b1c56fd9ffc4ae11b1c79b0b82cdef7d Mon Sep 17 00:00:00 2001 From: zhangbaizhou Date: Tue, 9 Apr 2024 07:24:57 +0000 Subject: [PATCH 2/2] add test for cinn dynamic shape --- test/ir/pir/cinn/test_dynamic_shape.py | 132 +++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 test/ir/pir/cinn/test_dynamic_shape.py diff --git a/test/ir/pir/cinn/test_dynamic_shape.py b/test/ir/pir/cinn/test_dynamic_shape.py new file mode 100644 index 0000000000000..92efc24037893 --- /dev/null +++ b/test/ir/pir/cinn/test_dynamic_shape.py @@ -0,0 +1,132 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import unittest + +import numpy + +os.environ['FLAGS_cinn_new_group_scheduler'] = '1' +os.environ['FLAGS_group_schedule_tiling_first'] = '1' +os.environ['FLAGS_prim_all'] = 'true' +os.environ['FLAGS_prim_enable_dynamic'] = 'true' +os.environ['FLAGS_print_ir'] = '1' +os.environ['FLAGS_enable_pir_api'] = '1' +os.environ['FLAGS_use_cinn'] = '1' +os.environ['FLAGS_cinn_bucket_compile'] = '1' +os.environ['FLAGS_cinn_new_cluster_op_method'] = '1' +os.environ['FLAGS_deny_cinn_ops'] = 'slice;' + +import paddle + +build_strategy = paddle.static.BuildStrategy() +build_strategy.build_cinn_pass = True + + +def generate_input_spec(rank_dtype_list): + input_spec = [] + for rank, dtype in rank_dtype_list: + input_spec.append( + paddle.static.InputSpec(shape=[None] * rank, dtype=dtype) + ) + return input_spec + + +class TestTrivialFusion(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def compare_result(self, dy_compute, input_spec, data_init): + inputs = data_init() + dy_out = dy_compute(*inputs) + static_compute = paddle.jit.to_static( + full_graph=True, + build_strategy=build_strategy, + input_spec=input_spec, + )(dy_compute) + st_out = static_compute(*inputs) + numpy.testing.assert_allclose(dy_out, st_out, atol=1e-5, rtol=1e-6) + + def test_simple_trivial_fusions(self): + def func(x): + x = x * 2 + x = x + 1 + x = paddle.nn.functional.relu(x) + x = paddle.transpose(x, perm=[0, 2, 1]) + x = x.reshape((-1, 128)) + return x + + def init(): + x = paddle.rand((32, 32, 128)) + return (x,) + + input_spec = generate_input_spec([(3, 'float32')]) + self.compare_result(func, input_spec, init) + + def test_trivial_fusion_slice_and_concat(self): + def func(x, y): + x = x * 2 + y = y * 2 + x = x[:, :, :64] + y = y[:, :, :64] + z = paddle.concat([x, y], axis=-1) + return z + + def init(): + x = paddle.rand((32, 32, 128)) + y = paddle.rand((32, 32, 128)) + return (x, y) + + input_spec = generate_input_spec([(3, 'float32'), (3, 'float32')]) + self.compare_result(func, input_spec, init) + + def test_trivial_fusion_gather_nd(self): + def func(x, y): + x = x * 2 + output = paddle.gather_nd(x, y) + return output + + def init(): + x = paddle.to_tensor( + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]] + ) + index = paddle.to_tensor([[0, 1]]) + return (x, index) + + input_spec = [ + paddle.static.InputSpec(shape=[None, None, None], dtype='float32'), + paddle.static.InputSpec(shape=[None, 2], dtype='int32'), + ] + self.compare_result(func, input_spec, init) + + def test_broadcast(self): + def func(x, y): + output = x + y + return output + + def init(): + x = paddle.rand((32, 1)) + y = paddle.rand((1, 32)) + return (x, y) + + input_spec = generate_input_spec([(2, 'float32'), (2, 'float32')]) + self.compare_result(func, input_spec, init) + + +if __name__ == "__main__": + unittest.main()