Skip to content

Commit

Permalink
[PIR]Store Python data in Operation (#62750)
Browse files Browse the repository at this point in the history
* store data in operation

* delete lod

* rename persistable

* fix append_backward

* fix lod

* remove pir test for data feeder

* fix amp

* support return none

* amend

* perfect set property

* fix descontruct bug
  • Loading branch information
YuanRisheng authored Mar 26, 2024
1 parent 66a4faa commit c3f5747
Show file tree
Hide file tree
Showing 19 changed files with 163 additions and 49 deletions.
70 changes: 44 additions & 26 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,40 @@ pir::Value apply(Value self, py::object func) {
return out;
}

#define DEF_VALUE_BOOL_PROPERTY(name) \
def_property( \
name, \
[](Value self) { \
auto bool_data = self.attribute<BoolAttribute>(name); \
return !bool_data || bool_data.data(); \
}, \
[](Value self, bool bool_data) { \
self.set_attribute( \
name, BoolAttribute::get(pir::IrContext::Instance(), bool_data)); \
})

#define DEF_VALUE_POINTER_PROPERTY(name) \
def_property( \
name, \
[](Value self) -> py::object { \
auto prop_ptr = self.property(name); \
if (!prop_ptr) { \
return py::cast<py::none>(Py_None); \
} \
auto py_data = reinterpret_cast<PyObject *>(prop_ptr); \
py::object obj = py::object(py::handle(py_data), true); \
return obj; \
}, \
[](Value self, py::object obj) { \
pir::PropertiesDeleter deleter = [](void *python_obj) { \
Py_DECREF(python_obj); \
}; \
PyObject *pointer_data = obj.release().ptr(); \
pir::Property value_property(reinterpret_cast<void *>(pointer_data), \
deleter); \
self.set_property(name, value_property); \
})

void BindValue(py::module *m) {
py::class_<Value> value(*m,
"Value",
Expand All @@ -834,8 +868,7 @@ void BindValue(py::module *m) {
The constructor of Value should not be invoked directly. Value can be automatically constructed
when build network.
)DOC",
pybind11::dynamic_attr());
)DOC");
g_ir_value_pytype = reinterpret_cast<PyTypeObject *>(value.ptr());
value.def(py::init<>())
.def_property_readonly(
Expand Down Expand Up @@ -916,30 +949,15 @@ void BindValue(py::module *m) {
return true;
}
})
.def_property(
"stop_gradient",
[](Value self) {
auto stop_gradient =
self.attribute<BoolAttribute>(kAttrStopGradients);
return !stop_gradient || stop_gradient.data();
},
[](Value self, bool stop_gradient) {
self.set_attribute(
kAttrStopGradients,
BoolAttribute::get(pir::IrContext::Instance(), stop_gradient));
})
.def_property(
"persistable",
[](Value self) {
auto persistable =
self.attribute<BoolAttribute>(kAttrIsPersistable);
return !persistable || persistable.data();
},
[](Value self, bool persistable) {
self.set_attribute(
kAttrIsPersistable,
BoolAttribute::get(pir::IrContext::Instance(), persistable));
})
.DEF_VALUE_BOOL_PROPERTY("stop_gradient")
.DEF_VALUE_BOOL_PROPERTY("trainable")
.DEF_VALUE_BOOL_PROPERTY("persistable")
.DEF_VALUE_BOOL_PROPERTY("need_clip")
.DEF_VALUE_BOOL_PROPERTY("is_distributed")
.DEF_VALUE_BOOL_PROPERTY("is_parameter")
.DEF_VALUE_POINTER_PROPERTY("optimize_attr")
.DEF_VALUE_POINTER_PROPERTY("regularizer")
.DEF_VALUE_POINTER_PROPERTY("do_model_average")
.def("all_used_ops",
[](Value &self) -> py::list {
py::list op_list;
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/include/core/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "paddle/pir/include/core/type_id.h"

constexpr char kAttrStopGradients[] = "stop_gradient";
constexpr char kAttrIsPersistable[] = "is_persistable";
constexpr char kAttrIsPersistable[] = "persistable";
constexpr char kAttrOpDistAttr[] = "op_dist_attr";

namespace pir {
Expand Down
3 changes: 3 additions & 0 deletions paddle/pir/include/core/op_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class IR_API OpResult : public Value {
Attribute attribute(const std::string &key) const;
void set_attribute(const std::string &key, Attribute value);

void *property(const std::string &key) const;
void set_property(const std::string &key, const Property &value);

private:
friend Operation;
OpResult(detail::OpResultImpl *impl); // NOLINT
Expand Down
9 changes: 9 additions & 0 deletions paddle/pir/include/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ class IR_API alignas(8) Operation final
return attributes_.find(key) != attributes_.end();
}

void set_value_property(const std::string &key,
const Property &value,
size_t index);

void *value_property(const std::string &key, size_t index) const;

///
/// \brief op ouput related public interfaces
///
Expand Down Expand Up @@ -266,6 +272,9 @@ class IR_API alignas(8) Operation final

AttributeMap attributes_;

// store data that user create by Python
std::vector<PropertyMap> value_properties_;

OpInfo info_;

static uint64_t GenerateId() {
Expand Down
1 change: 1 addition & 0 deletions paddle/pir/include/core/operation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
namespace pir {
class Block;
using AttributeMap = std::unordered_map<std::string, Attribute>;
using PropertyMap = std::unordered_map<std::string, Property>;

//===----------------------------------------------------------------------===//
// OperationArgument
Expand Down
6 changes: 6 additions & 0 deletions paddle/pir/include/core/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

namespace pir {
class Operation;
using PropertiesDeleter = void (*)(void *);
using Property = std::pair<void *, PropertiesDeleter>;

namespace detail {
class ValueImpl;
Expand Down Expand Up @@ -116,6 +118,10 @@ class IR_API Value {

void set_attribute(const std::string &key, Attribute value);

void set_property(const std::string &key, const Property &value);

void *property(const std::string &name) const;

protected:
detail::ValueImpl *impl_{nullptr};
};
Expand Down
8 changes: 8 additions & 0 deletions paddle/pir/src/core/op_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ void OpResult::set_attribute(const std::string &key, Attribute value) {
return IMPL_->set_attribute(key, value);
}

void *OpResult::property(const std::string &key) const {
return impl_ ? IMPL_->property(key) : nullptr;
}
void OpResult::set_property(const std::string &key, const Property &value) {
CHECK_OPRESULT_NULL_IMPL(set_property);
return IMPL_->set_property(key, value);
}

OpResult::OpResult(detail::OpResultImpl *impl) : Value(impl) {}

} // namespace pir
9 changes: 9 additions & 0 deletions paddle/pir/src/core/op_result_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ void OpResultImpl::set_attribute(const std::string &key, Attribute value) {
owner->set_attribute(key, ArrayAttribute::get(owner->ir_context(), vec));
}

void *OpResultImpl::property(const std::string &key) const {
return owner()->value_property(key, index());
}

void OpResultImpl::set_property(const std::string &key, const Property &value) {
auto owner = this->owner();
owner->set_value_property(key, value, index());
}

OpInlineResultImpl::OpInlineResultImpl(Type type, uint32_t result_index)
: OpResultImpl(type, result_index) {
PADDLE_ENFORCE_LE(
Expand Down
3 changes: 3 additions & 0 deletions paddle/pir/src/core/op_result_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class OpResultImpl : public ValueImpl {
Attribute attribute(const std::string &key) const;
void set_attribute(const std::string &key, Attribute value);

void *property(const std::string &key) const;
void set_property(const std::string &key, const Property &value);

private:
int32_t ComputeOperationOffset() const;
};
Expand Down
39 changes: 35 additions & 4 deletions paddle/pir/src/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,26 +199,35 @@ void Operation::Destroy() {
}
}

// 3. Deconstruct Operation.
// 3. Deconstruct Properties.
for (auto &value_property : value_properties_) {
for (auto &property_map : value_property) {
if (property_map.second.second) {
property_map.second.second((property_map.second.first));
}
}
}

// 4. Deconstruct Operation.
this->~Operation();

// 4. Deconstruct OpOperand.
// 5. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
detail::OpOperandImpl *op_operand_impl = operand(idx).impl_;
if (op_operand_impl) {
op_operand_impl->~OpOperandImpl();
}
}

// 5. Deconstruct BlockOperand.
// 6. Deconstruct BlockOperand.
for (size_t idx = 0; idx < num_successors_; idx++) {
detail::BlockOperandImpl *block_operand_impl = block_operands_ + idx;
if (block_operand_impl) {
block_operand_impl->~BlockOperandImpl();
}
}

// 5. Free memory.
// 7. Free memory.
size_t result_mem_size =
num_results_ > OUTLINE_RESULT_IDX
? sizeof(detail::OpOutlineResultImpl) *
Expand Down Expand Up @@ -399,6 +408,28 @@ int32_t Operation::ComputeOpOperandOffset(uint32_t index) const {
sizeof(Operation));
}

void Operation::set_value_property(const std::string &key,
const Property &value,
size_t index) {
if (value_properties_.size() < index + 1) {
value_properties_.resize(index + 1);
}
auto &property_map = value_properties_[index];
if (property_map.count(key)) {
property_map[key].second(property_map[key].first);
}
property_map[key] = value;
}

void *Operation::value_property(const std::string &key, size_t index) const {
if (value_properties_.size() < (index + 1)) {
return nullptr;
}
auto &property_map = value_properties_[index];
auto iter = property_map.find(key);
return iter == property_map.end() ? nullptr : iter->second.first;
}

#define COMPONENT_IMPL(component_lower, component_upper) \
component_upper##Impl *Operation::component_lower##_impl(uint32_t index) \
const { \
Expand Down
18 changes: 18 additions & 0 deletions paddle/pir/src/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,22 @@ void Value::set_attribute(const std::string &key, Attribute value) {
return dyn_cast<BlockArgument>().set_attribute(key, value);
}

void Value::set_property(const std::string &key, const Property &value) {
auto op_result = dyn_cast<OpResult>();
PADDLE_ENFORCE_NE(op_result,
nullptr,
common::errors::PreconditionNotMet(
"The Value is not an OpResult, we can set property "
"only for OpResult currently"));
return op_result.set_property(key, value);
}

void *Value::property(const std::string &key) const {
auto op_result = dyn_cast<OpResult>();
if (op_result) {
return op_result.property(key);
} else {
return nullptr;
}
}
} // namespace pir
7 changes: 7 additions & 0 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,15 @@ def _pir_transform(t, dtype):
paddle.pir.reset_insertion_point_to_start()
block = main.global_block()
cast_param = paddle._pir_ops.parameter(t.name)
cast_param.trainable = t.trainable
cast_param.stop_gradient = t.stop_gradient
cast_param.persistable = t.persistable
cast_param.optimize_attr = t.optimize_attr
cast_param.regularizer = t.regularizer
cast_param.do_model_average = t.do_model_average
cast_param.need_clip = t.need_clip
cast_param.is_distributed = t.is_distributed
cast_param.is_parameter = t.is_parameter
op = t.get_defining_op()
t.replace_all_uses_with(cast_param)
block.remove_op(op)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
ops = loss.get_defining_op().get_parent_block().ops
parameter_list = []
for op in ops:
if not op.has_attr("is_persistable"):
if not op.has_attr("persistable"):
continue
persist_value = [
result for result in op.results() if result.persistable
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/base/data_feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def __init__(self, feed_list, place, program=None):
raise TypeError("Feed list should contain a list of Value")
self.feed_dtypes.append(each_var.dtype)
self.feed_names.append(each_var.name)
self.feed_lod_level.append(each_var.lod_level)
self.feed_lod_level.append(0)
self.feed_shapes.append(each_var.shape)
else:
if program is None:
Expand Down
18 changes: 8 additions & 10 deletions python/paddle/pir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,10 @@ def create_parameter(
name=None,
**kwargs,
):
regularizer = None
need_clip = None
if 'initializer' not in kwargs:
raise ValueError(
"initializer is None, if you want to create parameter, please pass its initializer."
)
if 'regularizer' in kwargs:
regularizer = kwargs['regularizer']
if 'need_clip' in kwargs:
need_clip = kwargs['need_clip']
if dtype is not None:
if not isinstance(dtype, DataType):
dtype = convert_np_dtype_to_dtype_(dtype)
Expand All @@ -320,12 +314,16 @@ def create_parameter(
with program_guard(default_main_program()):
reset_insertion_point_to_start()
param = parameter(value_name)
trainable = kwargs.get('trainable', True)
param.stop_gradient = not trainable
param.persistable = True

param.regularizer = regularizer
param.need_clip = need_clip
param.trainable = kwargs.get('trainable', True)
param.stop_gradient = not param.trainable
param.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
param.regularizer = kwargs.get('regularizer', None)
param.do_model_average = kwargs.get('do_model_average', None)
param.need_clip = kwargs.get('need_clip', True)
param.is_distributed = False
param.is_parameter = True
return param


Expand Down
1 change: 0 additions & 1 deletion python/paddle/static/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def _reset_data_op_insertion_point():
prev_insertion_point = get_current_insertion_point()
_reset_data_op_insertion_point()
out = paddle._pir_ops.data(name, shape, ir_dtype, core.Place())
out.lod_level = lod_level
set_insertion_point(prev_insertion_point)
return out

Expand Down
7 changes: 7 additions & 0 deletions test/dygraph_to_static/test_tensor_attr_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@
'dist_attr',
'value_assign',
'replace_grad_users_with',
'do_model_average',
'is_distributed',
'is_parameter',
'need_clip',
'optimize_attr',
'regularizer',
'trainable',
]
)

Expand Down
Loading

0 comments on commit c3f5747

Please sign in to comment.