diff --git a/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h b/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h index 1eaa8843033f1..d4aaefe81c983 100644 --- a/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h +++ b/paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h @@ -212,6 +212,8 @@ pir::Type parseType(Json* type_json) { size_t offset = data_json.at(4).get(); return pir::DenseTensorType::get( ctx, dtype, ddim, data_layout, lod, offset); + } else if (type_name == NULL_TYPE) { + return pir::Type(); } else { PADDLE_ENFORCE(false, phi::errors::InvalidArgument( diff --git a/paddle/fluid/pir/serialize_deserialize/include/schema.h b/paddle/fluid/pir/serialize_deserialize/include/schema.h index 75973b99ca049..d444bee469596 100644 --- a/paddle/fluid/pir/serialize_deserialize/include/schema.h +++ b/paddle/fluid/pir/serialize_deserialize/include/schema.h @@ -62,4 +62,6 @@ namespace pir { // type/attr's contents which is json::array. #define DATA "D" +// NULL_TYPE +#define NULL_TYPE "NULL" } // namespace pir diff --git a/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h b/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h index 0b75579e080b3..a6cae97f135d9 100644 --- a/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h +++ b/paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h @@ -248,6 +248,9 @@ Json writeType(const pir::Type& type) { content.push_back(type_.offset()); type_json[DATA] = content; return type_json; + } else if (!type) { + type_json[ID] = NULL_TYPE; + return type_json; } else { PADDLE_ENFORCE( false, phi::errors::InvalidArgument("Unknown Type when write type")); diff --git a/paddle/fluid/pir/serialize_deserialize/src/ir_deserialize.cc b/paddle/fluid/pir/serialize_deserialize/src/ir_deserialize.cc index 12f46f33604c3..88ee2ba168476 100644 --- a/paddle/fluid/pir/serialize_deserialize/src/ir_deserialize.cc +++ b/paddle/fluid/pir/serialize_deserialize/src/ir_deserialize.cc @@ -17,6 +17,7 @@ namespace pir { void ProgramReader::RecoverProgram(Json* program_json, pir::Program* recover_program) { + id_value_map[0] = pir::Value(); ReadProgram(program_json, recover_program); VLOG(6) << "Finish json to program."; return; diff --git a/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc b/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc index 7af995074461b..21067aa83906d 100644 --- a/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc +++ b/paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc @@ -93,14 +93,19 @@ Json ProgramWriter::WriteBlockArg(const pir::Value& value) { Json ProgramWriter::WriteValue(const pir::Value& value) { Json var_json; - // Json var = value; + if (value) { + value_id_map[value] = value_id_; + var_json[ID] = value_id_; + VLOG(6) << "Finish write value " << value_id_; + value_id_++; + } else { + var_json[ID] = 0; // NULL_TYPE + VLOG(6) << "Finish write NULL_TYPE value."; + } + Json var = WriteType(value.type()); - value_id_map[value] = value_id_; - var_json[ID] = value_id_; var_json[TYPE_TYPE] = var; - VLOG(6) << "Finish write value " << value_id_; - value_id_++; return var_json; } @@ -136,9 +141,15 @@ Json ProgramWriter::WriteOp(const pir::Operation& op) { Json ProgramWriter::WriteOpOperand(const pir::OpOperand& op_operand) { Json operand_json = Json::object(); - int64_t id = value_id_map[op_operand.source()]; - operand_json[ID] = id; - VLOG(6) << "Finish write OpOperand " << id; + if (op_operand.source()) { + int64_t id = value_id_map[op_operand.source()]; + operand_json[ID] = id; + VLOG(6) << "Finish write OpOperand " << id; + } else { + operand_json[ID] = 0; // NULL_VALUE + VLOG(6) << "Finish write NULL_VALUE OpOperand."; + } + return operand_json; } diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index a5bb29104f89b..934cce5ad26ea 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -23,6 +23,7 @@ import numpy as np import paddle +from paddle import pir from paddle.base import ( CompiledProgram, Program, @@ -36,6 +37,7 @@ from paddle.base.framework import ( Parameter, dygraph_not_support, + in_pir_mode, process_type_promotion, static_only, ) @@ -75,7 +77,7 @@ def _check_args(caller, args, supported_args=None, deprecated_args=None): def _check_vars(name, var_list): if not isinstance(var_list, list): var_list = [var_list] - if not all(isinstance(var, Variable) for var in var_list): + if not all(isinstance(var, (Variable, pir.Value)) for var in var_list): raise ValueError( f"'{name}' should be a Variable or a list of Variable." ) @@ -191,6 +193,100 @@ def append_fetch_ops( ) +def normalize_pir_program(program, feed_vars, fetch_vars, **kwargs): + """ + + Normalize/Optimize a program according to feed_vars and fetch_vars. + + Args: + program(Program): Specify a program you want to optimize. + feed_vars(Tensor | list[Tensor]): Values needed by inference. + fetch_vars(Tensor | list[Tensor]): Values returned by inference. + kwargs: Supported keys including ``skip_prune_program``. + - skip_prune_program(bool): whether to skip pruning program. Defaults to False. + + Returns: + Program: Normalized/Optimized program. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> paddle.enable_static() + + >>> path_prefix = "./infer_model" + + # User defined network, here a softmax regression example + >>> image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32') + >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + >>> predict = paddle.static.nn.fc(image, 10, activation='softmax') + + >>> loss = paddle.nn.functional.cross_entropy(predict, label) + + >>> exe = paddle.static.Executor(paddle.CPUPlace()) + >>> exe.run(paddle.static.default_startup_program()) + + # normalize main program. + >>> program = paddle.static.default_main_program() + >>> normalized_program = paddle.static.normalize_program(program, [image], [predict]) + + """ + if not isinstance(program, paddle.static.Program): + raise TypeError( + "program type must be `paddle.static.Program`, but received `%s`" + % type(program) + ) + if not isinstance(feed_vars, list): + feed_vars = [feed_vars] + if not all(isinstance(v, pir.Value) for v in feed_vars): + raise TypeError("feed_vars type must be a Value or a list of Variable.") + if not isinstance(fetch_vars, list): + fetch_vars = [fetch_vars] + if not all(isinstance(v, pir.Value) for v in fetch_vars): + raise TypeError( + "fetch_vars type must be a Value or a list of Variable." + ) + + # TODO(Ruting) remind users to set auc_states to 0 if auc op were found. + + # fix the bug that the activation op's output as target will be pruned. + # will affect the inference performance. + # TODO(Superjomn) add an IR pass to remove 1-scale op. + with paddle.static.program_guard(program): + uniq_fetch_vars = [] + for i, var in enumerate(fetch_vars): + if var.dtype != paddle.bool: + var = paddle.scale(var, 1.0, name=f"save_infer_model/scale_{i}") + uniq_fetch_vars.append(var) + fetch_vars = uniq_fetch_vars + + # serialize program + copy_program = program.clone() + global_block = copy_program.global_block() + remove_ops = [] + for op in global_block.ops: + if op.name() == "pd_op.feed" or op.name() == "pd_op.fetch": + remove_ops.append(op) + + for op in remove_ops: + global_block.remove_op(op) + + # feed_var_names = [var.name for var in feed_vars] + + # skip_prune_program = kwargs.get('skip_prune_program', False) + # if not skip_prune_program: + # copy_program = copy_program._prune_with_input( + # feeded_var_names=feed_var_names, targets=fetch_vars + # ) + # copy_program = copy_program._inference_optimize(prune_read_op=True) + # fetch_var_names = [var.name for var in fetch_vars] + # prepend_feed_ops(copy_program, feed_var_names) + # append_fetch_ops(copy_program, fetch_var_names) + + return copy_program + + def normalize_program(program, feed_vars, fetch_vars, **kwargs): """ @@ -578,7 +674,12 @@ def save_inference_model( except OSError as e: if e.errno != errno.EEXIST: raise - model_path = path_prefix + ".pdmodel" + + if in_pir_mode(): + model_path = path_prefix + ".json" + else: + model_path = path_prefix + ".pdmodel" + params_path = path_prefix + ".pdiparams" if os.path.isdir(model_path): raise ValueError(f"'{model_path}' is an existing directory.") @@ -596,40 +697,52 @@ def save_inference_model( program = process_type_promotion(program) clip_extra = kwargs.get('clip_extra', True) - program = normalize_program( - program, - feed_vars, - fetch_vars, - skip_prune_program=kwargs.get('skip_prune_program', False), - ) - # serialize and save program - legacy_format = kwargs.get('legacy_format', False) - program_bytes = _serialize_program( - program._remove_training_info(clip_extra=clip_extra), - legacy_format=legacy_format, - ) - - save_to_file(model_path, program_bytes) - vars = list(filter(is_persistable, program.list_vars())) - - if len(list(vars)) == 0: - warnings.warn( - "no variable in your model, please ensure there are any variables in your model to save" + if in_pir_mode(): + program = normalize_pir_program( + program, + feed_vars, + fetch_vars, + skip_prune_program=kwargs.get('skip_prune_program', False), + ) + paddle.core.serialize_pir_program( + program, model_path, 1, True, False, True ) - if len(vars) > 0: - save_dirname = os.path.dirname(params_path) - params_filename = os.path.basename(params_path) - save_vars( - executor, - dirname=save_dirname, - main_program=program, - predicate=is_persistable, - filename=params_filename, + else: + program = normalize_program( + program, + feed_vars, + fetch_vars, + skip_prune_program=kwargs.get('skip_prune_program', False), + ) + legacy_format = kwargs.get('legacy_format', False) + program_bytes = _serialize_program( + program._remove_training_info(clip_extra=clip_extra), + legacy_format=legacy_format, ) + save_to_file(model_path, program_bytes) + + vars = list(filter(is_persistable, program.list_vars())) + + if len(list(vars)) == 0: + warnings.warn( + "no variable in your model, please ensure there are any variables in your model to save" + ) + + if len(vars) > 0: + save_dirname = os.path.dirname(params_path) + params_filename = os.path.basename(params_path) + save_vars( + executor, + dirname=save_dirname, + main_program=program, + predicate=is_persistable, + filename=params_filename, + ) + @static_only def deserialize_program(data): @@ -888,6 +1001,8 @@ def load_inference_model(path_prefix, executor, **kwargs): # fetch_targets, we can use an executor to run the inference # program to get the inference result. """ + if in_pir_mode(): + return load_pir_inference_model(path_prefix, executor, **kwargs) # check kwargs supported_args = ('model_filename', 'params_filename') deprecated_args = ('pserver_endpoints',) @@ -992,6 +1107,150 @@ def load_inference_model(path_prefix, executor, **kwargs): return [program, feed_target_names, fetch_targets] +@static_only +def load_pir_inference_model(path_prefix, executor, **kwargs): + """ + + Load inference model from a given path. By this API, you can get the model + structure(Inference Program) and model parameters. + + Args: + path_prefix(str | None): One of the following: + - Directory path to save model + model name without suffix. + - Set to None when reading the model from memory. + executor(Executor): The executor to run for loading inference model. + See :ref:`api_guide_executor_en` for more details about it. + kwargs: Supported keys including 'model_filename', 'params_filename'. Attention please, kwargs is used for backward compatibility mainly. + + - model_filename(str): specify model_filename if you don't want to use default name. + + - params_filename(str): specify params_filename if you don't want to use default name. + + Returns: + list: The return of this API is a list with three elements: + (program, feed_target_names, fetch_targets). The `program` is a + ``Program`` (refer to :ref:`api_guide_Program_en`), which is used for inference. + The `feed_target_names` is a list of ``str``, which contains names of variables + that need to feed data in the inference program. The `fetch_targets` is a list of + ``Variable`` (refer to :ref:`api_guide_Program_en`). It contains variables from which + we can get inference results. + + Examples: + .. code-block:: python + + >>> import paddle + >>> import numpy as np + + >>> paddle.enable_static() + + # Build the model + >>> startup_prog = paddle.static.default_startup_program() + >>> main_prog = paddle.static.default_main_program() + >>> with paddle.static.program_guard(main_prog, startup_prog): + ... image = paddle.static.data(name="img", shape=[64, 784]) + ... w = paddle.create_parameter(shape=[784, 200], dtype='float32') + ... b = paddle.create_parameter(shape=[200], dtype='float32') + ... hidden_w = paddle.matmul(x=image, y=w) + ... hidden_b = paddle.add(hidden_w, b) + >>> exe = paddle.static.Executor(paddle.CPUPlace()) + >>> exe.run(startup_prog) + + # Save the inference model + >>> path_prefix = "./infer_model" + >>> paddle.static.save_inference_model(path_prefix, [image], [hidden_b], exe) + + >>> [inference_program, feed_target_names, fetch_targets] = ( + ... paddle.static.load_inference_model(path_prefix, exe)) + >>> tensor_img = np.array(np.random.random((64, 784)), dtype=np.float32) + >>> results = exe.run(inference_program, + ... feed={feed_target_names[0]: tensor_img}, + ... fetch_list=fetch_targets) + + # In this example, the inference program was saved in file + # "./infer_model.pdmodel" and parameters were saved in file + # " ./infer_model.pdiparams". + # By the inference program, feed_target_names and + # fetch_targets, we can use an executor to run the inference + # program to get the inference result. + """ + # check kwargs + supported_args = ('model_filename', 'params_filename') + deprecated_args = ('pserver_endpoints',) + caller = inspect.currentframe().f_code.co_name + _check_args(caller, kwargs, supported_args, deprecated_args) + + # load from memory + if path_prefix is None: + _logger.warning( + "Load inference model from memory is deprecated. Please specify path_prefix." + ) + model_filename = kwargs.get('model_filename', None) + params_filename = kwargs.get('params_filename', None) + if params_filename is None: + raise ValueError( + "params_filename cannot be None when path_prefix is None." + ) + + # deserialize bytes to program + program = paddle.static.Program() + paddle.base.core.deserialize_pir_program(model_filename, program, 1) + + vars = list(filter(is_persistable, program.list_vars())) + if len(vars) > 0: + load_vars( + executor, + # load from memory, dirname is None + dirname=None, + main_program=program, + predicate=is_persistable, + filename=params_filename, + ) + # load from file + else: + # check and norm path_prefix + path_prefix = _normalize_path_prefix(path_prefix) + dir_path = os.path.dirname(path_prefix) + if not os.path.isdir(dir_path): + raise ValueError(f"There is no directory named {dir_path}") + # set model_path and params_path in new way, + # path_prefix represents a file path without suffix in this case. + if not kwargs: + model_path = path_prefix + ".json" + params_path = path_prefix + ".pdiparams" + # set model_path and params_path in old way for compatible, + # path_prefix represents a directory path. + else: + model_filename = kwargs.get('model_filename', None) + params_filename = kwargs.get('params_filename', None) + # set model_path + if model_filename is None: + model_path = os.path.join(path_prefix, "__model__") + else: + model_path = os.path.join(path_prefix, model_filename + ".json") + + if not os.path.exists(model_path): + model_path = os.path.join(path_prefix, model_filename) + # set params_path + if params_filename is None: + params_path = os.path.join(path_prefix, "") + else: + params_path = os.path.join( + path_prefix, params_filename + ".pdiparams" + ) + if not os.path.exists(params_path): + params_path = os.path.join(path_prefix, params_filename) + _logger.warning( + "The old way to load inference model is deprecated. Please specify path_prefix." + f" model path: {model_path}, params path: {params_path}" + ) + + # deserialize bytes to program + program = paddle.static.Program() + paddle.base.core.deserialize_pir_program(model_path, program, 1) + + return [program, [], []] + + @dygraph_not_support def save_vars( executor, diff --git a/test/legacy_test/test_cumsum_op.py b/test/legacy_test/test_cumsum_op.py index 5cc45e0b0b117..0ce0b8393860b 100644 --- a/test/legacy_test/test_cumsum_op.py +++ b/test/legacy_test/test_cumsum_op.py @@ -576,6 +576,44 @@ def test_static_and_infer(self): infer_out = output_handle.copy_to_cpu() np.testing.assert_allclose(static_out[0], infer_out) + def test_static(self): + paddle.enable_static() + np_x = np.random.randn(9, 10, 11).astype('float32') + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + # run static + x = paddle.static.data( + shape=np_x.shape, name='x', dtype=np_x.dtype + ) + linear = paddle.nn.Linear(np_x.shape[-1], np_x.shape[-1]) + linear_out = linear(x) + relu_out = paddle.nn.functional.relu(linear_out) + axis = paddle.full([1], 2, dtype='int64') + out = paddle.cumsum(relu_out, axis=axis) + loss = paddle.mean(out) + sgd = paddle.optimizer.SGD(learning_rate=0.0) + sgd.minimize(paddle.mean(out)) + + exe = paddle.static.Executor(self.place) + exe.run(startup_prog) + static_out = exe.run(feed={'x': np_x}, fetch_list=[out]) + + # run infer + paddle.static.save_inference_model( + self.save_path, [x], [out], exe, program=main_prog + ) + + load_program, _, _ = paddle.static.load_inference_model( + self.save_path, exe + ) + + self.assertEqual( + len(load_program.global_block().ops) + 1, + len(main_prog.global_block().ops), + ) + class TestCumSumOpFp16(unittest.TestCase): @test_with_pir_api