diff --git a/paddle/fluid/ir/dialect/op_generator/python_c_gen.py b/paddle/fluid/ir/dialect/op_generator/python_c_gen.py index 340aa9569818a..a890a8db5d249 100644 --- a/paddle/fluid/ir/dialect/op_generator/python_c_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/python_c_gen.py @@ -283,11 +283,12 @@ def _gen_cast_attrs(self, op_info, op_name, with_mutable): def _gen_one_impl(self, op_info, op_name): input_name_list = op_info.input_name_list + output_name_list = op_info.output_name_list attr_name_list = op_info.attribute_name_list mutable_attr_name_list = op_info.mutable_attribute_name_list no_mutable_attr_name_list = op_info.non_mutable_attribute_name_list - if op_name == "send_v2": + if len(output_name_list) == 0: ret = NO_OUTPUT_API_IMPL_TEMPLATE.format( api_name=op_name, inputs=self._gen_inputs(op_info, op_name), diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_ops.yaml b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_ops.yaml index a5d2f42fc1ba1..da4c252af7217 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_ops.yaml +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_ops.yaml @@ -92,7 +92,7 @@ output : Tensor(out) infer_meta: func: RecvV2InferMeta - param: [peer, dtype, out_shape] + param: [ring_id, dynamic_shape, peer, out_shape, dtype] kernel : func : recv_v2 param : [ring_id, dynamic_shape, peer, out_shape, dtype, use_calc_stream] diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 1526ba2ec021b..d5da3a2f8bc87 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -181,6 +181,48 @@ void PRecvArrayInferMeta(int peer, out->set_dtype(dtype); } +void RecvV2InferMeta(const int ring_id, + const bool dynamic_shape, + const int peer, + const std::vector& out_shape, + DataType dtype, + MetaTensor* out) { + PADDLE_ENFORCE_GE( + peer, + 0, + errors::InvalidArgument( + "The peer (%d) for recv_v2 op must be non-negative.", peer)); + + PADDLE_ENFORCE_GE( + ring_id, + 0, + errors::InvalidArgument( + "The ring_id (%d) for recv_v2 op must be non-negative.", ring_id)); + + PADDLE_ENFORCE_GE(out_shape.size(), + 1, + errors::InvalidArgument( + "The size of the output shape must be greater than 0 " + "but the value given is %d.", + out_shape.size())); + + if (!dynamic_shape) { + for (size_t i = 0; i < out_shape.size(); ++i) { + PADDLE_ENFORCE_GE(out_shape[i], + 1, + errors::InvalidArgument( + "The shape attribute for recv_v2 must be set " + "explicitly, but the %dth element is %d which " + "is less than 1. Or dynamic_shape should be " + "set to True for both send_v2 and recv_v2.", + i, + out_shape[i])); + } + out->set_dims(phi::make_ddim(out_shape)); + } + out->set_dtype(dtype); +} + void TruncatedGaussianRandomInferMeta(const std::vector& shape, float mean, float std, diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index 775df0cc6ab47..bc73942c8ec1c 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -76,6 +76,13 @@ void PRecvArrayInferMeta(int peer, const std::vector& out_shape, MetaTensor* out); +void RecvV2InferMeta(const int ring_id, + const bool dynamic_shape, + const int peer, + const std::vector& out_shape, + DataType dtype, + MetaTensor* out); + void TruncatedGaussianRandomInferMeta(const std::vector& shape, float mean, float std, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 9932f9612d70f..aa1b6526cd5f8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -405,50 +405,6 @@ void CConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) { out->set_dtype(x.dtype()); } -void SendV2InferMeta(const int peer, const int ring_id) { - PADDLE_ENFORCE_GE( - peer, - 0, - errors::InvalidArgument( - "The peer (%d) for send_v2 op must be non-negative.", peer)); - PADDLE_ENFORCE_GE( - ring_id, - 0, - errors::InvalidArgument( - "The ring_id (%d) for send_v2 op must be non-negative.", ring_id)); -} - -void RecvV2InferMeta(int peer, - DataType dtype, - const std::vector& out_shape, - MetaTensor* out) { - PADDLE_ENFORCE_GE( - peer, - 0, - errors::InvalidArgument( - "The peer (%d) for p_recv op must be non-negative.", peer)); - - PADDLE_ENFORCE_GE(out_shape.size(), - 1, - errors::InvalidArgument( - "The size of the output shape must be greater than 0 " - "but the value given is %d.", - out_shape.size())); - - for (size_t i = 0; i < out_shape.size(); ++i) { - PADDLE_ENFORCE_GE( - out_shape[i], - 1, - errors::InvalidArgument("The shape attribute for recv must be set " - "explicitly, but the %dth element is %d which " - "is less than 1. Or dynamic_shape should be " - "set to True for both send_v2 and recv_v2.", - i, - out_shape[i])); - } - out->set_dtype(dtype); -} - void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) { auto dims = x.dims(); auto rank = dims.size(); @@ -3045,6 +3001,19 @@ void PSendArrayInferMeta(const MetaTensor& x, int peer) { "The peer (%d) for p_send op must be non-negative.", peer)); } +void SendV2InferMeta(const int peer, const int ring_id) { + PADDLE_ENFORCE_GE( + peer, + 0, + errors::InvalidArgument( + "The peer (%d) for send_v2 op must be non-negative.", peer)); + PADDLE_ENFORCE_GE( + ring_id, + 0, + errors::InvalidArgument( + "The ring_id (%d) for send_v2 op must be non-negative.", ring_id)); +} + void PoolInferMeta(const MetaTensor& x, const std::vector& kernel_size, const std::vector& strides, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 2bf90048d30d3..a3b7e87d86d0b 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -73,13 +73,6 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); void CConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out); -void SendV2InferMeta(const int peer, const int ring_id); - -void RecvV2InferMeta(int peer, - DataType dtype, - const std::vector& out_shape, - MetaTensor* out); - void ChannelShuffleInferMeta(const MetaTensor& x, int groups, const std::string& data_format, @@ -448,6 +441,8 @@ void PSendInferMeta(const MetaTensor& x, int peer); void PSendArrayInferMeta(const MetaTensor& x, int peer); +void SendV2InferMeta(const int peer, const int ring_id); + void QrInferMeta(const MetaTensor& x, const std::string& mode, MetaTensor* q, diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index 07bebf3d0c5ae..ff0c3004cf605 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -833,6 +833,14 @@ def _initialize(self, mode): dist_main_program, self._place, dist_context ) + # NOTE(zhaoyinglia): Skip startup program when use new ir temporarily. + use_new_ir = False + if auto_utils.use_new_ir(): + use_new_ir = True + paddle.framework.set_flags( + {"FLAGS_enable_new_ir_in_executor": False} + ) + if self._executor is None: self._executor = paddle.static.Executor(self._place) uninitialized = [] @@ -860,6 +868,11 @@ def _initialize(self, mode): ] self._executor.run(dist_startup_prog) + if use_new_ir: + paddle.framework.set_flags( + {"FLAGS_enable_new_ir_in_executor": True} + ) + def fit( self, train_data, diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 3441914518822..8ec0ba2e09f98 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -2423,6 +2423,19 @@ def use_new_executor(): ] +def use_new_ir(): + enable_new_ir_in_executor = os.environ.get( + 'FLAGS_enable_new_ir_in_executor', None + ) + return enable_new_ir_in_executor in [ + 1, + '1', + True, + 'True', + 'true', + ] + + def get_pp_stage(dist_context, rank): pp_idx = None for idx, process_mesh in enumerate(dist_context.process_meshes): diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index eae16e0245468..3af649a809cca 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -78,20 +78,23 @@ if(WITH_DISTRIBUTE AND WITH_GPU) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) py_test_modules(test_pass_quantization MODULES test_pass_quantization) set_tests_properties(test_pass_quantization - PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 60) + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 60) py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r) set_tests_properties(test_reshard_s_to_r - PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) py_test_modules(test_reshard_r_to_s MODULES test_reshard_r_to_s) set_tests_properties(test_reshard_r_to_s - PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p) set_tests_properties(test_reshard_r_to_p - PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) py_test_modules(test_semi_auto_parallel_basic MODULES test_semi_auto_parallel_basic) set_tests_properties(test_semi_auto_parallel_basic - PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) + py_test_modules(test_gpt_with_newir MODULES test_gpt_with_newir) + set_tests_properties(test_gpt_with_newir + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) # End of unittests WITH multi cards and timeout # NOTE(zyl): unittests WITH multi cards and WITHOUT timeout diff --git a/test/auto_parallel/gpt_with_newir.py b/test/auto_parallel/gpt_with_newir.py new file mode 100644 index 0000000000000..4ddfd5a76ffe0 --- /dev/null +++ b/test/auto_parallel/gpt_with_newir.py @@ -0,0 +1,161 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import unittest + +import numpy as np +from get_gpt_model import FakeDataset, generate_model + +import paddle +from paddle.distributed import ParallelEnv +from paddle.distributed.fleet import auto + +paddle.enable_static() + + +def apply_pass(): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + return strategy + + +def reset_prog(): + paddle.framework.switch_main_program(paddle.static.Program()) + paddle.framework.switch_startup_program(paddle.static.Program()) + paddle.utils.unique_name.switch() + + +class TestNewIR(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.batch_num = 5 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + os.environ['FLAGS_new_executor_micro_batching'] = 'True' + paddle.set_flags({'FLAGS_embedding_deterministic': 1}) + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + + def init(self, engine): + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + paddle.distributed.fleet.init(is_collective=True) + place = paddle.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, mode): + reset_prog() + + strategy = apply_pass() + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=None) + model, loss = generate_model(mode) + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses): + np.testing.assert_equal( + ref_losses, + check_losses, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses + ), + ) + + def enable_new_ir(self, flag): + paddle.set_flags({'FLAGS_enable_new_ir_in_executor': flag}) # for c++ + os.environ['FLAGS_enable_new_ir_in_executor'] = str(flag) # for python + + def test_dp(self): + self.enable_new_ir(False) + engine_dp_prog = self.get_engine("dp") + out_dp_prog = engine_dp_prog.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + + self.enable_new_ir(True) + engine_dp_ir = self.get_engine("dp") + out_dp_ir = engine_dp_ir.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + + self.check_results( + out_dp_prog.history["loss"][0], out_dp_ir.history["loss"][0] + ) + + def test_mp(self): + self.enable_new_ir(False) + engine_mp_prog = self.get_engine("mp") + out_mp_prog = engine_mp_prog.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + + self.enable_new_ir(True) + engine_mp_ir = self.get_engine("mp") + out_mp_ir = engine_mp_ir.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + + self.check_results( + out_mp_prog.history["loss"][0], out_mp_ir.history["loss"][0] + ) + + def test_pp(self): + # navie pipeline parallel without schedule + self.enable_new_ir(False) + engine_pp_prog = self.get_engine("pp") + out_pp_prog = engine_pp_prog.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + + self.enable_new_ir(True) + # send_v2/recv_v2 dynamic_shape is True + engine_pp_ir = self.get_engine("pp") + out_pp_ir = engine_pp_ir.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + + if paddle.distributed.get_rank() == 1: + self.check_results( + out_pp_prog.history["loss"][0], out_pp_ir.history["loss"][0] + ) + + # send_v2/recv_v2 dynamic_shape is False + engine_pp_prog1 = self.get_engine("pp") + dataloader_pp_prog = engine_pp_prog1.dataloader( + self.dataset, + batch_size=self.batch_size, + sample_split=3, + mode="train", + ) + engine_pp_prog1.prepare(mode="train") + for op in engine_pp_prog1.main_program.global_block().ops: + if op.type in ["send_v2", "recv_v2"]: + op.desc._set_attr("dynamic_shape", False) + for data in dataloader_pp_prog: + out_pp_prog1 = engine_pp_prog1.run(data, mode="train") + + if paddle.distributed.get_rank() == 1: + self.check_results( + out_pp_prog1["loss"], out_pp_ir.history["loss"][0] + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_gpt_with_newir.py b/test/auto_parallel/test_gpt_with_newir.py new file mode 100644 index 0000000000000..2f736d8a3b297 --- /dev/null +++ b/test/auto_parallel/test_gpt_with_newir.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + + +class TestGPTNewIR(unittest.TestCase): + def test_gpt_newir(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "gpt_with_newir.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main()