diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 13896b66f3c55..bd63d20c21510 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -957,7 +957,7 @@ void BuildOpFuncList( if (op_name == "builtin.combine" || op_name == "pd.feed" || op_name == "builtin.set_parameter" || - op_name == "builtin.get_parameter") { + op_name == "builtin.get_parameter" || op_name == "builtin.slice") { VLOG(6) << "skip process " << op_name; continue; } @@ -977,6 +977,7 @@ void BuildOpFuncList( phi::MetaTensor, phi::MetaTensor, paddle::small_vector, + paddle::small_vector, false>((*it), value_2_name_map, scope, @@ -1003,6 +1004,7 @@ void BuildOpFuncList( const phi::TensorBase*, phi::TensorBase*, paddle::small_vector, + paddle::small_vector, true>((*it), value_2_name_map, scope, diff --git a/paddle/fluid/framework/tensor_ref_array.h b/paddle/fluid/framework/tensor_ref_array.h index 516d76150840f..d5f5e0b61f2f9 100644 --- a/paddle/fluid/framework/tensor_ref_array.h +++ b/paddle/fluid/framework/tensor_ref_array.h @@ -20,11 +20,11 @@ namespace paddle { namespace framework { template <> -struct PhiVectorType { - const char* type_name = "PhiTensorRefArray"; +struct PhiVectorType { + const char* type_name = "VariableRefArray"; }; -using TensorRefArray = PhiVector; +using VariableRefArray = PhiVector; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/type_info.cc b/paddle/fluid/framework/type_info.cc index e75e77c194c8f..c2be243552704 100644 --- a/paddle/fluid/framework/type_info.cc +++ b/paddle/fluid/framework/type_info.cc @@ -41,6 +41,6 @@ template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; + paddle::framework::VariableRefArray>; } // namespace phi diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index 2e188e6caa076..ea6341d34ce16 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -212,7 +212,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< std::vector, std::vector, RawTensor, - TensorRefArray>; + VariableRefArray>; template struct VarTypeTrait { static_assert(VarTypeRegistry::IsRegistered(), "Must be registered type"); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h index 9dec0cb52b266..0c67837648dc5 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h @@ -87,6 +87,7 @@ class PhiKernelAdaptor { phi::MetaTensor, phi::MetaTensor, paddle::small_vector, + paddle::small_vector, false>((*it), name_map, scope_, nullptr, op_yaml_info_parser, &ctx); infer_meta_impl->infer_meta_(&ctx); @@ -106,6 +107,7 @@ class PhiKernelAdaptor { const phi::TensorBase*, phi::TensorBase*, paddle::small_vector, + paddle::small_vector, true>( (*it), name_map, scope_, nullptr, op_yaml_info_parser, &kernel_ctx); kernel_fn(&kernel_ctx); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index 601106037e6c7..1542e3a3379fd 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -43,6 +43,9 @@ namespace ir { +using VariableNameMap = + std::unordered_map; + paddle::framework::Variable* CreateVar(ir::Value value, const std::string& name, paddle::framework::Scope* scope, @@ -89,6 +92,7 @@ void BuildValue(ir::Value value, paddle::framework::Scope* scope, paddle::framework::Scope* local_scope, std::unordered_map* name_map, + VariableNameMap* variable_name_map, int& count) { // NOLINT auto inner_local_scope = local_scope != nullptr ? local_scope : scope; std::string name; @@ -107,7 +111,7 @@ void BuildValue(ir::Value value, } else if (value.type().isa()) { var->GetMutable(); } else if (value.type().isa()) { - auto tensor_array = var->GetMutable(); + auto tensor_array = var->GetMutable(); for (size_t i = 0; i < value.type().dyn_cast().size(); i++) { PADDLE_ENFORCE(value.type() @@ -118,7 +122,9 @@ void BuildValue(ir::Value value, "DenseTensorType")); std::string name_i = "inner_var_" + std::to_string(count++); auto var_i = CreateVar(value, name_i, scope, inner_local_scope); - tensor_array->emplace_back(var_i->GetMutable()); + var_i->GetMutable(); + tensor_array->emplace_back(var_i); + variable_name_map->emplace(var_i, name_i); } } else { PADDLE_THROW(phi::errors::PreconditionNotMet( @@ -127,6 +133,7 @@ void BuildValue(ir::Value value, } void HandleForSpecialOp(ir::Operation* op, + const VariableNameMap& variable_name_map, paddle::framework::Scope* scope, paddle::framework::Scope* local_scope, std::unordered_map* name_map, @@ -179,7 +186,7 @@ void HandleForSpecialOp(ir::Operation* op, } auto var = CreateVar(out_value, name, scope, local_scope); - auto tensor_array = var->GetMutable(); + auto tensor_array = var->GetMutable(); // clear tensor array tensor_array->clear(); @@ -191,8 +198,7 @@ void HandleForSpecialOp(ir::Operation* op, true, phi::errors::PreconditionNotMet("can not found input of combine op")); tensor_array->emplace_back( - &(CreateVar(value, name_map->at(value), scope, local_scope) - ->Get())); + CreateVar(value, name_map->at(value), scope, local_scope)); } } @@ -222,6 +228,34 @@ void HandleForSpecialOp(ir::Operation* op, auto out_ptr = op->result(0); name_map->emplace(out_ptr, param_name); } + + if (op_name == "builtin.slice") { + VLOG(6) << "Handle for builtin.slice"; + auto out_value = op->result(0); + + auto in_value = op->operand(0); + + PADDLE_ENFORCE_EQ(name_map->count(in_value), + true, + phi::errors::PreconditionNotMet( + "input of buildin slice not in name map")); + + int index = + op->attributes().at("index").dyn_cast().data(); + auto in_var = scope->FindVar(name_map->at(in_value)); + auto variable_array = in_var->Get(); + + PADDLE_ENFORCE_EQ( + variable_name_map.count(variable_array[index]), + true, + phi::errors::PreconditionNotMet("[%d] the variable in build slice " + "input MUST in variable name map", + index)); + + std::string var_name = variable_name_map.at(variable_array[index]); + + name_map->emplace(out_value, var_name); + } } void HandleForInplaceOp(ir::Operation* op, @@ -241,7 +275,7 @@ void HandleForInplaceOp(ir::Operation* op, paddle::dialect::OpYamlInfoParser yaml_parser( op_info.GetInterfaceImpl() ->get_op_info_()); - + VariableNameMap variable_name_map; for (size_t i = 0; i < op->num_results(); ++i) { ir::Value value = op->result(i); std::string value_name = yaml_parser.OutputNames()[i]; @@ -254,7 +288,8 @@ void HandleForInplaceOp(ir::Operation* op, << " (var: " << var_name << ")"; name_map->emplace(value, var_name); } else { - BuildValue(value, scope, local_scope, name_map, count); + BuildValue( + value, scope, local_scope, name_map, &variable_name_map, count); } } } @@ -272,8 +307,11 @@ void BuildScope(const ir::Block& block, VLOG(6) << "Build: scope [" << scope << "] inner_local_scope [" << inner_local_scope << "]"; + std::unordered_map + variable_name_map; + // int count = name_map->size(); - int count = inner_local_scope->Size(); + int count = name_map->size(); for (auto it = block.begin(); it != block.end(); ++it) { ir::Operation* op = *it; @@ -287,9 +325,10 @@ void BuildScope(const ir::Block& block, if (op_name == "pd.feed" || op_name == "pd.fetch" || op_name == "builtin.combine" || op_name == "builtin.set_parameter" || - op_name == "builtin.get_parameter") { - VLOG(4) << "HandleForSpecialOp: " << op_name; - HandleForSpecialOp(op, scope, inner_local_scope, name_map, count); + op_name == "builtin.get_parameter" || op_name == "builtin.slice") { + VLOG(6) << "HandleForSpecialOp: " << op_name; + HandleForSpecialOp( + op, variable_name_map, scope, inner_local_scope, name_map, count); continue; } @@ -305,7 +344,12 @@ void BuildScope(const ir::Block& block, continue; } else { for (size_t i = 0; i < op->num_results(); ++i) { - BuildValue(op->result(i), scope, local_scope, name_map, count); + BuildValue(op->result(i), + scope, + local_scope, + name_map, + &variable_name_map, + count); } } } diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index 917c7270217f2..32d3f92bc935d 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -75,7 +75,8 @@ void BuildScope(const ir::Block& block, template void BuildPhiContext( ir::Operation* op, @@ -121,11 +122,12 @@ void BuildPhiContext( if (var->IsType()) { const phi::TensorBase* tensor_in = &(var->Get()); ctx->EmplaceBackInput(InType(tensor_in)); - } else if (var->IsType()) { - ListType inputs; - auto& tensor_array = var->Get(); - for (size_t i = 0; i < tensor_array.size(); ++i) { - inputs.emplace_back(InType(tensor_array[i])); + } else if (var->IsType()) { + InListType inputs; + auto& variable_array = var->Get(); + for (size_t i = 0; i < variable_array.size(); ++i) { + inputs.emplace_back(InType(const_cast( + &(variable_array[i]->Get())))); } ctx->EmplaceBackInputs(inputs); } else { @@ -157,18 +159,21 @@ void BuildPhiContext( VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name; if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") { if (ptr.type().isa()) { - phi::Attribute r1 = phi::TensorRef( + phi::Attribute attr = phi::TensorRef( &(inner_scope->FindVar(in_var_name)->Get())); - ctx->EmplaceBackAttr(r1); + ctx->EmplaceBackAttr(attr); } else if (ptr.type().isa()) { auto& tensor_array = inner_scope->FindVar(in_var_name) - ->Get(); + ->Get(); if (tensor_array.size() == 1) { - ctx->EmplaceBackAttr(phi::TensorRef(tensor_array[0])); + phi::Attribute attr = + phi::TensorRef(&(tensor_array[0]->Get())); + ctx->EmplaceBackAttr(attr); } else { std::vector vec_ref; for (size_t i = 0; i < tensor_array.size(); ++i) { - vec_ref.emplace_back(phi::TensorRef(tensor_array[i])); + vec_ref.emplace_back( + phi::TensorRef(&(tensor_array[i]->Get()))); } ctx->EmplaceBackAttr(vec_ref); } @@ -328,8 +333,18 @@ void BuildPhiContext( } else if (out_type.isa()) { ctx->EmplaceBackOutput(OutType(const_cast( &(scope->Var(name)->Get())))); + } else if (out_type.isa()) { + OutListType outputs; + auto& variable_array = + scope->Var(name)->Get(); + for (size_t i = 0; i < variable_array.size(); ++i) { + outputs.emplace_back(OutType(const_cast( + &(variable_array[i]->Get())))); + } + ctx->EmplaceBackOutputs(outputs); } else { - PADDLE_THROW("not support type"); + PADDLE_THROW( + phi::errors::Unimplemented("only support DenseTensor and vector ")); } if (output_map != nullptr) { diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 03a94a14d7266..2c59cc5adf775 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -955,6 +955,104 @@ struct FeedOpTranscriber : public OpTranscriber { } }; +struct SplitOpTranscriber : public OpTranscriber { + std::vector GenerateOperationInput( + ir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos, + ir::Program* program) override { + // input of pslit is [Tensor x, IntArray sections, Scalar(int) axis)] + + VLOG(10) << "[op:split][input] start"; + + std::vector op_inputs; + // process first input + auto x_input_vars = op_desc.Input("X"); + IR_ENFORCE(x_input_vars.size() == 1, "x input of split MUST be a tensor"); + auto x_defining_info = (*param_map)[x_input_vars[0]]; + op_inputs.push_back(x_defining_info.value); + + // process sections + int num = paddle::get(op_desc.GetAttr("num")); + if (num <= 0) { + if (op_desc.HasInput("SectionsTensorList")) { + // get SectionsTensorList from input + + auto sec_tensor_list = op_desc.Input("SectionsTensorList"); + auto* combine_op = InsertCombineOperationForTarget( + ctx, param_map, program, sec_tensor_list); + op_inputs.push_back(combine_op->result(0)); + } else { + auto& attribute_translator = AttributeTranslator::instance(); + ir::Attribute new_attr = attribute_translator( + "paddle::dialect::IntArrayAttribute", op_desc.GetAttr("sections")); + auto sec_defin_op = + InsertFullOperationForAttributeInput(ctx, program, new_attr); + op_inputs.push_back(sec_defin_op->result(0)); + } + } + + // process axis + if (op_desc.HasInput("AxisTensor") && + op_desc.Input("AxisTensor").size() > 0) { + // get axis from input + auto axis_var_list = op_desc.Input("AxisTensor"); + IR_ENFORCE(axis_var_list.size() == 1, + "axis tensor input of split MUST be a tensor"); + auto axis_defining_info = (*param_map)[axis_var_list[0]]; + op_inputs.push_back(axis_defining_info.value); + } else { + auto& attribute_translator = AttributeTranslator::instance(); + ir::Attribute new_attr = + attribute_translator("ir::Int32Attribute", op_desc.GetAttr("axis")); + + auto sec_defin_op = + InsertFullOperationForAttributeInput(ctx, program, new_attr); + op_inputs.push_back(sec_defin_op->result(0)); + } + + return op_inputs; + } + + ir::AttributeMap TranslateOpAttribute( + ir::IrContext* ctx, + const std::string& normalized_op_name, + const OpAttributeInfoList& op_attr_infos, + const OpDesc& op_desc) override { + int num = paddle::get(op_desc.GetAttr("num")); + if (num > 0) { + ir::AttributeMap attribute_map = { + {"num", + ir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists("num"))}, + }; + + return attribute_map; + } + + return {}; + } + + ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + int num = paddle::get(op_desc.GetAttr("num")); + std::string target_op_name; + if (num > 0) { + target_op_name = "pd.split_with_num"; + + } else { + target_op_name = "pd.split"; + } + + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + IR_THROW("Op assign_value should have corresponding OpInfo pd.split"); + } + + return op_info; + } +}; + struct FetchOpTranscriber : public OpTranscriber { ir::Operation* operator()(ir::IrContext* ctx, TranslationContext* param_map, @@ -994,6 +1092,7 @@ OpTranslator::OpTranslator() { special_handlers["feed"] = FeedOpTranscriber(); special_handlers["fetch_v2"] = FetchOpTranscriber(); special_handlers["cast"] = CastOpTranscriber(); + special_handlers["split"] = SplitOpTranscriber(); special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber(); special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber(); special_handlers["assign_value"] = AssignValueOpTranscriber(); diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 01baacdf3f632..4dfd7d679e190 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2532,7 +2532,17 @@ int_array: sections : data_type : int - tensor_name : AxesTensor + scalar : + axis : + data_type : int + support_tensor : true + +- op : split_with_num + scalar : + axis : + data_type : int + support_tensor : true + tensor_name : AxisTensor - op : sqrt backward : sqrt_grad, sqrt_double_grad (sqrt_grad_grad) diff --git a/test/ir/new_ir/test_standalone_new_ir.py b/test/ir/new_ir/test_standalone_new_ir.py index 41ee5151d69ab..9be4e07fddc77 100644 --- a/test/ir/new_ir/test_standalone_new_ir.py +++ b/test/ir/new_ir/test_standalone_new_ir.py @@ -141,5 +141,27 @@ def test_with_new_ir(self): np.testing.assert_array_equal(out[0], gold_res) +class TestSplitOp(unittest.TestCase): + def test_with_new_ir(self): + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + + main_program = paddle.static.Program() + new_scope = paddle.static.Scope() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + x = paddle.static.data("x", [6, 2], dtype="float32") + out0, out1, out2 = paddle.split(x, num_or_sections=3, axis=0) + + np_a = np.random.rand(6, 2).astype("float32") + out = exe.run( + main_program, + feed={"x": np_a}, + fetch_list=[out0.name], + ) + + np.testing.assert_array_equal(out[0], np_a[0:2]) + + if __name__ == "__main__": unittest.main()