From 7d17cb1b8ce87121f4cafe7ca55d5c32bd964c3c Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Fri, 12 Apr 2024 06:54:21 +0000 Subject: [PATCH 1/6] add nulltype saveload --- .../include/deserialize_utils.h | 2 + .../serialize_deserialize/include/schema.h | 2 + .../include/serialize_utils.h | 3 + .../src/ir_deserialize.cc | 1 + .../serialize_deserialize/src/ir_serialize.cc | 27 ++- python/paddle/static/io.py | 211 ++++++++++++++---- test/legacy_test/test_cumsum_op.py | 29 +++ 7 files changed, 223 insertions(+), 52 deletions(-) diff --git a/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h b/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h index a10727a59e6dd..c740fb6f9be5e 100644 --- a/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h +++ b/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h @@ -213,6 +213,8 @@ pir::Type parseType(Json* type_json) { size_t offset = data_json.at(4).get(); return pir::DenseTensorType::get( ctx, dtype, ddim, data_layout, lod, offset); + } else if (type_name == NULL_TYPE) { + return pir::Type(); } else { PADDLE_ENFORCE(false, phi::errors::InvalidArgument( diff --git a/paddle/fluid/pir/serialize_deserialize/include/schema.h b/paddle/fluid/pir/serialize_deserialize/include/schema.h index dd764a8c42e22..ba5a966e08fb7 100644 --- a/paddle/fluid/pir/serialize_deserialize/include/schema.h +++ b/paddle/fluid/pir/serialize_deserialize/include/schema.h @@ -58,4 +58,6 @@ make sure all the key mutually exclusive */ // type/attr's contents which is json::array. #define DATA "data" +// NULL_TYPE +#define NULL_TYPE "NULL" } // namespace pir diff --git a/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h b/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h index 435688ff51b33..94de15c2f2768 100644 --- a/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h +++ b/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h @@ -214,6 +214,9 @@ Json writeType(const pir::Type& type) { content.push_back(type_.offset()); type_json[DATA] = content; return type_json; + } else if (!type) { + type_json[ID] = NULL_TYPE; + return type_json; } else { PADDLE_ENFORCE( false, phi::errors::InvalidArgument("Unknown Type when write type")); diff --git a/paddle/fluid/pir/serialize_deserialize/src/ir_deserialize.cc b/paddle/fluid/pir/serialize_deserialize/src/ir_deserialize.cc index 12f46f33604c3..88ee2ba168476 100644 --- a/paddle/fluid/pir/serialize_deserialize/src/ir_deserialize.cc +++ b/paddle/fluid/pir/serialize_deserialize/src/ir_deserialize.cc @@ -17,6 +17,7 @@ namespace pir { void ProgramReader::RecoverProgram(Json* program_json, pir::Program* recover_program) { + id_value_map[0] = pir::Value(); ReadProgram(program_json, recover_program); VLOG(6) << "Finish json to program."; return; diff --git a/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc b/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc index d4d508377289f..25212601e00c2 100644 --- a/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc +++ b/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc @@ -93,14 +93,19 @@ Json ProgramWriter::WriteBlockArg(const pir::Value& value) { Json ProgramWriter::WriteValue(const pir::Value& value) { Json var_json; - // Json var = value; + if (value) { + value_id_map[value] = value_id_; + var_json[ID] = value_id_; + VLOG(6) << "Finish write value " << value_id_; + value_id_++; + } else { + var_json[ID] = 0; // NULL_TYPE + VLOG(6) << "Finish write NULL_TYPE value."; + } + Json var = WriteType(value.type()); - value_id_map[value] = value_id_; - var_json[ID] = value_id_; var_json[TYPE_TYPE] = var; - VLOG(6) << "Finish write value " << value_id_; - value_id_++; return var_json; } @@ -136,9 +141,15 @@ Json ProgramWriter::WriteOp(const pir::Operation& op) { Json ProgramWriter::WriteOpOperand(const pir::OpOperand& op_operand) { Json operand_json = Json::object(); - int64_t id = value_id_map[op_operand.source()]; - operand_json[ID] = id; - VLOG(6) << "Finish write OpOperand " << id; + if (op_operand.source()) { + int64_t id = value_id_map[op_operand.source()]; + operand_json[ID] = id; + VLOG(6) << "Finish write OpOperand " << id; + } else { + operand_json[ID] = 0; // NULL_VALUE + VLOG(6) << "Finish write NULL_VALUE OpOperand."; + } + return operand_json; } diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index 0d423716665cd..b5c00410e6a0c 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -23,6 +23,7 @@ import numpy as np import paddle +from paddle import pir from paddle.base import ( CompiledProgram, Program, @@ -36,6 +37,7 @@ from paddle.base.framework import ( Parameter, dygraph_not_support, + in_pir_mode, process_type_promotion, static_only, ) @@ -75,7 +77,7 @@ def _check_args(caller, args, supported_args=None, deprecated_args=None): def _check_vars(name, var_list): if not isinstance(var_list, list): var_list = [var_list] - if not all(isinstance(var, Variable) for var in var_list): + if not all(isinstance(var, (Variable, pir.Value)) for var in var_list): raise ValueError( f"'{name}' should be a Variable or a list of Variable." ) @@ -109,7 +111,7 @@ def _get_valid_program(program=None): warnings.warn( "The input is a CompiledProgram, this is not recommended." ) - if not isinstance(program, Program): + if not isinstance(program, (Program, paddle.pir.Program)): raise TypeError( "The type of input program is invalid, expected type is base.Program, but received %s" % type(program) @@ -191,6 +193,100 @@ def append_fetch_ops( ) +def normalize_pir_program(program, feed_vars, fetch_vars, **kwargs): + """ + + Normalize/Optimize a program according to feed_vars and fetch_vars. + + Args: + program(Program): Specify a program you want to optimize. + feed_vars(Tensor | list[Tensor]): Values needed by inference. + fetch_vars(Tensor | list[Tensor]): Values returned by inference. + kwargs: Supported keys including ``skip_prune_program``. + - skip_prune_program(bool): whether to skip pruning program. Defaults to False. + + Returns: + Program: Normalized/Optimized program. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> paddle.enable_static() + + >>> path_prefix = "./infer_model" + + # User defined network, here a softmax regression example + >>> image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32') + >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + >>> predict = paddle.static.nn.fc(image, 10, activation='softmax') + + >>> loss = paddle.nn.functional.cross_entropy(predict, label) + + >>> exe = paddle.static.Executor(paddle.CPUPlace()) + >>> exe.run(paddle.static.default_startup_program()) + + # normalize main program. + >>> program = paddle.static.default_main_program() + >>> normalized_program = paddle.static.normalize_program(program, [image], [predict]) + + """ + if not isinstance(program, pir.Program): + raise TypeError( + "program type must be `paddle.pir.Program`, but received `%s`" + % type(program) + ) + if not isinstance(feed_vars, list): + feed_vars = [feed_vars] + if not all(isinstance(v, pir.Value) for v in feed_vars): + raise TypeError("feed_vars type must be a Value or a list of Variable.") + if not isinstance(fetch_vars, list): + fetch_vars = [fetch_vars] + if not all(isinstance(v, pir.Value) for v in fetch_vars): + raise TypeError( + "fetch_vars type must be a Value or a list of Variable." + ) + + # TODO(Ruting) remind users to set auc_states to 0 if auc op were found. + + # fix the bug that the activation op's output as target will be pruned. + # will affect the inference performance. + # TODO(Superjomn) add an IR pass to remove 1-scale op. + with paddle.static.program_guard(program): + uniq_fetch_vars = [] + for i, var in enumerate(fetch_vars): + if var.dtype != paddle.bool: + var = paddle.scale(var, 1.0, name=f"save_infer_model/scale_{i}") + uniq_fetch_vars.append(var) + fetch_vars = uniq_fetch_vars + + # serialize program + copy_program = program.clone() + global_block = copy_program.global_block() + remove_ops = [] + for op in global_block.ops: + if op.name() == "pd_op.feed" or op.name() == "pd_op.fetch": + remove_ops.append(op) + + for op in remove_ops: + global_block.remove_op(op) + + # feed_var_names = [var.name for var in feed_vars] + + # skip_prune_program = kwargs.get('skip_prune_program', False) + # if not skip_prune_program: + # copy_program = copy_program._prune_with_input( + # feeded_var_names=feed_var_names, targets=fetch_vars + # ) + # copy_program = copy_program._inference_optimize(prune_read_op=True) + # fetch_var_names = [var.name for var in fetch_vars] + # prepend_feed_ops(copy_program, feed_var_names) + # append_fetch_ops(copy_program, fetch_var_names) + + return copy_program + + def normalize_program(program, feed_vars, fetch_vars, **kwargs): """ @@ -578,7 +674,12 @@ def save_inference_model( except OSError as e: if e.errno != errno.EEXIST: raise - model_path = path_prefix + ".pdmodel" + + if in_pir_mode(): + model_path = path_prefix + ".json" + else: + model_path = path_prefix + ".pdmodel" + params_path = path_prefix + ".pdiparams" if os.path.isdir(model_path): raise ValueError(f"'{model_path}' is an existing directory.") @@ -596,40 +697,46 @@ def save_inference_model( program = process_type_promotion(program) clip_extra = kwargs.get('clip_extra', True) - program = normalize_program( - program, - feed_vars, - fetch_vars, - skip_prune_program=kwargs.get('skip_prune_program', False), - ) - # serialize and save program - legacy_format = kwargs.get('legacy_format', False) - program_bytes = _serialize_program( - program._remove_training_info(clip_extra=clip_extra), - legacy_format=legacy_format, - ) - save_to_file(model_path, program_bytes) - - vars = list(filter(is_persistable, program.list_vars())) - - if len(list(vars)) == 0: - warnings.warn( - "no variable in your model, please ensure there are any variables in your model to save" + if in_pir_mode(): + program = normalize_pir_program( + program, + feed_vars, + fetch_vars, + skip_prune_program=kwargs.get('skip_prune_program', False), + ) + paddle.core.serialize_pir_program( + program, model_path, 1, True, False, True ) - if len(vars) > 0: - save_dirname = os.path.dirname(params_path) - params_filename = os.path.basename(params_path) - save_vars( - executor, - dirname=save_dirname, - main_program=program, - predicate=is_persistable, - filename=params_filename, + else: + legacy_format = kwargs.get('legacy_format', False) + program_bytes = _serialize_program( + program._remove_training_info(clip_extra=clip_extra), + legacy_format=legacy_format, ) + save_to_file(model_path, program_bytes) + + vars = list(filter(is_persistable, program.list_vars())) + + if len(list(vars)) == 0: + warnings.warn( + "no variable in your model, please ensure there are any variables in your model to save" + ) + + if len(vars) > 0: + save_dirname = os.path.dirname(params_path) + params_filename = os.path.basename(params_path) + save_vars( + executor, + dirname=save_dirname, + main_program=program, + predicate=is_persistable, + filename=params_filename, + ) + @static_only def deserialize_program(data): @@ -905,12 +1012,17 @@ def load_inference_model(path_prefix, executor, **kwargs): raise ValueError( "params_filename cannot be None when path_prefix is None." ) - program_bytes = model_filename + # deserialize bytes to program - program = deserialize_program(program_bytes) + if in_pir_mode(): + program = paddle.static.Program() + paddle.base.core.deserialize_pir_program(model_filename, program, 1) + else: + program_bytes = model_filename + program = deserialize_program(program_bytes) - # do type promotion - program = process_type_promotion(program) + # do type promotion + program = process_type_promotion(program) vars = list(filter(is_persistable, program.list_vars())) if len(vars) > 0: @@ -932,7 +1044,10 @@ def load_inference_model(path_prefix, executor, **kwargs): # set model_path and params_path in new way, # path_prefix represents a file path without suffix in this case. if not kwargs: - model_path = path_prefix + ".pdmodel" + if in_pir_mode(): + model_path = path_prefix + ".json" + else: + model_path = path_prefix + ".pdmodel" params_path = path_prefix + ".pdiparams" # set model_path and params_path in old way for compatible, # path_prefix represents a directory path. @@ -943,9 +1058,14 @@ def load_inference_model(path_prefix, executor, **kwargs): if model_filename is None: model_path = os.path.join(path_prefix, "__model__") else: - model_path = os.path.join( - path_prefix, model_filename + ".pdmodel" - ) + if in_pir_mode(): + model_path = os.path.join( + path_prefix, model_filename + ".json" + ) + else: + model_path = os.path.join( + path_prefix, model_filename + ".pdmodel" + ) if not os.path.exists(model_path): model_path = os.path.join(path_prefix, model_filename) # set params_path @@ -962,13 +1082,16 @@ def load_inference_model(path_prefix, executor, **kwargs): f" model path: {model_path}, params path: {params_path}" ) - program_bytes = load_from_file(model_path) - # deserialize bytes to program - program = deserialize_program(program_bytes) + if in_pir_mode(): + program = paddle.static.Program() + paddle.base.core.deserialize_pir_program(model_path, program, 1) + else: + program_bytes = load_from_file(model_path) + program = deserialize_program(program_bytes) - # do type promotion - program = process_type_promotion(program) + # do type promotion + program = process_type_promotion(program) vars = list(filter(is_persistable, program.list_vars())) if len(vars) > 0: diff --git a/test/legacy_test/test_cumsum_op.py b/test/legacy_test/test_cumsum_op.py index 5cc45e0b0b117..1996f6e80a4f8 100644 --- a/test/legacy_test/test_cumsum_op.py +++ b/test/legacy_test/test_cumsum_op.py @@ -576,6 +576,35 @@ def test_static_and_infer(self): infer_out = output_handle.copy_to_cpu() np.testing.assert_allclose(static_out[0], infer_out) + def test_static(self): + paddle.enable_static() + np_x = np.random.randn(9, 10, 11).astype('float32') + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + # run static + x = paddle.static.data( + shape=np_x.shape, name='x', dtype=np_x.dtype + ) + linear = paddle.nn.Linear(np_x.shape[-1], np_x.shape[-1]) + linear_out = linear(x) + relu_out = paddle.nn.functional.relu(linear_out) + axis = paddle.full([1], 2, dtype='int64') + out = paddle.cumsum(relu_out, axis=axis) + loss = paddle.mean(out) + sgd = paddle.optimizer.SGD(learning_rate=0.0) + sgd.minimize(paddle.mean(out)) + + exe = paddle.static.Executor(self.place) + exe.run(startup_prog) + static_out = exe.run(feed={'x': np_x}, fetch_list=[out]) + + # run infer + paddle.static.save_inference_model( + self.save_path, [x], [out], exe, program=main_prog + ) + class TestCumSumOpFp16(unittest.TestCase): @test_with_pir_api From 3b259872b25d9fd04033471cdbcd1dacd3938781 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Fri, 12 Apr 2024 07:33:39 +0000 Subject: [PATCH 2/6] add load test --- python/paddle/static/io.py | 36 ++++++++++++++++-------------- test/legacy_test/test_cumsum_op.py | 9 ++++++++ 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index 2c6c045681c9c..d979aef42ae94 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -1086,6 +1086,8 @@ def load_inference_model(path_prefix, executor, **kwargs): if in_pir_mode(): program = paddle.static.Program() paddle.base.core.deserialize_pir_program(model_path, program, 1) + + return [program, [], []] else: program_bytes = load_from_file(model_path) program = deserialize_program(program_bytes) @@ -1093,26 +1095,26 @@ def load_inference_model(path_prefix, executor, **kwargs): # do type promotion program = process_type_promotion(program) - vars = list(filter(is_persistable, program.list_vars())) - if len(vars) > 0: - load_dirname = os.path.dirname(params_path) - params_filename = os.path.basename(params_path) + vars = list(filter(is_persistable, program.list_vars())) + if len(vars) > 0: + load_dirname = os.path.dirname(params_path) + params_filename = os.path.basename(params_path) - load_vars( - executor, - dirname=load_dirname, - main_program=program, - predicate=is_persistable, - filename=params_filename, - ) + load_vars( + executor, + dirname=load_dirname, + main_program=program, + predicate=is_persistable, + filename=params_filename, + ) - feed_target_names = program.desc.get_feed_target_names() - fetch_target_names = program.desc.get_fetch_target_names() - fetch_targets = [ - program.global_block().var(name) for name in fetch_target_names - ] + feed_target_names = program.desc.get_feed_target_names() + fetch_target_names = program.desc.get_fetch_target_names() + fetch_targets = [ + program.global_block().var(name) for name in fetch_target_names + ] - return [program, feed_target_names, fetch_targets] + return [program, feed_target_names, fetch_targets] @dygraph_not_support diff --git a/test/legacy_test/test_cumsum_op.py b/test/legacy_test/test_cumsum_op.py index 1996f6e80a4f8..0ce0b8393860b 100644 --- a/test/legacy_test/test_cumsum_op.py +++ b/test/legacy_test/test_cumsum_op.py @@ -605,6 +605,15 @@ def test_static(self): self.save_path, [x], [out], exe, program=main_prog ) + load_program, _, _ = paddle.static.load_inference_model( + self.save_path, exe + ) + + self.assertEqual( + len(load_program.global_block().ops) + 1, + len(main_prog.global_block().ops), + ) + class TestCumSumOpFp16(unittest.TestCase): @test_with_pir_api From 1e56979f4c0ff2394acfeca7257cf1adea6acd6b Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Fri, 12 Apr 2024 08:25:13 +0000 Subject: [PATCH 3/6] add save --- python/paddle/static/io.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index d979aef42ae94..2959a23fa511d 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -711,6 +711,12 @@ def save_inference_model( ) else: + program = normalize_program( + program, + feed_vars, + fetch_vars, + skip_prune_program=kwargs.get('skip_prune_program', False), + ) legacy_format = kwargs.get('legacy_format', False) program_bytes = _serialize_program( program._remove_training_info(clip_extra=clip_extra), From c766dbcc9e7a526927922c4a6abb3d9eff445a1e Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Fri, 12 Apr 2024 09:32:00 +0000 Subject: [PATCH 4/6] modify --- python/paddle/static/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index 2959a23fa511d..af587b5a7e3cd 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -232,7 +232,7 @@ def normalize_pir_program(program, feed_vars, fetch_vars, **kwargs): >>> normalized_program = paddle.static.normalize_program(program, [image], [predict]) """ - if not isinstance(paddle.static.Program): + if not isinstance(program, paddle.static.Program): raise TypeError( "program type must be `paddle.static.Program`, but received `%s`" % type(program) From 387c686c4c8a5d2f46c527d9fd977528a1f009f7 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Mon, 15 Apr 2024 06:33:27 +0000 Subject: [PATCH 5/6] modify load_inference_model --- python/paddle/static/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index af587b5a7e3cd..2b807b30832c2 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -1093,7 +1093,6 @@ def load_inference_model(path_prefix, executor, **kwargs): program = paddle.static.Program() paddle.base.core.deserialize_pir_program(model_path, program, 1) - return [program, [], []] else: program_bytes = load_from_file(model_path) program = deserialize_program(program_bytes) @@ -1121,6 +1120,7 @@ def load_inference_model(path_prefix, executor, **kwargs): ] return [program, feed_target_names, fetch_targets] + return [program, [], []] @dygraph_not_support From 241e8547ed4c3b69b93cbdb0466a292d85a57be0 Mon Sep 17 00:00:00 2001 From: "wangruting@baidu.com" Date: Mon, 15 Apr 2024 09:34:39 +0000 Subject: [PATCH 6/6] modify loadinference --- python/paddle/static/io.py | 218 +++++++++++++++++++++++++++++-------- 1 file changed, 173 insertions(+), 45 deletions(-) diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index 2b807b30832c2..934cce5ad26ea 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -1001,6 +1001,8 @@ def load_inference_model(path_prefix, executor, **kwargs): # fetch_targets, we can use an executor to run the inference # program to get the inference result. """ + if in_pir_mode(): + return load_pir_inference_model(path_prefix, executor, **kwargs) # check kwargs supported_args = ('model_filename', 'params_filename') deprecated_args = ('pserver_endpoints',) @@ -1018,17 +1020,12 @@ def load_inference_model(path_prefix, executor, **kwargs): raise ValueError( "params_filename cannot be None when path_prefix is None." ) - + program_bytes = model_filename # deserialize bytes to program - if in_pir_mode(): - program = paddle.static.Program() - paddle.base.core.deserialize_pir_program(model_filename, program, 1) - else: - program_bytes = model_filename - program = deserialize_program(program_bytes) + program = deserialize_program(program_bytes) - # do type promotion - program = process_type_promotion(program) + # do type promotion + program = process_type_promotion(program) vars = list(filter(is_persistable, program.list_vars())) if len(vars) > 0: @@ -1050,10 +1047,7 @@ def load_inference_model(path_prefix, executor, **kwargs): # set model_path and params_path in new way, # path_prefix represents a file path without suffix in this case. if not kwargs: - if in_pir_mode(): - model_path = path_prefix + ".json" - else: - model_path = path_prefix + ".pdmodel" + model_path = path_prefix + ".pdmodel" params_path = path_prefix + ".pdiparams" # set model_path and params_path in old way for compatible, # path_prefix represents a directory path. @@ -1064,14 +1058,9 @@ def load_inference_model(path_prefix, executor, **kwargs): if model_filename is None: model_path = os.path.join(path_prefix, "__model__") else: - if in_pir_mode(): - model_path = os.path.join( - path_prefix, model_filename + ".json" - ) - else: - model_path = os.path.join( - path_prefix, model_filename + ".pdmodel" - ) + model_path = os.path.join( + path_prefix, model_filename + ".pdmodel" + ) if not os.path.exists(model_path): model_path = os.path.join(path_prefix, model_filename) # set params_path @@ -1088,38 +1077,177 @@ def load_inference_model(path_prefix, executor, **kwargs): f" model path: {model_path}, params path: {params_path}" ) + program_bytes = load_from_file(model_path) + # deserialize bytes to program - if in_pir_mode(): - program = paddle.static.Program() - paddle.base.core.deserialize_pir_program(model_path, program, 1) + program = deserialize_program(program_bytes) - else: - program_bytes = load_from_file(model_path) - program = deserialize_program(program_bytes) + # do type promotion + program = process_type_promotion(program) - # do type promotion - program = process_type_promotion(program) + vars = list(filter(is_persistable, program.list_vars())) + if len(vars) > 0: + load_dirname = os.path.dirname(params_path) + params_filename = os.path.basename(params_path) - vars = list(filter(is_persistable, program.list_vars())) - if len(vars) > 0: - load_dirname = os.path.dirname(params_path) - params_filename = os.path.basename(params_path) + load_vars( + executor, + dirname=load_dirname, + main_program=program, + predicate=is_persistable, + filename=params_filename, + ) - load_vars( - executor, - dirname=load_dirname, - main_program=program, - predicate=is_persistable, - filename=params_filename, + feed_target_names = program.desc.get_feed_target_names() + fetch_target_names = program.desc.get_fetch_target_names() + fetch_targets = [ + program.global_block().var(name) for name in fetch_target_names + ] + + return [program, feed_target_names, fetch_targets] + + +@static_only +def load_pir_inference_model(path_prefix, executor, **kwargs): + """ + + Load inference model from a given path. By this API, you can get the model + structure(Inference Program) and model parameters. + + Args: + path_prefix(str | None): One of the following: + - Directory path to save model + model name without suffix. + - Set to None when reading the model from memory. + executor(Executor): The executor to run for loading inference model. + See :ref:`api_guide_executor_en` for more details about it. + kwargs: Supported keys including 'model_filename', 'params_filename'. Attention please, kwargs is used for backward compatibility mainly. + + - model_filename(str): specify model_filename if you don't want to use default name. + + - params_filename(str): specify params_filename if you don't want to use default name. + + Returns: + list: The return of this API is a list with three elements: + (program, feed_target_names, fetch_targets). The `program` is a + ``Program`` (refer to :ref:`api_guide_Program_en`), which is used for inference. + The `feed_target_names` is a list of ``str``, which contains names of variables + that need to feed data in the inference program. The `fetch_targets` is a list of + ``Variable`` (refer to :ref:`api_guide_Program_en`). It contains variables from which + we can get inference results. + + Examples: + .. code-block:: python + + >>> import paddle + >>> import numpy as np + + >>> paddle.enable_static() + + # Build the model + >>> startup_prog = paddle.static.default_startup_program() + >>> main_prog = paddle.static.default_main_program() + >>> with paddle.static.program_guard(main_prog, startup_prog): + ... image = paddle.static.data(name="img", shape=[64, 784]) + ... w = paddle.create_parameter(shape=[784, 200], dtype='float32') + ... b = paddle.create_parameter(shape=[200], dtype='float32') + ... hidden_w = paddle.matmul(x=image, y=w) + ... hidden_b = paddle.add(hidden_w, b) + >>> exe = paddle.static.Executor(paddle.CPUPlace()) + >>> exe.run(startup_prog) + + # Save the inference model + >>> path_prefix = "./infer_model" + >>> paddle.static.save_inference_model(path_prefix, [image], [hidden_b], exe) + + >>> [inference_program, feed_target_names, fetch_targets] = ( + ... paddle.static.load_inference_model(path_prefix, exe)) + >>> tensor_img = np.array(np.random.random((64, 784)), dtype=np.float32) + >>> results = exe.run(inference_program, + ... feed={feed_target_names[0]: tensor_img}, + ... fetch_list=fetch_targets) + + # In this example, the inference program was saved in file + # "./infer_model.pdmodel" and parameters were saved in file + # " ./infer_model.pdiparams". + # By the inference program, feed_target_names and + # fetch_targets, we can use an executor to run the inference + # program to get the inference result. + """ + # check kwargs + supported_args = ('model_filename', 'params_filename') + deprecated_args = ('pserver_endpoints',) + caller = inspect.currentframe().f_code.co_name + _check_args(caller, kwargs, supported_args, deprecated_args) + + # load from memory + if path_prefix is None: + _logger.warning( + "Load inference model from memory is deprecated. Please specify path_prefix." + ) + model_filename = kwargs.get('model_filename', None) + params_filename = kwargs.get('params_filename', None) + if params_filename is None: + raise ValueError( + "params_filename cannot be None when path_prefix is None." + ) + + # deserialize bytes to program + program = paddle.static.Program() + paddle.base.core.deserialize_pir_program(model_filename, program, 1) + + vars = list(filter(is_persistable, program.list_vars())) + if len(vars) > 0: + load_vars( + executor, + # load from memory, dirname is None + dirname=None, + main_program=program, + predicate=is_persistable, + filename=params_filename, + ) + # load from file + else: + # check and norm path_prefix + path_prefix = _normalize_path_prefix(path_prefix) + dir_path = os.path.dirname(path_prefix) + if not os.path.isdir(dir_path): + raise ValueError(f"There is no directory named {dir_path}") + # set model_path and params_path in new way, + # path_prefix represents a file path without suffix in this case. + if not kwargs: + model_path = path_prefix + ".json" + params_path = path_prefix + ".pdiparams" + # set model_path and params_path in old way for compatible, + # path_prefix represents a directory path. + else: + model_filename = kwargs.get('model_filename', None) + params_filename = kwargs.get('params_filename', None) + # set model_path + if model_filename is None: + model_path = os.path.join(path_prefix, "__model__") + else: + model_path = os.path.join(path_prefix, model_filename + ".json") + + if not os.path.exists(model_path): + model_path = os.path.join(path_prefix, model_filename) + # set params_path + if params_filename is None: + params_path = os.path.join(path_prefix, "") + else: + params_path = os.path.join( + path_prefix, params_filename + ".pdiparams" ) + if not os.path.exists(params_path): + params_path = os.path.join(path_prefix, params_filename) + _logger.warning( + "The old way to load inference model is deprecated. Please specify path_prefix." + f" model path: {model_path}, params path: {params_path}" + ) - feed_target_names = program.desc.get_feed_target_names() - fetch_target_names = program.desc.get_fetch_target_names() - fetch_targets = [ - program.global_block().var(name) for name in fetch_target_names - ] + # deserialize bytes to program + program = paddle.static.Program() + paddle.base.core.deserialize_pir_program(model_path, program, 1) - return [program, feed_target_names, fetch_targets] return [program, [], []]