Skip to content

Commit

Permalink
[NewIR]Remove compatible logic of ProgramTranslator (#55453)
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll authored Jul 27, 2023
1 parent 147fbfe commit cbbd940
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 52 deletions.
3 changes: 0 additions & 3 deletions paddle/fluid/ir_adaptor/translator/op_compat_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ def insert_new_mutable_attributes(
backward_op, op_compat_item["scalar"]
)

# special op mappings
op_name_mappings["fetch_v2"] = "fetch"

op_name_normailzer_template = env.get_template("op_compat_info.cc.j2")
with open(output_source_file, 'wt') as f:
op_compat_definition = op_name_normailzer_template.render(
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/ir_adaptor/translator/op_compat_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

#include "glog/logging.h"

#include "paddle/fluid/ir_adaptor/translator/utils.h"

#pragma once

namespace paddle {
Expand Down Expand Up @@ -106,11 +104,11 @@ class OpNameNormalizer {
return legacy_name;
}
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
return UnderscoreToCamelCase(arg_name);
return arg_name;
}
auto& arg_mappings = op_arg_name_mappings[op_type];
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
return UnderscoreToCamelCase(arg_name);
return arg_name;
}
return arg_mappings.at(arg_name);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx,
const OpDesc& op_desc) {
std::string target_op_name =
kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type());
if (IsInplace(op_desc)) {
if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') {
target_op_name += "_";
}
VLOG(6) << "[op name normalizing]: " << op_desc.Type() << " to "
Expand Down
42 changes: 0 additions & 42 deletions paddle/fluid/ir_adaptor/translator/utils.h

This file was deleted.

27 changes: 25 additions & 2 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@
attrs : [bool use_mkldnn = false]

- op : bilinear (bilinear_tensor_product)
backward: bilinear_grad (bilinear_tensor_product_grad)
inputs :
{x : X, y : Y,weight: Weight, bias: Bias}
outputs :
Expand Down Expand Up @@ -1838,7 +1839,7 @@
data_type : float
support_tensor : true

- op : merged_momentum_
- op : merged_momentum_ (merged_momentum)
inputs :
{param : Param, grad : Grad, velocity : Velocity, learning_rate : LearningRate, master_param : MasterParam}
outputs :
Expand Down Expand Up @@ -3038,11 +3039,27 @@
yolo_loss : GetYoloLossExpectedKernelType
yolo_loss_grad : GetYoloLossExpectedKernelType

- op: fetch
- op: channel_shuffle
inputs:
{x: X}
outputs:
{out: Out}

- op: fetch (fetch_v2)
inputs: {x: X}
outputs: {out: Out}

- op: full_batch_size_like (fill_constant_batch_size_like)
inputs:
{input: Input}
outputs:
{out: Out}

- op: logspace
inputs:
{start: Start, stop: Stop, num: Num, base: Base}
outputs:
{out: Out}

- op: lu
backward: lu_grad
Expand All @@ -3059,6 +3076,12 @@
outputs :
{reindex_src : Reindex_Src, reindex_dst : Reindex_Dst, out_nodes : Out_Nodes}

- op: rrelu
inputs:
{x: X}
outputs:
{out: Out, noise: Noise}

- op: sigmoid_cross_entropy_with_logits
backward: sigmoid_cross_entropy_with_logits_grad
inputs :
Expand Down
16 changes: 16 additions & 0 deletions test/ir/new_ir/test_special_op_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,5 +194,21 @@ def test_with_axis(self):
np.testing.assert_array_equal(out[0], np.all(arr, axis=0))


class TestIndexPutOpTranscriber(unittest.TestCase):
def test_op(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.randn([2, 3])
indices = [paddle.randint(0, 2, [2]), paddle.randint(0, 1, [2])]
value = paddle.randn([2])
y = paddle.index_put(x, indices, value, False)

_ = ir.translate_to_new_ir(main_program.desc)


if __name__ == "__main__":
unittest.main()

0 comments on commit cbbd940

Please sign in to comment.