From afce826e4135eb0c65c14b79a51fe726dff82d78 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Mon, 22 Aug 2022 10:54:20 +0800 Subject: [PATCH] add paddle version check help function (#650) --- fleetx/core/engine/eager_engine.py | 8 +++++--- fleetx/utils/download.py | 3 ++- fleetx/utils/version.py | 21 +++++++++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) create mode 100644 fleetx/utils/version.py diff --git a/fleetx/core/engine/eager_engine.py b/fleetx/core/engine/eager_engine.py index 06a2b6319c4fb..0f40eca286416 100644 --- a/fleetx/core/engine/eager_engine.py +++ b/fleetx/core/engine/eager_engine.py @@ -30,6 +30,7 @@ from fleetx.core.engine.basic_engine import BasicEngine from fleetx.core.module.basic_module import BasicModule from fleetx.utils.tensor_fusion_helper import all_reduce_parameters +from fleetx.utils.version import version_check class EagerEngine(BasicEngine): @@ -81,6 +82,7 @@ def configure_optimizers(self): """ super().__init__() + version_check() self.mode = mode @@ -492,8 +494,8 @@ def load(self): opt_dict = paddle.load(opt_path) self._module.optimizer.set_state_dict(opt_dict) else: - raise ValueError("No optimizer checkpoint file found in %s." % - opt_path) + raise ValueError( + "No optimizer checkpoint file found in %s." % opt_path) if os.path.exists(meta_path): meta_dict = paddle.load(meta_path) @@ -507,7 +509,7 @@ def load(self): self._module.global_step = resume_step else: raise ValueError("No meta checkpoint file found in %s." % - meta_path) + meta_path) logger.info("successfully load checkpoints") else: diff --git a/fleetx/utils/download.py b/fleetx/utils/download.py index 0123436c6ce52..095a8f5c671a6 100644 --- a/fleetx/utils/download.py +++ b/fleetx/utils/download.py @@ -116,7 +116,8 @@ def _download(url, fullname): def download(url, path): local_rank = 0 world_size = 1 - if paddle.fluid.core.is_compiled_with_dist(): + if paddle.fluid.core.is_compiled_with_dist( + ) and paddle.distributed.get_world_size() > 1: local_rank = paddle.distributed.ParallelEnv().dev_id world_size = paddle.distributed.get_world_size() if world_size > 1 and local_rank != 0: diff --git a/fleetx/utils/version.py b/fleetx/utils/version.py new file mode 100644 index 0000000000000..f899a1ced33d6 --- /dev/null +++ b/fleetx/utils/version.py @@ -0,0 +1,21 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + + +def version_check(): + version = paddle.version.full_version + if version != '0.0.0': + paddle.utils.require_version(min_version='2.3.0')