From 2f077edca5f54a6ba36d51bc3dff986c8e894e65 Mon Sep 17 00:00:00 2001 From: wangruting Date: Wed, 2 Aug 2023 08:03:47 +0000 Subject: [PATCH] add eq and hash --- paddle/fluid/pybind/ir.cc | 41 ++++++++++++++++++++++---------- paddle/ir/core/value.cc | 9 +++++++ paddle/ir/core/value.h | 5 ++++ test/ir/new_ir/test_ir_pybind.py | 24 +++++++++++++------ 4 files changed, 59 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 8f805bf06ec4c..afc69f61e3bc6 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -176,7 +176,13 @@ void BindValue(py::module *m) { .def("get_defining_op", &Value::GetDefiningOp, return_value_policy::reference) - .def("__eq__", &Value::operator==); + .def("__eq__", &Value::operator==) + .def("__eq__", + [](Value &self, OpResult &other) { + return self.impl() == other.value_impl(); + }) + .def("__hash__", + [](const Value &self) { return std::hash{}(self); }); } void BindOpOperand(py::module *m) { @@ -218,18 +224,27 @@ void BindOpResult(py::module *m) { ir::ArrayAttribute::get(ir::IrContext::Instance(), stop_gradients)); }) - .def("get_stop_gradient", [](OpResult &self) { - auto *defining_op = self.owner(); - if (defining_op->HasAttribute(kAttrStopGradients)) { - auto stop_gradients = defining_op->attribute(kAttrStopGradients) - .dyn_cast() - .AsVector(); - return stop_gradients[self.GetResultIndex()] - .dyn_cast() - .data(); - } else { - return false; - } + .def("get_stop_gradient", + [](OpResult &self) { + auto *defining_op = self.owner(); + if (defining_op->HasAttribute(kAttrStopGradients)) { + auto stop_gradients = defining_op->attribute(kAttrStopGradients) + .dyn_cast() + .AsVector(); + return stop_gradients[self.GetResultIndex()] + .dyn_cast() + .data(); + } else { + return false; + } + }) + .def("__eq__", &OpResult::operator==) + .def("__eq__", + [](OpResult &self, Value &other) { + return self.value_impl() == other.impl(); + }) + .def("__hash__", [](OpResult &self) { + return std::hash{}(self.dyn_cast()); }); } diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index 666be5481c418..018342aa81547 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -122,6 +122,15 @@ detail::OpResultImpl *OpResult::impl() const { return reinterpret_cast(impl_); } +bool OpResult::operator==(const OpResult &other) const { + return impl_ == other.impl_; +} + +detail::ValueImpl *OpResult::value_impl() const { + IR_ENFORCE(impl_, "Can't use value_impl() interface while value is null."); + return impl_; +} + uint32_t OpResult::GetValidInlineIndex(uint32_t index) { uint32_t max_inline_index = ir::detail::OpResultImpl::GetMaxInlineResultIndex(); diff --git a/paddle/ir/core/value.h b/paddle/ir/core/value.h index 88f23cd1ee517..86a5566393a22 100644 --- a/paddle/ir/core/value.h +++ b/paddle/ir/core/value.h @@ -192,8 +192,12 @@ class IR_API OpResult : public Value { uint32_t GetResultIndex() const; + bool operator==(const OpResult &other) const; + friend Operation; + detail::ValueImpl *value_impl() const; + private: static uint32_t GetValidInlineIndex(uint32_t index); @@ -209,4 +213,5 @@ struct hash { return std::hash()(obj.impl_); } }; + } // namespace std diff --git a/test/ir/new_ir/test_ir_pybind.py b/test/ir/new_ir/test_ir_pybind.py index 4d32b39ce2154..e3debf308b682 100644 --- a/test/ir/new_ir/test_ir_pybind.py +++ b/test/ir/new_ir/test_ir_pybind.py @@ -84,12 +84,26 @@ def test_value(self): matmul_op.result(0).set_stop_gradient(True) self.assertEqual(matmul_op.result(0).get_stop_gradient(), True) + # test opresult hash result_set = set() for opresult in matmul_op.results(): result_set.add(opresult) - - # self.assertTrue(add_op.operands()[0].source() in result_set) - # self.assertEqual(add_op.operands_source()[0] , matmul_op.results()[0],) + # test opresult hash and hash(opresult) == hash(operesult) + self.assertTrue(add_op.operands()[0].source() in result_set) + # test value hash and hash(value) == hash(operesult) + self.assertTrue(add_op.operands_source()[0] in result_set) + # test value == value + self.assertEqual( + add_op.operands_source()[0], add_op.operands_source()[0] + ) + # test value == opresult + self.assertEqual(add_op.operands_source()[0], matmul_op.results()[0]) + # test opresult == value + self.assertEqual( + add_op.operands()[0].source(), add_op.operands_source()[0] + ) + # test opresult == opresult + self.assertEqual(add_op.operands()[0].source(), matmul_op.results()[0]) self.assertEqual( tanh_op.operands()[0].source().get_defining_op().name(), "pd.add" @@ -100,10 +114,6 @@ def test_value(self): tanh_op.operands()[0].source().get_defining_op().name(), "pd.matmul" ) - self.assertEqual( - tanh_op.operands()[0].source().get_defining_op(), - tanh_op.operands_source()[0].get_defining_op(), - ) self.assertEqual(add_op.result(0).use_empty(), True) def test_type(self):