Skip to content

Commit

Permalink
support backward for distribute pir.
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Mar 25, 2024
1 parent 8cae55d commit 15b1c85
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 16 deletions.
29 changes: 26 additions & 3 deletions paddle/fluid/pir/dialect/distributed/ir/dist_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.
#pragma once

#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
#include "paddle/pir/include/core/cast_utils.h"
#include "paddle/pir/include/core/dll_decl.h"
#include "paddle/pir/include/core/type.h"
Expand All @@ -25,24 +26,46 @@ class IR_API DistTypeInterface
public:
struct Concept {
/// Defined these methods with the interface.
explicit Concept(pir::Type (*local_type)(pir::Type))
: local_type(local_type) {}
explicit Concept(pir::Type (*local_type)(pir::Type),
ProcessMeshAttribute (*process_mesh_attr)(pir::Type),
TensorDistAttribute (*tensor_dist_attr)(pir::Type))
: local_type(local_type),
process_mesh_attr(process_mesh_attr),
tensor_dist_attr(tensor_dist_attr) {}
pir::Type (*local_type)(pir::Type);
ProcessMeshAttribute (*process_mesh_attr)(pir::Type);
TensorDistAttribute (*tensor_dist_attr)(pir::Type);
};

template <class ConcreteType>
struct Model : public Concept {
static Type local_type(Type type) {
return pir::cast<ConcreteType>(type).local_type();
}
Model() : Concept(local_type) {}
static ProcessMeshAttribute process_mesh_attr(Type type) {
return pir::cast<ConcreteType>(type).process_mesh_attr();
}

static TensorDistAttribute tensor_dist_attr(Type type) {
return pir::cast<ConcreteType>(type).tensor_dist_attr();
}

Model() : Concept(local_type, process_mesh_attr, tensor_dist_attr) {}
};

DistTypeInterface(pir::Type type, Concept *impl)
: pir::TypeInterfaceBase<DistTypeInterface>(type), impl_(impl) {}

pir::Type local_type() { return impl_->local_type(*this); }

ProcessMeshAttribute process_mesh_attr() {
return impl_->process_mesh_attr(*this);
}

TensorDistAttribute tensor_dist_attr() {
return impl_->tensor_dist_attr(*this);
}

private:
Concept *impl_;
};
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/core/ir_context.h"

namespace paddle {
Expand Down Expand Up @@ -155,6 +156,7 @@ void ShardTensorOp::Build(pir::Builder& builder,
tensor_dist_attr,
local_shape);
argument.AddOutput(out_dist_tensor_type);
::pir::PassStopGradientsDefaultly(argument);
}

void ReShardOp::VerifySig() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,8 @@ def GenDistBranch(args, op_info):
ProcessMeshAttribute op_mesh;
auto ctx = pir::IrContext::Instance();
for(auto value : input_values) {{
if (auto dist_type = value.type().dyn_cast<DistDenseTensorType>()) {{
op_mesh = dist_type.process_mesh_attr();
if (auto dist_interface = value.type().dyn_cast<DistTypeInterface>()) {{
op_mesh = dist_interface.process_mesh_attr();
break;
}}
}}"""
Expand Down
14 changes: 12 additions & 2 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/ir/manual_api.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h"
Expand Down Expand Up @@ -63,8 +64,17 @@ void set_parameter(const pir::Value& parameter, const std::string& name) {
}

void shadow_output(const pir::Value& persist_value, const std::string& name) {
ApiBuilder::Instance().GetBuilder()->Build<pir::ShadowOutputOp>(persist_value,
name);
auto& builder = ApiBuilder::Instance().GetBuilder();
auto op = builder->Build<pir::ShadowOutputOp>(persist_value, name);
if (auto dist_interface =
persist_value.type().dyn_cast<DistTypeInterface>()) {
op->set_attribute(
kAttrOpDistAttr,
OperationDistAttribute::get(builder->ir_context(),
dist_interface.process_mesh_attr(),
{dist_interface.tensor_dist_attr()},
{}));
}
}

pir::Value embedding_grad(const pir::Value& x,
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ using pir::Block;
using pir::BlockArgument;
using pir::BoolAttribute;
using pir::CloneOptions;
using pir::IrContext;
using pir::IrMapping;
using pir::IrParser;
using pir::Operation;
Expand Down Expand Up @@ -223,6 +224,20 @@ std::string GetValueInfo(Value v) {
return ss.str();
}

Value GetOutputValueByName(const Program &program, const std::string &name) {
auto &block = *program.block();
pir::StrAttribute name_attr =
pir::StrAttribute::get(IrContext::Instance(), name);
for (auto &op : block) {
if (op.isa<pir::ShadowOutputOp>()) {
if (op.attribute("output_name") == name_attr) {
return op.operand_source(0);
}
}
}
return nullptr;
}

void BindProgram(py::module *m) {
py::class_<Program, std::shared_ptr<Program>> program(
*m, "Program", py::dynamic_attr(), R"DOC(
Expand Down Expand Up @@ -334,6 +349,10 @@ void BindProgram(py::module *m) {
[](std::shared_ptr<Program> self, int64_t random_seed) {
SetProgramInt64Attr(self, "random_seed", random_seed);
})
.def("get_output_value_by_name",
[](Program &self, const std::string &name) {
return GetOutputValueByName(self, name);
})
.def("num_ops", [](Program &self) { return self.num_ops(); });
}

Expand Down
10 changes: 5 additions & 5 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,11 +638,10 @@ def _parallel_pir(self, mode):
dist_program = paddle.base.libpaddle.pir.apply_mix2dist_pass(
mix_fw_program
)

# TODO(winter-wang) Step 1.2: pir backward
# with program_guard(dist_program):
# params_grads = append_backward_pir(self._loss, parameter_list=self._parameter_list)

# Step 1.2: pir backward
if mode != "predict" and self._loss:
loss = dist_program.get_output_value_by_name(self._loss_names[0])
paddle.autograd.ir_backward.append_backward(loss)
# TODO(winter-wang) Step 1.3: adapot opt.minimize() for pir-auto-parallel
# with program_guard(dist_program):
# ptimizer_ops = self._optimizer.apply_gradients(params_grads)
Expand Down Expand Up @@ -767,6 +766,7 @@ def _build(self, mode):
# self._process_dist_input_specs()
outputs = self.program_helper.output_vars
self._losses = self.program_helper.loss_vars
self._loss_names = self.program_helper.loss_names
metrics = self.program_helper.metric_vars

paddle.enable_static()
Expand Down
23 changes: 23 additions & 0 deletions python/paddle/distributed/auto_parallel/static/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, layer, loss_func, metrics):
self._label_vars = defaultdict(list)
self._output_vars = defaultdict(list)
self._loss_vars = defaultdict(list)
self._loss_names = defaultdict(list)
self._metric_vars = defaultdict(list)

# Consider ProxyLayer as not Paddle inner function because it contains
Expand All @@ -66,6 +67,12 @@ def __init__(self, layer, loss_func, metrics):
inspect.getmodule(ProxyLayer).__name__ + ".ProxyLayer"
)

@paddle.jit.not_to_static
def append_loss_to_shadow_output(self, mode):
name = paddle.utils.unique_name.generate('loss')
paddle._pir_ops.set_persistable_value(self._loss_vars[mode], name)
self._loss_names[mode] = name

def _train(self, inputs, labels):
"""
Train process of inner_layer with forward/loss/metric logic.
Expand All @@ -81,6 +88,10 @@ def _train(self, inputs, labels):
# step 3. calculate loss if needed
new_inputs = self._prepare(self.output_vars, labels)
self._loss_vars[mode] = self.call_loss(new_inputs)
if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
"FLAGS_enable_pir_api"
]:
self.append_loss_to_shadow_output(mode)

# step 4. calculate metrics if needed
self._metric_vars[mode] = self.call_metrics(new_inputs)
Expand All @@ -103,6 +114,10 @@ def _eval(self, inputs, labels):
# step 3. calculate loss if needed
new_inputs = self._prepare(self.output_vars, labels)
self._loss_vars[mode] = self.call_loss(new_inputs)
if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
"FLAGS_enable_pir_api"
]:
self.append_loss_to_shadow_output(mode)

# step 4. calculate metrics if needed
self._metric_vars[mode] = self.call_metrics(new_inputs)
Expand Down Expand Up @@ -180,6 +195,10 @@ def output_vars(self):
def loss_vars(self):
return self._loss_vars[self.mode]

@property
def loss_names(self):
return self._loss_names[self.mode]

@property
def metric_vars(self):
return self._metric_vars[self.mode]
Expand Down Expand Up @@ -521,6 +540,10 @@ def label_vars(self):
def loss_vars(self):
return to_list(self.proxy_layer.loss_vars)

@property
def loss_names(self):
return to_list(self.proxy_layer.loss_names)

@property
def metric_vars(self):
return to_list(self.proxy_layer.metric_vars)
Expand Down
35 changes: 31 additions & 4 deletions test/auto_parallel/pir/test_to_static_pir_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def test_to_static_program(self):
main_program = dist_model._engine._pir_main_progs["eval"]

for op in main_program.global_block().ops:
if op.num_results() == 0:
continue
tensor = op.result(0)
if op.name() == 'pd_op.data':
self.assertTrue(tensor.is_dist_dense_tensor_type())
Expand Down Expand Up @@ -128,9 +130,22 @@ def test_to_static_program(self):

relu_idx = 0
matmul_idx = 0

for op in main_program.global_block().ops:
matmul_grad_idx = 0
ops = main_program.global_block().ops
self.assertEqual(ops[-1].name(), "pd_op.matmul_grad")
self.assertEqual(ops[-2].name(), "pd_op.relu_grad")
self.assertEqual(ops[-3].name(), "pd_op.matmul_grad")
self.assertEqual(ops[-4].name(), "pd_op.relu_grad")
self.assertEqual(ops[-5].name(), "pd_op.subtract_grad")
self.assertEqual(ops[-6].name(), "pd_op.square_grad")
self.assertEqual(ops[-7].name(), "pd_op.mean_grad")

for op in ops:
if op.num_results() == 0:
continue
tensor = op.result(0)
if not tensor.initialized():
continue
self.assertTrue(tensor.is_dist_dense_tensor_type())
self.assertEqual(tensor.dist_attr().process_mesh.shape, [2])
self.assertEqual(
Expand All @@ -143,8 +158,6 @@ def test_to_static_program(self):
elif op.name() == 'builtin.parameter':
self.assertTrue(tensor.is_dense_tensor_type())
self.assertTrue(tensor.is_dist_dense_tensor_type())
self.assertTrue(tensor.has_one_use())

self.assertTrue(tensor.is_dist_dense_tensor_type())
self.assertEqual(tensor.dist_attr().process_mesh.shape, [2])
self.assertEqual(
Expand Down Expand Up @@ -189,6 +202,20 @@ def test_to_static_program(self):
tensor._local_shape, [BATCH_SIZE, CLASS_NUM]
)
matmul_idx += 1
if op.name() == 'pd_op.matmul_grad':
if matmul_grad_idx == 0:
self.assertEqual(tensor.dist_attr().dims_mapping, [-1, 0])
self.assertEqual(tensor.dist_attr().partial_dims, set())
self.assertEqual(
tensor._local_shape, [BATCH_SIZE, CLASS_NUM]
)
elif matmul_grad_idx == 1:
self.assertEqual(tensor.dist_attr().dims_mapping, [-1, 0])
self.assertEqual(tensor.dist_attr().partial_dims, set())
self.assertEqual(
tensor._local_shape, [BATCH_SIZE, IMAGE_SIZE // 2]
)
matmul_grad_idx += 1

# dist_model.train()
# for batch_id, (image, label) in enumerate(dist_loader()):
Expand Down

0 comments on commit 15b1c85

Please sign in to comment.