From 5da9c385e2610d91a9fde4268e5c964deed93059 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Thu, 5 May 2022 16:38:38 +0800 Subject: [PATCH 1/4] Add ONNX Export --- paddlenlp/trainer/trainer_base.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/paddlenlp/trainer/trainer_base.py b/paddlenlp/trainer/trainer_base.py index 363877795596..d93ddf8f5a29 100644 --- a/paddlenlp/trainer/trainer_base.py +++ b/paddlenlp/trainer/trainer_base.py @@ -1042,14 +1042,16 @@ def save_model(self, output_dir: Optional[str]=None): def export_model(self, input_spec=None, load_best_model=False, - output_dir: Optional[str]=None): - """ Export paddle inference model. + output_dir: Optional[str]=None, + export_model_format: Optional[str]="Paddle"): + """ Export paddle inference model or ONNX model. Args: input_spec (paddle.static.InputSpec, optional): InputSpec describes the signature information of the model input, such as shape , dtype , name. Defaults to None. load_best_model (bool, optional): Load best model. Defaults to False. output_dir (Optional[str], optional): Output dir to save the exported model. Defaults to None. + export_model_format (Optional[str], optional): Export model format. Defaults to Paddle. """ if output_dir is None: @@ -1079,14 +1081,21 @@ def export_model(self, model = unwrap_model(self.model) model.eval() - # Convert to static graph with specific input description - model = paddle.jit.to_static(model, input_spec=input_spec) - - # Save in static graph model. - save_path = os.path.join(output_dir, "inference", "infer") - logger.info("Exporting inference model to %s" % save_path) - paddle.jit.save(model, save_path) - logger.info("Inference model exported.") + if export_model_format == "Paddle": + # Convert to static graph with specific input description + model = paddle.jit.to_static(model, input_spec=input_spec) + + # Save in static graph model. + save_path = os.path.join(output_dir, "inference", "infer") + logger.info("Exporting inference model to %s" % save_path) + paddle.jit.save(model, save_path) + logger.info("Inference model exported.") + elif export_model_format == "ONNX": + # Export ONNX model. + save_path = os.path.join(output_dir, "onnx", "model") + logger.info("Exporting ONNX model to %s" % save_path) + paddle.onnx.export(model, save_path, input_spec=input_spec) + logger.info("ONNX model exported.") def _save_checkpoint(self, model, metrics=None): # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" From e40965a4654e657c3ad37a7d95ff3490459002f4 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Thu, 5 May 2022 17:09:55 +0800 Subject: [PATCH 2/4] deal with comments --- paddlenlp/trainer/trainer_base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/paddlenlp/trainer/trainer_base.py b/paddlenlp/trainer/trainer_base.py index d93ddf8f5a29..b4e9ac48c0aa 100644 --- a/paddlenlp/trainer/trainer_base.py +++ b/paddlenlp/trainer/trainer_base.py @@ -1043,7 +1043,7 @@ def export_model(self, input_spec=None, load_best_model=False, output_dir: Optional[str]=None, - export_model_format: Optional[str]="Paddle"): + export_model_format: Optional[str]="paddle"): """ Export paddle inference model or ONNX model. Args: @@ -1051,7 +1051,7 @@ def export_model(self, such as shape , dtype , name. Defaults to None. load_best_model (bool, optional): Load best model. Defaults to False. output_dir (Optional[str], optional): Output dir to save the exported model. Defaults to None. - export_model_format (Optional[str], optional): Export model format. Defaults to Paddle. + export_model_format (Optional[str], optional): Export model format. There are two options: paddle or onnx, defaults to paddle. """ if output_dir is None: @@ -1081,7 +1081,8 @@ def export_model(self, model = unwrap_model(self.model) model.eval() - if export_model_format == "Paddle": + export_model_format = export_model_format.lower() + if export_model_format == "paddle": # Convert to static graph with specific input description model = paddle.jit.to_static(model, input_spec=input_spec) @@ -1090,12 +1091,15 @@ def export_model(self, logger.info("Exporting inference model to %s" % save_path) paddle.jit.save(model, save_path) logger.info("Inference model exported.") - elif export_model_format == "ONNX": + elif export_model_format == "onnx": # Export ONNX model. save_path = os.path.join(output_dir, "onnx", "model") logger.info("Exporting ONNX model to %s" % save_path) paddle.onnx.export(model, save_path, input_spec=input_spec) logger.info("ONNX model exported.") + else: + logger.info( + "This export format is not supported, please select paddle or onnx!") def _save_checkpoint(self, model, metrics=None): # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" From 2fd0e372313555b6dfd1eadf3bc1dd67828d3b8c Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Thu, 5 May 2022 17:13:45 +0800 Subject: [PATCH 3/4] deal with comments --- paddlenlp/trainer/trainer_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddlenlp/trainer/trainer_base.py b/paddlenlp/trainer/trainer_base.py index b4e9ac48c0aa..b06aafbd4516 100644 --- a/paddlenlp/trainer/trainer_base.py +++ b/paddlenlp/trainer/trainer_base.py @@ -1044,7 +1044,7 @@ def export_model(self, load_best_model=False, output_dir: Optional[str]=None, export_model_format: Optional[str]="paddle"): - """ Export paddle inference model or ONNX model. + """ Export paddle inference model or onnx model. Args: input_spec (paddle.static.InputSpec, optional): InputSpec describes the signature information of the model input, @@ -1099,7 +1099,8 @@ def export_model(self, logger.info("ONNX model exported.") else: logger.info( - "This export format is not supported, please select paddle or onnx!") + "This export format is not supported, please select paddle or onnx!" + ) def _save_checkpoint(self, model, metrics=None): # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" From c8f5f9d0c5fc5d3a7ec3a825bf583d71907211d1 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Fri, 6 May 2022 10:34:21 +0800 Subject: [PATCH 4/4] deal with comments --- paddlenlp/trainer/trainer_base.py | 10 +++++----- requirements.txt | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddlenlp/trainer/trainer_base.py b/paddlenlp/trainer/trainer_base.py index b06aafbd4516..5962ebd374fc 100644 --- a/paddlenlp/trainer/trainer_base.py +++ b/paddlenlp/trainer/trainer_base.py @@ -1043,7 +1043,7 @@ def export_model(self, input_spec=None, load_best_model=False, output_dir: Optional[str]=None, - export_model_format: Optional[str]="paddle"): + model_format: Optional[str]="paddle"): """ Export paddle inference model or onnx model. Args: @@ -1051,7 +1051,7 @@ def export_model(self, such as shape , dtype , name. Defaults to None. load_best_model (bool, optional): Load best model. Defaults to False. output_dir (Optional[str], optional): Output dir to save the exported model. Defaults to None. - export_model_format (Optional[str], optional): Export model format. There are two options: paddle or onnx, defaults to paddle. + model_format (Optional[str], optional): Export model format. There are two options: paddle or onnx, defaults to paddle. """ if output_dir is None: @@ -1081,8 +1081,8 @@ def export_model(self, model = unwrap_model(self.model) model.eval() - export_model_format = export_model_format.lower() - if export_model_format == "paddle": + model_format = model_format.lower() + if model_format == "paddle": # Convert to static graph with specific input description model = paddle.jit.to_static(model, input_spec=input_spec) @@ -1091,7 +1091,7 @@ def export_model(self, logger.info("Exporting inference model to %s" % save_path) paddle.jit.save(model, save_path) logger.info("Inference model exported.") - elif export_model_format == "onnx": + elif model_format == "onnx": # Export ONNX model. save_path = os.path.join(output_dir, "onnx", "model") logger.info("Exporting ONNX model to %s" % save_path) diff --git a/requirements.txt b/requirements.txt index 4a69e6ba033c..0b05fbc2682c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ datasets tqdm paddlefsl sentencepiece +paddle2onnx