Skip to content

Commit

Permalink
Merge commit 'cdd04618c19a8386bb12437a0d4e14753d7ada40' into comm_dis…
Browse files Browse the repository at this point in the history
…tributed_fused_lamb
  • Loading branch information
BeingGod committed Sep 19, 2023
2 parents 28f8d50 + cdd0461 commit b6b8773
Show file tree
Hide file tree
Showing 472 changed files with 12,017 additions and 9,602 deletions.
2 changes: 1 addition & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ cppcoreguidelines-avoid-c-arrays,
cppcoreguidelines-c-copy-assignment-signature,
cppcoreguidelines-explicit-virtual-functions,
-cppcoreguidelines-init-variables,
-cppcoreguidelines-narrowing-conversions,
cppcoreguidelines-narrowing-conversions,
-cppcoreguidelines-no-malloc,
-cppcoreguidelines-pro-type-const-cast,
-cppcoreguidelines-pro-type-member-init,
Expand Down
9 changes: 7 additions & 2 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
select = C,E,W
exclude =
./build,
# Exclude fluid directory
./python/paddle/base/**,
# Exclude third-party libraries
./third_party/**,
./python/paddle/utils/gast/**,
Expand All @@ -27,3 +25,10 @@ ignore =
per-file-ignores =
# These files need tabs for testing.
test/dygraph_to_static/test_error.py:E101,W191

# temp ignore base directory
python/paddle/base/*:
E713,
E712,
E266,
E714
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ repos:
- id: flake8
args: ["--config=.flake8"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.272
rev: v0.0.289
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --no-cache]
Expand Down
22 changes: 2 additions & 20 deletions paddle/cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,31 +285,13 @@ void CodeGenC::Visit(const ir::Select *op) {
void CodeGenC::Visit(const ir::IfThenElse *op) {
str_ += "if (";
IrPrinter::Visit(op->condition);
str_ += ") {\n";
str_ += ") ";

if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent();
IrPrinter::Visit(op->true_case);
if (!op->true_case.As<ir::Block>()) str_ += ";";
str_ += "\n";

if (!op->true_case.As<ir::Block>()) DecIndent();

DoIndent();
str_ += "}";

if (op->false_case.defined()) {
str_ += " else {\n";

if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent();
str_ += " else ";
IrPrinter::Visit(op->false_case);
if (!op->false_case.As<ir::Block>()) str_ += ";";
str_ += "\n";
if (!op->true_case.As<ir::Block>()) DecIndent();

DoIndent();
str_ += "}";
}
}
void CodeGenC::Visit(const ir::Block *op) {
Expand Down
4 changes: 0 additions & 4 deletions paddle/cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -794,10 +794,8 @@ void test_simple_compute_at(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
};
};
};
Expand Down Expand Up @@ -869,10 +867,8 @@ void test_compute_at0(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
};
};
};
Expand Down
8 changes: 3 additions & 5 deletions paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ std::vector<ir::Tensor> CollectInputTensor(
std::vector<ir::Tensor>* func_args,
std::unordered_map<::pir::Value, ir::Tensor>* tensor_map) {
std::vector<ir::Tensor> tensors;
for (auto& operand : op->operands()) {
CHECK(operand);
auto in_value = operand.source();
for (auto in_value : op->operands_source()) {
VLOG(4) << "input tensor name: " << CompatibleInfo::ValueName(in_value);
// NOTE(Aurelius84): Need always to create placeholder for input tensor.
ir::Tensor tensor = details::GetTensor(in_value);
Expand All @@ -72,7 +70,7 @@ std::vector<ir::Tensor> CollectInputTensor(
return tensors;
}

void CollectOutputInfo(const ::pir::Operation* op,
void CollectOutputInfo(::pir::Operation* op,
std::vector<Type>* out_types,
std::vector<std::vector<int>>* out_shapes) {
auto op_results = op->results();
Expand Down Expand Up @@ -359,7 +357,7 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps(

std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
const ::pir::Operation* op,
::pir::Operation* op,
std::unordered_map<::pir::Value, ir::Tensor>* tensor_map,
std::vector<ir::Tensor>* op_func_arg_tensors) {
VLOG(4) << "Do lower with Compute, op: " << op->name();
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
*/
std::vector<ir::LoweredFunc> DoOpLower(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
const ::pir::Operation* op,
::pir::Operation* op,
std::unordered_map<::pir::Value, ir::Tensor>* tensor_map,
std::vector<ir::Tensor>* op_func_arg_tensors);

Expand Down
3 changes: 1 addition & 2 deletions paddle/cinn/hlir/framework/new_ir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ std::vector<std::string> CompatibleInfo::InputNames(const ::pir::Operation& op,
return names;
}

std::vector<std::string> CompatibleInfo::OutputNames(
const ::pir::Operation& op) {
std::vector<std::string> CompatibleInfo::OutputNames(::pir::Operation& op) {
std::vector<std::string> names;
for (int i = 0; i < op.num_results(); ++i) {
auto value = op.result(i);
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/new_ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct CompatibleInfo {
static std::vector<std::string> InputNames(const ::pir::Operation& op,
bool allow_duplicate = false);

static std::vector<std::string> OutputNames(const ::pir::Operation& op);
static std::vector<std::string> OutputNames(::pir::Operation& op); // NOLINT
};

} // namespace newir
Expand Down
8 changes: 6 additions & 2 deletions paddle/cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ Expr For::Make(Var loop_var,
node->min = min;
node->extent = extent;
node->device_api = device_api;
node->body = body;
node->body = body.As<ir::Block>() ? body : ir::Block::Make({body});
node->set_for_type(for_type);
node->set_vectorize_info(vector_info);
node->set_bind_info(bind_info);
Expand Down Expand Up @@ -346,6 +346,10 @@ std::vector<const Expr *> ScheduleBlockRealize::expr_fields() const {
}

Expr IfThenElse::Make(Expr condition, Expr true_case, Expr false_case) {
if (true_case.defined() && (!true_case.As<Block>()))
true_case = ir::Block::Make({true_case});
if (false_case.defined() && (!false_case.As<Block>()))
false_case = ir::Block::Make({false_case});
auto node = make_shared<IfThenElse>(condition, true_case, false_case);
return Expr(node);
}
Expand Down Expand Up @@ -513,7 +517,7 @@ Expr PolyFor::Make(Var iterator,
n->condition = condition;
n->inc = inc;
n->device_api = device_api;
n->body = body;
n->body = body.As<ir::Block>() ? body : ir::Block::Make({body});
n->set_for_type(for_type);
n->set_vectorize_info(vectorize_info);
n->set_bind_info(bind_info);
Expand Down
18 changes: 2 additions & 16 deletions paddle/cinn/ir/utils/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,26 +229,12 @@ void IrPrinter::Visit(const PolyFor *x) {
void IrPrinter::Visit(const IfThenElse *x) {
str_ += "if (";
Visit(x->condition);
str_ += ") {\n";
IncIndent();
DoIndent();
str_ += ") ";
Visit(x->true_case);
DecIndent();
str_ += "\n";
DoIndent();
str_ += "}";

if (x->false_case.defined()) {
str_ += " else {\n";
IncIndent();

DoIndent();
str_ += " else ";
Visit(x->false_case);
str_ += "\n";

DecIndent();
DoIndent();
str_ += "}";
}
}
void IrPrinter::Visit(const Block *x) {
Expand Down
17 changes: 0 additions & 17 deletions paddle/cinn/optim/ir_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,23 +306,6 @@ struct SimplifyBlocksMutator : public ir::IRMutator<> {
}
}

void Visit(const IfThenElse* op, Expr* expr) override {
auto* node = expr->As<IfThenElse>();
Visit(&node->condition, &node->condition);
if (node->true_case.As<Block>() &&
(node->true_case.As<Block>()->stmts.size() == 1)) {
node->true_case = node->true_case.As<Block>()->stmts[0];
}
Visit(&node->true_case, &node->true_case);
if (node->false_case.defined()) {
if (node->false_case.As<Block>() &&
(node->false_case.As<Block>()->stmts.size() == 1)) {
node->false_case = node->false_case.As<Block>()->stmts[0];
}
Visit(&node->false_case, &node->false_case);
}
}

void Visit(const ScheduleBlock* op, Expr* expr) override {
auto* node = expr->As<ScheduleBlock>();
CHECK(node);
Expand Down
5 changes: 4 additions & 1 deletion paddle/cinn/utils/attribute_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/cinn/utils/type_defs.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/builtin_type.h"

namespace cinn {
Expand Down Expand Up @@ -61,7 +62,9 @@ AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) {
AttributeMap dst_attrs;
for (auto& item : src_attrs) {
VLOG(4) << "deal with " << item.first;
if (item.second.isa<paddle::dialect::PlaceAttribute>()) {
if (item.first == ::pir::kStopGradientAttrName) {
continue;
} else if (item.second.isa<paddle::dialect::PlaceAttribute>()) {
auto is_cpu =
item.second.dyn_cast<paddle::dialect::PlaceAttribute>().data() ==
phi::CPUPlace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ std::vector<DimTrans*> MakeReshapeDimTrans(

if (tgt_splitted_shape.size() > 0) {
std::vector<DimTrans*> input_dims;
for (int64_t i = 0, n = src_dims.size(); i < n; i++) {
for (int i = 0, n = static_cast<int>(src_dims.size()); i < n; i++) {
int64_t in_dim = src_dims[i];
if (src_shape[in_dim] > 1) {
input_dims.emplace_back(new InputDim(in_dim));
Expand All @@ -141,7 +141,7 @@ paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Reshape Logic
int64_t ninputs = input_specs.size();
int64_t ninputs = static_cast<int64_t>(input_specs.size());
PADDLE_ENFORCE_EQ(
ninputs,
1,
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -949,15 +949,15 @@ static bool CollectGradInformationFromOpInfo(
op_base_infos->resize(grad_node->size());
for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) {
// Each OpBase
int index = std::distance(grad_node->begin(), iter);
int index = static_cast<int>(std::distance(grad_node->begin(), iter));
paddle::imperative::OpBase& op_base = *iter;
(*op_base_infos)[index].SetOpBaseType(op_base.Type());
}

/* ------ Get Grad ins/outs/attrs ---- */
VLOG(6) << "In function size: " << grad_node->size();
for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) {
int index = std::distance(grad_node->begin(), iter);
int index = static_cast<int>(std::distance(grad_node->begin(), iter));
auto* op_base_grad_ins = (*op_base_infos)[index].GetMutableGradIns();
auto* op_base_grad_outs = (*op_base_infos)[index].GetMutableGradOuts();
auto* op_base_grad_attrs = (*op_base_infos)[index].GetMutableGradAttrs();
Expand Down Expand Up @@ -3160,7 +3160,8 @@ static void DygraphCodeGeneration(const std::string& output_dir,
op_info_map_need_gen.emplace(pair);
}

int each_cc_file_api_size = op_info_map_need_gen.size() / split_count;
int each_cc_file_api_size =
static_cast<int>(op_info_map_need_gen.size() / split_count);
if (op_info_map_need_gen.size() % split_count != 0) {
each_cc_file_api_size++;
}
Expand Down
25 changes: 24 additions & 1 deletion paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class {} : public egr::GradNodeBase {{
// Prepare Grad function call
{}
// Runtime check if we need next grad
{}
// Set DistAttr of Out Tensor for semi-auto parallel
{}
// Inplace Check
{}
Expand Down Expand Up @@ -529,6 +531,12 @@ class {} : public egr::GradNodeBase {{
if( !{}.empty() ) {}_optional = paddle::make_optional<std::vector<paddle::Tensor>>({});
"""

SET_GRAD_OUT_DIST_ATTR_TEMPLATE = """
if (IsRunAutoParallel()) {{
egr::EagerUtils::SetGradOutputDistAttr(out_metas, {}, {});
}}
"""

CHECK_BACKWARD_INPLACE_TEMPLATE = """
bool can_be_inplaced = false;
if ({}.initialized()) {{
Expand Down Expand Up @@ -1088,7 +1096,7 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False):
for name, (ttype, pos) in forward_inputs_position_map.items():
if name in need_pre_contiguous_set:
pre_contiguous_list.append(
f"{indent}const auto& {name}_tmp = (require_any_grad && {name}.is_dense_tensor() && !std::dynamic_pointer_cast<phi::DenseTensor>({name}.impl())->meta().is_contiguous()) ? paddle::Tensor(std::make_shared<phi::DenseTensor>(std::move(paddle::experimental::Trans2Contiguous(*(std::dynamic_pointer_cast<phi::DenseTensor>({name}.impl())))))) : {name};"
f"{indent}const auto& {name}_tmp = (require_any_grad && {name}.is_dense_tensor() && !std::dynamic_pointer_cast<phi::DenseTensor>({name}.impl())->meta().is_contiguous()) ? paddle::Tensor(std::make_shared<phi::DenseTensor>(std::move(paddle::experimental::Trans2Contiguous(*(std::dynamic_pointer_cast<phi::DenseTensor>({name}.impl()))))), {name}.mutable_autograd_meta()) : {name};"
)
self.inputs_call_list_tmp[pos] = (
self.inputs_call_list_tmp[pos] + '_tmp'
Expand Down Expand Up @@ -2181,6 +2189,8 @@ def GenerateNodeDefinition(
)
grad_api_args = ["" for i in range(grad_api_args_len)]
get_grad_in_args_list = []
grad_api_out_args_list = []
fwd_positions_list = []

# Fill Grad Ins with Zero
fill_zero_str = ""
Expand Down Expand Up @@ -2388,6 +2398,8 @@ def GenerateNodeDefinition(
out_assign_str += f"{indent}*api_output_{out_index} = std::get<{out_index}>(api_output);\n"
else:
grad_api_args.append(f"api_output_{out_index}")
grad_api_out_args_list.append(f"api_output_{out_index}")
fwd_positions_list.append(f"{fwd_position}")
if inplace_grad_input_str in optional_inplace_var_name:
optional_inplace_str = "VLOG(6) << \"No Inplace should happend for wrappered input: {inplace_grad_input_str}\";"
else:
Expand Down Expand Up @@ -2433,6 +2445,16 @@ def GenerateNodeDefinition(
composite_grad_api_args_str = ", ".join(grad_api_args)
composite_template_name = "<paddle::Tensor>"

# Set DistAttr Func Construct
set_out_dist_attr_str = ""
if not is_invoke_forward_api:
fwd_positions_str = "{" + ", ".join(fwd_positions_list) + "}"
grad_api_out_args_str = ", ".join(grad_api_out_args_list)
set_out_dist_attr_str = SET_GRAD_OUT_DIST_ATTR_TEMPLATE.format(
fwd_positions_str,
grad_api_out_args_str,
)

if is_invoke_forward_api:
autograd_api_out = "auto"
if (
Expand Down Expand Up @@ -2600,6 +2622,7 @@ def GenerateNodeDefinition(
get_grad_in_args_str,
grad_function_prepare_str,
compute_require_next_grad_str,
set_out_dist_attr_str,
inplace_check_str,
inplace_for_grad_outs_str,
self.backward_api_name,
Expand Down
Loading

0 comments on commit b6b8773

Please sign in to comment.