diff --git a/paddlenlp/trainer/trainer_base.py b/paddlenlp/trainer/trainer_base.py index 363877795596..5962ebd374fc 100644 --- a/paddlenlp/trainer/trainer_base.py +++ b/paddlenlp/trainer/trainer_base.py @@ -1042,14 +1042,16 @@ def save_model(self, output_dir: Optional[str]=None): def export_model(self, input_spec=None, load_best_model=False, - output_dir: Optional[str]=None): - """ Export paddle inference model. + output_dir: Optional[str]=None, + model_format: Optional[str]="paddle"): + """ Export paddle inference model or onnx model. Args: input_spec (paddle.static.InputSpec, optional): InputSpec describes the signature information of the model input, such as shape , dtype , name. Defaults to None. load_best_model (bool, optional): Load best model. Defaults to False. output_dir (Optional[str], optional): Output dir to save the exported model. Defaults to None. + model_format (Optional[str], optional): Export model format. There are two options: paddle or onnx, defaults to paddle. """ if output_dir is None: @@ -1079,14 +1081,26 @@ def export_model(self, model = unwrap_model(self.model) model.eval() - # Convert to static graph with specific input description - model = paddle.jit.to_static(model, input_spec=input_spec) - - # Save in static graph model. - save_path = os.path.join(output_dir, "inference", "infer") - logger.info("Exporting inference model to %s" % save_path) - paddle.jit.save(model, save_path) - logger.info("Inference model exported.") + model_format = model_format.lower() + if model_format == "paddle": + # Convert to static graph with specific input description + model = paddle.jit.to_static(model, input_spec=input_spec) + + # Save in static graph model. + save_path = os.path.join(output_dir, "inference", "infer") + logger.info("Exporting inference model to %s" % save_path) + paddle.jit.save(model, save_path) + logger.info("Inference model exported.") + elif model_format == "onnx": + # Export ONNX model. + save_path = os.path.join(output_dir, "onnx", "model") + logger.info("Exporting ONNX model to %s" % save_path) + paddle.onnx.export(model, save_path, input_spec=input_spec) + logger.info("ONNX model exported.") + else: + logger.info( + "This export format is not supported, please select paddle or onnx!" + ) def _save_checkpoint(self, model, metrics=None): # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" diff --git a/requirements.txt b/requirements.txt index 4a69e6ba033c..0b05fbc2682c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ datasets tqdm paddlefsl sentencepiece +paddle2onnx