-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【new ir】add __eq__ and __hash__ to compare opresult and value #55909
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -176,7 +176,13 @@ void BindValue(py::module *m) { | |
.def("get_defining_op", | ||
&Value::GetDefiningOp, | ||
return_value_policy::reference) | ||
.def("__eq__", &Value::operator==); | ||
.def("__eq__", &Value::operator==) | ||
.def("__eq__", | ||
[](Value &self, OpResult &other) { | ||
return self.impl() == other.value_impl(); | ||
}) | ||
.def("__hash__", | ||
[](const Value &self) { return std::hash<ir::Value>{}(self); }); | ||
} | ||
|
||
void BindOpOperand(py::module *m) { | ||
|
@@ -218,18 +224,27 @@ void BindOpResult(py::module *m) { | |
ir::ArrayAttribute::get(ir::IrContext::Instance(), | ||
stop_gradients)); | ||
}) | ||
.def("get_stop_gradient", [](OpResult &self) { | ||
auto *defining_op = self.owner(); | ||
if (defining_op->HasAttribute(kAttrStopGradients)) { | ||
auto stop_gradients = defining_op->attribute(kAttrStopGradients) | ||
.dyn_cast<ir::ArrayAttribute>() | ||
.AsVector(); | ||
return stop_gradients[self.GetResultIndex()] | ||
.dyn_cast<ir::BoolAttribute>() | ||
.data(); | ||
} else { | ||
return false; | ||
} | ||
.def("get_stop_gradient", | ||
[](OpResult &self) { | ||
auto *defining_op = self.owner(); | ||
if (defining_op->HasAttribute(kAttrStopGradients)) { | ||
auto stop_gradients = defining_op->attribute(kAttrStopGradients) | ||
.dyn_cast<ir::ArrayAttribute>() | ||
.AsVector(); | ||
return stop_gradients[self.GetResultIndex()] | ||
.dyn_cast<ir::BoolAttribute>() | ||
.data(); | ||
} else { | ||
return false; | ||
} | ||
}) | ||
.def("__eq__", &OpResult::operator==) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有了下面的重载,这一句是不是已经不需要了啊 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 下面的判断opresult == value 的重载 |
||
.def("__eq__", | ||
[](OpResult &self, Value &other) { | ||
return self.value_impl() == other.impl(); | ||
}) | ||
.def("__hash__", [](OpResult &self) { | ||
return std::hash<ir::Value>{}(self.dyn_cast<ir::Value>()); | ||
}); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -192,8 +192,12 @@ class IR_API OpResult : public Value { | |
|
||
uint32_t GetResultIndex() const; | ||
|
||
bool operator==(const OpResult &other) const; | ||
|
||
friend Operation; | ||
|
||
detail::ValueImpl *value_impl() const; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为什么不返回 |
||
|
||
private: | ||
static uint32_t GetValidInlineIndex(uint32_t index); | ||
|
||
|
@@ -209,4 +213,5 @@ struct hash<ir::Value> { | |
return std::hash<const ir::detail::ValueImpl *>()(obj.impl_); | ||
} | ||
}; | ||
|
||
} // namespace std |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
若没有定义stop_gradient属性,默认策略是False,这个要留意下后续是否有什么影响