Skip to content

Commit

Permalink
[PIR] Support select complex kernel for real_grad (PaddlePaddle#58665)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* fix

* fux

* fix

* add ut
  • Loading branch information
zhangbo9674 authored and SecretXV committed Nov 28, 2023
1 parent 2fcdbd1 commit 6ecda61
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
28 changes: 24 additions & 4 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,11 +1435,31 @@ def OpGenerator(
'data_type' in op_kernel_map
and op_kernel_map['data_type']
):
kernel_key_dtype = '", "'.join(
op_kernel_map['data_type']['candidates']
)
for idx in range(
len(op_kernel_map['data_type']['candidates'])
):
if (
'to_complex_flag' in op_kernel_map['data_type']
and op_kernel_map['data_type'][
'to_complex_flag'
][idx]
):
kernel_key_dtype += (
'complex:'
+ op_kernel_map['data_type']['candidates'][
idx
]
+ '", "'
)
else:
kernel_key_dtype += (
op_kernel_map['data_type']['candidates'][
idx
]
+ '", "'
)
if kernel_key_dtype != "":
kernel_key_dtype = '"' + kernel_key_dtype + '"'
kernel_key_dtype = '"' + kernel_key_dtype[:-3]
if 'backend' in op_kernel_map and op_kernel_map['backend']:
kernel_key_backend = '", "'.join(
op_kernel_map['backend']['candidates']
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/pir/core/builtin_op.h"
Expand Down Expand Up @@ -343,6 +344,12 @@ phi::DataType GetKernelDataTypeByYamlInfo(
auto slot_name = data_type_info[i];
auto& input_map = op_info_parser->InputName2Id();

bool is_complex_tag = false;
if (slot_name.find("complex:") == 0) {
slot_name = slot_name.substr(8);
is_complex_tag = true;
}

auto find_it = Str2PhiDataType.find(slot_name);
if (find_it != Str2PhiDataType.end()) {
kernel_data_type = find_it->second;
Expand Down Expand Up @@ -383,6 +390,9 @@ phi::DataType GetKernelDataTypeByYamlInfo(
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType, SelectedRows, VectorType"));
}
if (is_complex_tag) {
kernel_data_type = phi::dtype::ToComplex(kernel_data_type);
}

} else {
PADDLE_ENFORCE_EQ(attr_map.count(slot_name),
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_real_imag_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test_check_grad(self):
'Out',
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out],
check_pir=True,
)


Expand Down

0 comments on commit 6ecda61

Please sign in to comment.