Skip to content

Commit

Permalink
[PIR] add an api to get param map when translating (#58044)
Browse files Browse the repository at this point in the history
* add an api to get param map when translating

* fix

* fix

* resolve conflicts
  • Loading branch information
kangguangli authored Oct 13, 2023
1 parent 0217838 commit e638c66
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 9 deletions.
11 changes: 11 additions & 0 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -623,5 +623,16 @@ void ProgramTranslator::SetIsPersisableAttributeForAllValue(
}
}

std::unordered_map<std::string, std::vector<pir::Value>>
ProgramTranslator::VarDesc2Value() {
std::unordered_map<std::string, std::vector<pir::Value>> var_desc_2_value;
for (const auto& [var_name, value_info_list] : param_map_) {
for (const auto& value_info : value_info_list) {
var_desc_2_value[var_name].push_back(value_info.value);
}
}
return var_desc_2_value;
}

} // namespace translator
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/ir_adaptor/translator/program_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class ProgramTranslator {

void Translate();

std::unordered_map<std::string, std::vector<pir::Value>> VarDesc2Value();

private:
const ProgramDesc* legacy_program_; // not owned
pir::Program* program_; // not owned
Expand Down
67 changes: 59 additions & 8 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/pir/core/builtin_op.h"

#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
Expand Down Expand Up @@ -340,6 +341,14 @@ void BindOperation(py::module *m) {
});
}

py::str Value2String(const Value &self) {
std::ostringstream print_stream;
print_stream << "Value(";
print_stream << GetValueInfo(self);
print_stream << ")";
return print_stream.str();
}

void BindValue(py::module *m) {
py::class_<Value> value(*m, "Value", R"DOC(
Value class represents the SSA value in the IR system. It is a directed edge
Expand Down Expand Up @@ -374,13 +383,8 @@ void BindValue(py::module *m) {
})
.def("__hash__",
[](const Value &self) { return std::hash<pir::Value>{}(self); })
.def("__str__", [](const Value &self) -> py::str {
std::ostringstream print_stream;
print_stream << "Value(";
print_stream << GetValueInfo(self);
print_stream << ")";
return print_stream.str();
});
.def("__str__", &Value2String)
.def("__repr__", &Value2String);
}

void BindOpOperand(py::module *m) {
Expand Down Expand Up @@ -1146,7 +1150,7 @@ void BindUtils(pybind11::module *m) {
y_s = paddle.matmul(x_s, x_s)
z_s = paddle.add(y_s, y_s)
k_s = paddle.tanh(z_s)
newir_program = ir.translate_to_new_ir(main_program.desc)
newir_program = pir.translate_to_new_ir(main_program.desc)
print(newir_program)
Expand All @@ -1166,6 +1170,53 @@ void BindUtils(pybind11::module *m) {
Returns:
list[str] : List of unregistered operators in paddle dialect, the name is expressed by origin op name.
)DOC");
m->def(
"translate_to_new_ir_with_param_map",
[](const framework::ProgramDesc &legacy_program) {
auto ir_ctx = pir::IrContext::Instance();
auto program = std::make_shared<pir::Program>(ir_ctx);
translator::ProgramTranslator program_translator(&legacy_program,
program.get());
program_translator.Translate();
return std::make_pair(program, program_translator.VarDesc2Value());
},
R"DOC(
Convert Fluid Program to New IR Program and get the mappings of VarDesc -> pir::Value.
Args:
legacy_program (ProgramDesc): The Fluid Program that will be converted.
Returns:
Program: The New IR Program
dict[str, pir::Value]: Mapping between VarDesc(by name) and pir::Value.
Raises:
PreconditionNotMet: If legacy_program has multi block will raise error.
Examples:
.. code-block:: python
import paddle
from paddle import pir
paddle.enable_static()
x = paddle.randn([4, 4])
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x_s = paddle.static.data('x', [4, 4], x.dtype)
x_s.stop_gradient = False
y_s = paddle.matmul(x_s, x_s)
z_s = paddle.add(y_s, y_s)
k_s = paddle.tanh(z_s)
newir_program, mappings = pir.translate_to_new_ir_with_param_map(main_program.desc)
print(newir_program)
print(mappings)
)DOC");
}

void BindIrPass(pybind11::module *m) {
Expand Down
1 change: 1 addition & 0 deletions python/paddle/pir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from paddle.base.libpaddle.pir import ( # noqa: F401
translate_to_new_ir,
translate_to_new_ir_with_param_map,
set_global_program,
set_insertion_point,
reset_insertion_point_to_start,
Expand Down
3 changes: 2 additions & 1 deletion test/ir/new_ir/test_special_op_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def test_op(self):
x = paddle.to_tensor([2, 3, 4], 'float64')
y = paddle.cast(x, 'uint8')

_ = pir.translate_to_new_ir(main_program.desc)
_, mappings = pir.translate_to_new_ir_with_param_map(main_program.desc)
assert len(str(mappings)) > 0, "no mapping found"


class TestElementwiseOpTranscriber(unittest.TestCase):
Expand Down

0 comments on commit e638c66

Please sign in to comment.