Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【new ir】add __eq__ and __hash__ to compare opresult and value #55909

Merged
merged 1 commit into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,13 @@ void BindValue(py::module *m) {
.def("get_defining_op",
&Value::GetDefiningOp,
return_value_policy::reference)
.def("__eq__", &Value::operator==);
.def("__eq__", &Value::operator==)
.def("__eq__",
[](Value &self, OpResult &other) {
return self.impl() == other.value_impl();
})
.def("__hash__",
[](const Value &self) { return std::hash<ir::Value>{}(self); });
}

void BindOpOperand(py::module *m) {
Expand Down Expand Up @@ -218,18 +224,27 @@ void BindOpResult(py::module *m) {
ir::ArrayAttribute::get(ir::IrContext::Instance(),
stop_gradients));
})
.def("get_stop_gradient", [](OpResult &self) {
auto *defining_op = self.owner();
if (defining_op->HasAttribute(kAttrStopGradients)) {
auto stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>()
.AsVector();
return stop_gradients[self.GetResultIndex()]
.dyn_cast<ir::BoolAttribute>()
.data();
} else {
return false;
}
.def("get_stop_gradient",
[](OpResult &self) {
auto *defining_op = self.owner();
if (defining_op->HasAttribute(kAttrStopGradients)) {
auto stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>()
.AsVector();
return stop_gradients[self.GetResultIndex()]
.dyn_cast<ir::BoolAttribute>()
.data();
} else {
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

若没有定义stop_gradient属性,默认策略是False,这个要留意下后续是否有什么影响

}
})
.def("__eq__", &OpResult::operator==)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有了下面的重载,这一句是不是已经不需要了啊

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面的判断opresult == value 的重载

.def("__eq__",
[](OpResult &self, Value &other) {
return self.value_impl() == other.impl();
})
.def("__hash__", [](OpResult &self) {
return std::hash<ir::Value>{}(self.dyn_cast<ir::Value>());
});
}

Expand Down
9 changes: 9 additions & 0 deletions paddle/ir/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ detail::OpResultImpl *OpResult::impl() const {
return reinterpret_cast<detail::OpResultImpl *>(impl_);
}

bool OpResult::operator==(const OpResult &other) const {
return impl_ == other.impl_;
}

detail::ValueImpl *OpResult::value_impl() const {
IR_ENFORCE(impl_, "Can't use value_impl() interface while value is null.");
return impl_;
}

uint32_t OpResult::GetValidInlineIndex(uint32_t index) {
uint32_t max_inline_index =
ir::detail::OpResultImpl::GetMaxInlineResultIndex();
Expand Down
5 changes: 5 additions & 0 deletions paddle/ir/core/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,12 @@ class IR_API OpResult : public Value {

uint32_t GetResultIndex() const;

bool operator==(const OpResult &other) const;

friend Operation;

detail::ValueImpl *value_impl() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么不返回const detail::ValueImpl * ?


private:
static uint32_t GetValidInlineIndex(uint32_t index);

Expand All @@ -209,4 +213,5 @@ struct hash<ir::Value> {
return std::hash<const ir::detail::ValueImpl *>()(obj.impl_);
}
};

} // namespace std
24 changes: 17 additions & 7 deletions test/ir/new_ir/test_ir_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,26 @@ def test_value(self):
matmul_op.result(0).set_stop_gradient(True)
self.assertEqual(matmul_op.result(0).get_stop_gradient(), True)

# test opresult hash
result_set = set()
for opresult in matmul_op.results():
result_set.add(opresult)

# self.assertTrue(add_op.operands()[0].source() in result_set)
# self.assertEqual(add_op.operands_source()[0] , matmul_op.results()[0],)
# test opresult hash and hash(opresult) == hash(operesult)
self.assertTrue(add_op.operands()[0].source() in result_set)
# test value hash and hash(value) == hash(operesult)
self.assertTrue(add_op.operands_source()[0] in result_set)
# test value == value
self.assertEqual(
add_op.operands_source()[0], add_op.operands_source()[0]
)
# test value == opresult
self.assertEqual(add_op.operands_source()[0], matmul_op.results()[0])
# test opresult == value
self.assertEqual(
add_op.operands()[0].source(), add_op.operands_source()[0]
)
# test opresult == opresult
self.assertEqual(add_op.operands()[0].source(), matmul_op.results()[0])

self.assertEqual(
tanh_op.operands()[0].source().get_defining_op().name(), "pd.add"
Expand All @@ -100,10 +114,6 @@ def test_value(self):
tanh_op.operands()[0].source().get_defining_op().name(), "pd.matmul"
)

self.assertEqual(
tanh_op.operands()[0].source().get_defining_op(),
tanh_op.operands_source()[0].get_defining_op(),
)
self.assertEqual(add_op.result(0).use_empty(), True)

def test_type(self):
Expand Down