Skip to content

Commit

Permalink
【pir_save_load】add null type saveload and modify save_inference_model…
Browse files Browse the repository at this point in the history
… api (PaddlePaddle#63438)

* add nulltype saveload

* add load test

* add save

* modify

* modify load_inference_model

* modify loadinference
  • Loading branch information
xiaoguoguo626807 authored and Asthestarsfalll committed Apr 17, 2024
1 parent b60182a commit b0a9c82
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ pir::Type parseType(Json* type_json) {
size_t offset = data_json.at(4).get<size_t>();
return pir::DenseTensorType::get(
ctx, dtype, ddim, data_layout, lod, offset);
} else if (type_name == NULL_TYPE) {
return pir::Type();
} else {
PADDLE_ENFORCE(false,
phi::errors::InvalidArgument(
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/serialize_deserialize/include/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@ namespace pir {
// type/attr's contents which is json::array.
#define DATA "D"

// NULL_TYPE
#define NULL_TYPE "NULL"
} // namespace pir
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ Json writeType(const pir::Type& type) {
content.push_back(type_.offset());
type_json[DATA] = content;
return type_json;
} else if (!type) {
type_json[ID] = NULL_TYPE;
return type_json;
} else {
PADDLE_ENFORCE(
false, phi::errors::InvalidArgument("Unknown Type when write type"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
namespace pir {
void ProgramReader::RecoverProgram(Json* program_json,
pir::Program* recover_program) {
id_value_map[0] = pir::Value();
ReadProgram(program_json, recover_program);
VLOG(6) << "Finish json to program.";
return;
Expand Down
27 changes: 19 additions & 8 deletions paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,19 @@ Json ProgramWriter::WriteBlockArg(const pir::Value& value) {

Json ProgramWriter::WriteValue(const pir::Value& value) {
Json var_json;
// Json var = value;
if (value) {
value_id_map[value] = value_id_;
var_json[ID] = value_id_;
VLOG(6) << "Finish write value " << value_id_;
value_id_++;
} else {
var_json[ID] = 0; // NULL_TYPE
VLOG(6) << "Finish write NULL_TYPE value.";
}

Json var = WriteType(value.type());
value_id_map[value] = value_id_;
var_json[ID] = value_id_;
var_json[TYPE_TYPE] = var;
VLOG(6) << "Finish write value " << value_id_;

value_id_++;
return var_json;
}

Expand Down Expand Up @@ -136,9 +141,15 @@ Json ProgramWriter::WriteOp(const pir::Operation& op) {

Json ProgramWriter::WriteOpOperand(const pir::OpOperand& op_operand) {
Json operand_json = Json::object();
int64_t id = value_id_map[op_operand.source()];
operand_json[ID] = id;
VLOG(6) << "Finish write OpOperand " << id;
if (op_operand.source()) {
int64_t id = value_id_map[op_operand.source()];
operand_json[ID] = id;
VLOG(6) << "Finish write OpOperand " << id;
} else {
operand_json[ID] = 0; // NULL_VALUE
VLOG(6) << "Finish write NULL_VALUE OpOperand.";
}

return operand_json;
}

Expand Down
Loading

0 comments on commit b0a9c82

Please sign in to comment.