From c4926b245d5251f507b26055a81b823fb8a39df1 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 23 Nov 2023 09:19:27 +0000 Subject: [PATCH 01/22] fix order of static backward --- python/paddle/base/backward.py | 73 +++++++++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index e62a5b9245a1b..54d21b4316eb7 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -511,7 +511,11 @@ def _accumulate_gradients_by_add_ops_( def _addup_repetitive_outputs_( - op_descs, block_idx, grad_var_to_var=None, grad_op_id_to_fwd_op=None + op_descs, + block_idx, + grad_var_to_var=None, + grad_op_id_to_fwd_op=None, + topo_order_for_backward=None, ): """ In backward part, an variable may be the output of more than one ops. @@ -525,12 +529,18 @@ def _addup_repetitive_outputs_( """ _MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add'] + topo_order_for_grad_name = {} # pending_sum_ops = [] pending_sum_ops = collections.OrderedDict() var_rename_count = collections.defaultdict(int) renamed_vars = collections.defaultdict(list) renamed_var_start_idx = collections.defaultdict(list) var_device = collections.defaultdict(str) + + def _change_order_by_topo_order(var_name): + origin_names = renamed_vars[var_name] + origin_names.sort(key=lambda x: topo_order_for_grad_name[x]) + for idx, op_desc in enumerate(op_descs): op_device_attr_name = ( core.op_proto_and_checker_maker.kOpDeviceAttrName() @@ -543,6 +553,7 @@ def _addup_repetitive_outputs_( continue if len(renamed_vars[var_name]) > 1: if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: + _change_order_by_topo_order(var_name) _accumulate_gradients_by_sum_op_( var_name, renamed_vars, @@ -551,6 +562,7 @@ def _addup_repetitive_outputs_( var_device[var_name], ) else: + _change_order_by_topo_order(var_name) _accumulate_gradients_by_add_ops_( var_name, renamed_vars, @@ -576,6 +588,9 @@ def _addup_repetitive_outputs_( # it's the first time we get the variable renamed_vars[var_name] = [var_name] renamed_var_start_idx[var_name] = idx + topo_order_for_grad_name[ + var_name + ] = topo_order_for_backward[op_desc] else: if len(renamed_vars[var_name]) == 1: new_name = ( @@ -595,6 +610,9 @@ def _addup_repetitive_outputs_( else: grad_var_to_var[new_name] = var_name # rename original var_name + topo_order_for_grad_name[ + new_name + ] = topo_order_for_grad_name[var_name] renamed_vars[var_name][0] = new_name # before change: _rename_arg_(op_descs, var_name, # new_name, 0, idx) @@ -646,10 +664,15 @@ def _addup_repetitive_outputs_( renamed_vars[var_name].append(new_name) # record the latest device var_device[var_name] = op_device + topo_order_for_grad_name[ + new_name + ] = topo_order_for_backward[op_desc] + breakpoint() for var_name, inputs in renamed_vars.items(): if len(renamed_vars[var_name]) > 1: if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: + _change_order_by_topo_order(var_name) _accumulate_gradients_by_sum_op_( var_name, renamed_vars, @@ -658,6 +681,7 @@ def _addup_repetitive_outputs_( var_device[var_name], ) else: + _change_order_by_topo_order(var_name) _accumulate_gradients_by_add_ops_( var_name, renamed_vars, @@ -1175,6 +1199,7 @@ def _append_backward_ops_with_checkpoints_( grad_to_var.update(op_grad_to_var) # 3.d. add sum op for repetitive_outputs + topo_order = _topo_order_map(block, target_vars) grad_op_descs = _addup_repetitive_outputs_( grad_op_descs, block.idx, grad_op_id_to_fwd_op=grad_op_id_to_fwd_op ) @@ -1262,6 +1287,45 @@ def _rename_grad_name_(name, grad_order): return 'grad/' * grad_order + name +def _topo_order_map(block, target_vars): + """Analysis forward block and build a mapping from: + OpDesc -> Int + """ + get_defined_op = {} # mapping from String -> OpDesc (defined op) + for op in block.ops: + for out_name in op.output_arg_names: + get_defined_op[out_name] = op + + topo_order_map = {} # mapping from OpDesc -> Topologic Order + queue = [var.name for var in target_vars] + topo_order_counter = 0 + while len(queue) > 0: + cur_var_name = queue.pop(0) + cur_op = get_defined_op[cur_var_name] + topo_order_map[cur_op] = topo_order_counter + topo_order_counter += 1 + for inp in cur_op.input_arg_names[ + ::-1 + ]: # [::-1] for reverse, in dygraph, x + y + z -> o will result in + # z@grad is first calculated. + if ( + inp not in topo_order_map and inp in get_defined_op + ): # maybe slow, find in list ! + queue.append(inp) + return topo_order_map + + +def _topo_bwd_order_map(topo_fwd_map, backward_op_map): + topo_bwd_map = {} + topo_fwd_map = {op.desc: order for op, order in topo_fwd_map.items()} + for fwd_op, bwd_ops in backward_op_map.items(): + if fwd_op not in topo_fwd_map: + continue + for bwd_op in bwd_ops: + topo_bwd_map[bwd_op] = topo_fwd_map[fwd_op] + return topo_bwd_map + + def _append_backward_ops_( block, ops, @@ -1325,6 +1389,7 @@ def update_distop_context( # grad_op_descs holds created grad_op, and will be appended to target_block grad_op_descs = [] program = block.program + get_backward_op_desc = {} # for topo order map if rename_var_map is None: rename_var_map = {} @@ -1410,6 +1475,7 @@ def find_op_index(block_desc, cur_op_desc): ) # record the mapping between fwd and bwd + get_backward_op_desc[op.desc] = grad_op_desc if grad_op_id_to_fwd_op is not None: for op_desc in grad_op_desc: grad_op_id_to_fwd_op[op_desc.original_id()] = op @@ -1526,11 +1592,16 @@ def find_op_index(block_desc, cur_op_desc): program._appending_grad_times ] # sum parameter's gradients' var given multiple var gradient + topo_order = _topo_order_map(block, target_vars) + topo_order_for_backward = _topo_bwd_order_map( + topo_order, get_backward_op_desc + ) grad_op_descs = _addup_repetitive_outputs_( grad_op_descs, block.idx, grad_var_to_var, grad_op_id_to_fwd_op=grad_op_id_to_fwd_op, + topo_order_for_backward=topo_order_for_backward, ) # if all outputs of the grad op are in no_grad_set, then just remove and fill zero From 7936b3fb204332036155098aa64e3dd9beb1b744 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 23 Nov 2023 13:10:36 +0000 Subject: [PATCH 02/22] fix some error in topo order --- python/paddle/base/backward.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index 54d21b4316eb7..175106f025429 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -1294,23 +1294,21 @@ def _topo_order_map(block, target_vars): get_defined_op = {} # mapping from String -> OpDesc (defined op) for op in block.ops: for out_name in op.output_arg_names: + assert out_name not in get_defined_op, "Duplicated output found." get_defined_op[out_name] = op topo_order_map = {} # mapping from OpDesc -> Topologic Order queue = [var.name for var in target_vars] + visited = set() topo_order_counter = 0 while len(queue) > 0: cur_var_name = queue.pop(0) + visited.add(cur_var_name) cur_op = get_defined_op[cur_var_name] topo_order_map[cur_op] = topo_order_counter topo_order_counter += 1 - for inp in cur_op.input_arg_names[ - ::-1 - ]: # [::-1] for reverse, in dygraph, x + y + z -> o will result in - # z@grad is first calculated. - if ( - inp not in topo_order_map and inp in get_defined_op - ): # maybe slow, find in list ! + for inp in cur_op.input_arg_names: + if inp in get_defined_op and inp not in visited: queue.append(inp) return topo_order_map From 2d69d5d1a83a18bf9a51a2bd5b1cf9c59423eec3 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 23 Nov 2023 13:15:20 +0000 Subject: [PATCH 03/22] remove useless breakpoint --- python/paddle/base/backward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index 175106f025429..c247e81064737 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -668,7 +668,6 @@ def _change_order_by_topo_order(var_name): new_name ] = topo_order_for_backward[op_desc] - breakpoint() for var_name, inputs in renamed_vars.items(): if len(renamed_vars[var_name]) > 1: if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: From be5418828467d2248de116d79d784763deefbe13 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 24 Nov 2023 02:58:57 +0000 Subject: [PATCH 04/22] fix --- python/paddle/base/backward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index c247e81064737..9bcbad3e682af 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -1293,7 +1293,6 @@ def _topo_order_map(block, target_vars): get_defined_op = {} # mapping from String -> OpDesc (defined op) for op in block.ops: for out_name in op.output_arg_names: - assert out_name not in get_defined_op, "Duplicated output found." get_defined_op[out_name] = op topo_order_map = {} # mapping from OpDesc -> Topologic Order From 056fb963799f2997422a65f2db8f5599a37c2b39 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Sun, 26 Nov 2023 13:40:46 +0000 Subject: [PATCH 05/22] fix --- python/paddle/base/backward.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index 9bcbad3e682af..fa8ce0d7540ed 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -588,9 +588,11 @@ def _change_order_by_topo_order(var_name): # it's the first time we get the variable renamed_vars[var_name] = [var_name] renamed_var_start_idx[var_name] = idx - topo_order_for_grad_name[ - var_name - ] = topo_order_for_backward[op_desc] + topo_order_for_grad_name[var_name] = ( + topo_order_for_backward[op_desc] + if topo_order_for_backward + else 1 + ) else: if len(renamed_vars[var_name]) == 1: new_name = ( @@ -664,9 +666,11 @@ def _change_order_by_topo_order(var_name): renamed_vars[var_name].append(new_name) # record the latest device var_device[var_name] = op_device - topo_order_for_grad_name[ - new_name - ] = topo_order_for_backward[op_desc] + topo_order_for_grad_name[new_name] = ( + topo_order_for_backward[op_desc] + if topo_order_for_backward + else 1 + ) for var_name, inputs in renamed_vars.items(): if len(renamed_vars[var_name]) > 1: @@ -1198,7 +1202,6 @@ def _append_backward_ops_with_checkpoints_( grad_to_var.update(op_grad_to_var) # 3.d. add sum op for repetitive_outputs - topo_order = _topo_order_map(block, target_vars) grad_op_descs = _addup_repetitive_outputs_( grad_op_descs, block.idx, grad_op_id_to_fwd_op=grad_op_id_to_fwd_op ) From 52da1cd4f8893cce6576a3367de6c9aebe3620a5 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Sun, 26 Nov 2023 13:55:00 +0000 Subject: [PATCH 06/22] fix --- python/paddle/base/backward.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index fa8ce0d7540ed..f94429ab49cb1 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -591,6 +591,7 @@ def _change_order_by_topo_order(var_name): topo_order_for_grad_name[var_name] = ( topo_order_for_backward[op_desc] if topo_order_for_backward + and op_desc in topo_order_for_backward else 1 ) else: @@ -669,6 +670,7 @@ def _change_order_by_topo_order(var_name): topo_order_for_grad_name[new_name] = ( topo_order_for_backward[op_desc] if topo_order_for_backward + and op_desc in topo_order_for_backward else 1 ) From 4be9cdbb2df81f84a7a554e083872b8e0258f7e0 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 27 Nov 2023 03:38:17 +0000 Subject: [PATCH 07/22] fix --- python/paddle/base/backward.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index f94429ab49cb1..cd697cc0bbd5b 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -1307,6 +1307,8 @@ def _topo_order_map(block, target_vars): while len(queue) > 0: cur_var_name = queue.pop(0) visited.add(cur_var_name) + if cur_var_name not in get_defined_op: + continue cur_op = get_defined_op[cur_var_name] topo_order_map[cur_op] = topo_order_counter topo_order_counter += 1 From 14ec0123efceb24051dd91862aa740b4bbeee6cd Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 27 Nov 2023 06:16:05 +0000 Subject: [PATCH 08/22] fix --- python/paddle/base/backward.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index cd697cc0bbd5b..fbd5a737b4e84 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -15,6 +15,7 @@ import collections import copy import logging +import os import re import warnings from collections.abc import Sequence @@ -1596,9 +1597,16 @@ def find_op_index(block_desc, cur_op_desc): ] # sum parameter's gradients' var given multiple var gradient topo_order = _topo_order_map(block, target_vars) - topo_order_for_backward = _topo_bwd_order_map( - topo_order, get_backward_op_desc - ) + if os.environ.get("FLAGS_program_topo_reorder", "False") in [ + 'True', + '1', + 'true', + ]: + topo_order_for_backward = _topo_bwd_order_map( + topo_order, get_backward_op_desc + ) + else: + topo_order_for_backward = None grad_op_descs = _addup_repetitive_outputs_( grad_op_descs, block.idx, From e0900d22646031e3d047181d9aa697e12eb93251 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 28 Nov 2023 09:03:09 +0000 Subject: [PATCH 09/22] adjustly ir backward prune routine. --- python/paddle/autograd/ir_backward.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index f2ad53b0254d8..7f097780051bc 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -703,8 +703,12 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): ) state.turn_map() + for bwd_op in inverse_sort_op(remove_ops): - remove_op(block, bwd_op, state) + if bwd_op.result(0) in grad_outputs: + continue + if bwd_op.result(0).use_empty(): + remove_op(block, bwd_op, state) state.turn_map() input_grad_map = state.value_to_valuegrad From 6e52a6fd73deae71f6210cb77e9d65e735174186 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 28 Nov 2023 09:11:49 +0000 Subject: [PATCH 10/22] fix --- python/paddle/base/backward.py | 84 +--------------------------------- 1 file changed, 1 insertion(+), 83 deletions(-) diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index fbd5a737b4e84..e62a5b9245a1b 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -15,7 +15,6 @@ import collections import copy import logging -import os import re import warnings from collections.abc import Sequence @@ -512,11 +511,7 @@ def _accumulate_gradients_by_add_ops_( def _addup_repetitive_outputs_( - op_descs, - block_idx, - grad_var_to_var=None, - grad_op_id_to_fwd_op=None, - topo_order_for_backward=None, + op_descs, block_idx, grad_var_to_var=None, grad_op_id_to_fwd_op=None ): """ In backward part, an variable may be the output of more than one ops. @@ -530,18 +525,12 @@ def _addup_repetitive_outputs_( """ _MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add'] - topo_order_for_grad_name = {} # pending_sum_ops = [] pending_sum_ops = collections.OrderedDict() var_rename_count = collections.defaultdict(int) renamed_vars = collections.defaultdict(list) renamed_var_start_idx = collections.defaultdict(list) var_device = collections.defaultdict(str) - - def _change_order_by_topo_order(var_name): - origin_names = renamed_vars[var_name] - origin_names.sort(key=lambda x: topo_order_for_grad_name[x]) - for idx, op_desc in enumerate(op_descs): op_device_attr_name = ( core.op_proto_and_checker_maker.kOpDeviceAttrName() @@ -554,7 +543,6 @@ def _change_order_by_topo_order(var_name): continue if len(renamed_vars[var_name]) > 1: if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: - _change_order_by_topo_order(var_name) _accumulate_gradients_by_sum_op_( var_name, renamed_vars, @@ -563,7 +551,6 @@ def _change_order_by_topo_order(var_name): var_device[var_name], ) else: - _change_order_by_topo_order(var_name) _accumulate_gradients_by_add_ops_( var_name, renamed_vars, @@ -589,12 +576,6 @@ def _change_order_by_topo_order(var_name): # it's the first time we get the variable renamed_vars[var_name] = [var_name] renamed_var_start_idx[var_name] = idx - topo_order_for_grad_name[var_name] = ( - topo_order_for_backward[op_desc] - if topo_order_for_backward - and op_desc in topo_order_for_backward - else 1 - ) else: if len(renamed_vars[var_name]) == 1: new_name = ( @@ -614,9 +595,6 @@ def _change_order_by_topo_order(var_name): else: grad_var_to_var[new_name] = var_name # rename original var_name - topo_order_for_grad_name[ - new_name - ] = topo_order_for_grad_name[var_name] renamed_vars[var_name][0] = new_name # before change: _rename_arg_(op_descs, var_name, # new_name, 0, idx) @@ -668,17 +646,10 @@ def _change_order_by_topo_order(var_name): renamed_vars[var_name].append(new_name) # record the latest device var_device[var_name] = op_device - topo_order_for_grad_name[new_name] = ( - topo_order_for_backward[op_desc] - if topo_order_for_backward - and op_desc in topo_order_for_backward - else 1 - ) for var_name, inputs in renamed_vars.items(): if len(renamed_vars[var_name]) > 1: if len(renamed_vars[var_name]) > _MAX_ADD_NUM_: - _change_order_by_topo_order(var_name) _accumulate_gradients_by_sum_op_( var_name, renamed_vars, @@ -687,7 +658,6 @@ def _change_order_by_topo_order(var_name): var_device[var_name], ) else: - _change_order_by_topo_order(var_name) _accumulate_gradients_by_add_ops_( var_name, renamed_vars, @@ -1292,44 +1262,6 @@ def _rename_grad_name_(name, grad_order): return 'grad/' * grad_order + name -def _topo_order_map(block, target_vars): - """Analysis forward block and build a mapping from: - OpDesc -> Int - """ - get_defined_op = {} # mapping from String -> OpDesc (defined op) - for op in block.ops: - for out_name in op.output_arg_names: - get_defined_op[out_name] = op - - topo_order_map = {} # mapping from OpDesc -> Topologic Order - queue = [var.name for var in target_vars] - visited = set() - topo_order_counter = 0 - while len(queue) > 0: - cur_var_name = queue.pop(0) - visited.add(cur_var_name) - if cur_var_name not in get_defined_op: - continue - cur_op = get_defined_op[cur_var_name] - topo_order_map[cur_op] = topo_order_counter - topo_order_counter += 1 - for inp in cur_op.input_arg_names: - if inp in get_defined_op and inp not in visited: - queue.append(inp) - return topo_order_map - - -def _topo_bwd_order_map(topo_fwd_map, backward_op_map): - topo_bwd_map = {} - topo_fwd_map = {op.desc: order for op, order in topo_fwd_map.items()} - for fwd_op, bwd_ops in backward_op_map.items(): - if fwd_op not in topo_fwd_map: - continue - for bwd_op in bwd_ops: - topo_bwd_map[bwd_op] = topo_fwd_map[fwd_op] - return topo_bwd_map - - def _append_backward_ops_( block, ops, @@ -1393,7 +1325,6 @@ def update_distop_context( # grad_op_descs holds created grad_op, and will be appended to target_block grad_op_descs = [] program = block.program - get_backward_op_desc = {} # for topo order map if rename_var_map is None: rename_var_map = {} @@ -1479,7 +1410,6 @@ def find_op_index(block_desc, cur_op_desc): ) # record the mapping between fwd and bwd - get_backward_op_desc[op.desc] = grad_op_desc if grad_op_id_to_fwd_op is not None: for op_desc in grad_op_desc: grad_op_id_to_fwd_op[op_desc.original_id()] = op @@ -1596,23 +1526,11 @@ def find_op_index(block_desc, cur_op_desc): program._appending_grad_times ] # sum parameter's gradients' var given multiple var gradient - topo_order = _topo_order_map(block, target_vars) - if os.environ.get("FLAGS_program_topo_reorder", "False") in [ - 'True', - '1', - 'true', - ]: - topo_order_for_backward = _topo_bwd_order_map( - topo_order, get_backward_op_desc - ) - else: - topo_order_for_backward = None grad_op_descs = _addup_repetitive_outputs_( grad_op_descs, block.idx, grad_var_to_var, grad_op_id_to_fwd_op=grad_op_id_to_fwd_op, - topo_order_for_backward=topo_order_for_backward, ) # if all outputs of the grad op are in no_grad_set, then just remove and fill zero From ddb654d3e32ea79097c6f49a4f3dad4dbf23f8fd Mon Sep 17 00:00:00 2001 From: chenzhiyang <1792266893@qq.com> Date: Tue, 28 Nov 2023 09:23:28 +0000 Subject: [PATCH 11/22] fix cross_entropy_with_softmax vjp bug --- .../dialect/op_generator/op_interface_gen.py | 2 +- .../test_softmax_with_cross_entropy_op.py | 24 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 92881b5d48523..dd2323cd7c3e9 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -149,7 +149,7 @@ def gen_op_vjp_str( index_0 = fwd_outputs_list.index(bw_input_name) else: vjp_param_name = 'out_grads' - grad_idx += 1 + grad_idx = fwd_outputs_list.index(bw_input_name[:-5]) index_0 = grad_idx if op_grad_info.input_optional_list[idx] == 'true': if input_type == 'Tensor': diff --git a/test/legacy_test/test_softmax_with_cross_entropy_op.py b/test/legacy_test/test_softmax_with_cross_entropy_op.py index e2d512707e57d..b0a79084ed118 100644 --- a/test/legacy_test/test_softmax_with_cross_entropy_op.py +++ b/test/legacy_test/test_softmax_with_cross_entropy_op.py @@ -160,11 +160,11 @@ def test_check_grad(self): if core.is_compiled_with_rocm(): if self.python_api is not None: self.check_grad( - ["Logits"], "Loss", max_relative_error=5e-1, check_pir=False + ["Logits"], "Loss", max_relative_error=5e-1, check_pir=True ) # HIP will have accuracy fail when using float32 in CPU place self.check_grad( - ["Logits"], "Loss", max_relative_error=5e-1, check_pir=False + ["Logits"], "Loss", max_relative_error=5e-1, check_pir=True ) else: if self.python_api is not None: @@ -172,10 +172,10 @@ def test_check_grad(self): ["Logits"], "Loss", numeric_grad_delta=0.001, - check_pir=False, + check_pir=True, ) self.check_grad( - ["Logits"], "Loss", numeric_grad_delta=0.001, check_pir=False + ["Logits"], "Loss", numeric_grad_delta=0.001, check_pir=True ) @@ -517,9 +517,9 @@ def test_check_output(self): def test_check_grad(self): if self.python_api is not None: - self.check_grad(["Logits"], "Loss", check_pir=False) + self.check_grad(["Logits"], "Loss", check_pir=True) self.check_grad( - ["Logits"], "Loss", max_relative_error=0.1, check_pir=False + ["Logits"], "Loss", max_relative_error=0.1, check_pir=True ) @@ -540,10 +540,10 @@ def initParams(self): def test_check_grad(self): if self.python_api is not None: self.check_grad( - ["Logits"], "Loss", max_relative_error=0.1, check_pir=False + ["Logits"], "Loss", max_relative_error=0.1, check_pir=True ) self.check_grad( - ["Logits"], "Loss", max_relative_error=0.1, check_pir=False + ["Logits"], "Loss", max_relative_error=0.1, check_pir=True ) @@ -574,15 +574,15 @@ def test_check_grad(self): # HIP will have accuracy fail when using float32 in CPU place if self.python_api is not None: self.check_grad( - ["Logits"], "Loss", max_relative_error=0.1, check_pir=False + ["Logits"], "Loss", max_relative_error=0.1, check_pir=True ) self.check_grad( - ["Logits"], "Loss", max_relative_error=0.1, check_pir=False + ["Logits"], "Loss", max_relative_error=0.1, check_pir=True ) else: if self.python_api is not None: - self.check_grad(["Logits"], "Loss", check_pir=False) - self.check_grad(["Logits"], "Loss", check_pir=False) + self.check_grad(["Logits"], "Loss", check_pir=True) + self.check_grad(["Logits"], "Loss", check_pir=True) class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp): From f57fc8b202fa7c138e35ab688761465078386878 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 28 Nov 2023 13:23:15 +0000 Subject: [PATCH 12/22] fix pre-commit! --- test/dygraph_to_static/test_mnist.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/test/dygraph_to_static/test_mnist.py b/test/dygraph_to_static/test_mnist.py index 34ad272a27d68..ccf37ac7faeac 100644 --- a/test/dygraph_to_static/test_mnist.py +++ b/test/dygraph_to_static/test_mnist.py @@ -20,14 +20,13 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, - test_default_mode_only, + test_legacy_and_pir, ) from predictor_utils import PredictorTools import paddle from paddle import base from paddle.base.dygraph import to_variable -from paddle.base.dygraph.base import switch_to_static_graph from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.nn import Linear from paddle.optimizer import Adam @@ -162,7 +161,7 @@ def train_static(self): def train_dygraph(self): return self.train(to_static=False) - @test_default_mode_only + @test_legacy_and_pir def test_mnist_to_static(self): dygraph_loss = self.train_dygraph() static_loss = self.train_static() @@ -173,7 +172,7 @@ def test_mnist_to_static(self): err_msg=f'dygraph is {dygraph_loss}\n static_res is \n{static_loss}', ) - @test_default_mode_only + @test_legacy_and_pir def test_mnist_declarative_cpu_vs_mkldnn(self): dygraph_loss_cpu = self.train_dygraph() base.set_flags({'FLAGS_use_mkldnn': True}) @@ -239,14 +238,16 @@ def train(self, to_static=False): prediction, acc, avg_loss = mnist(img, label) loss_data.append(float(avg_loss)) # new save load check - self.check_jit_save_load( - mnist, - [dy_x_data], - [img, label], - to_static, - prediction, - [img.name], - ) + # TODO(@xiongkun): enable this after new save load is supported in pir. + if not paddle.base.framework.use_pir_api(): + self.check_jit_save_load( + mnist, + [dy_x_data], + [img, label], + to_static, + prediction, + [img.name], + ) break return loss_data @@ -298,7 +299,6 @@ def check_jit_save_load( gt_out.numpy(), predictor_infer_out, rtol=1e-05 ) - @switch_to_static_graph def jit_load_and_run_inference_static( self, model_path, model_filename, params_filename, inputs ): @@ -320,6 +320,7 @@ def jit_load_and_run_inference_static( feed=dict(zip(feed_target_names, inputs)), fetch_list=fetch_targets, ) + paddle.disable_static() return np.array(results[0]) From 432c6403484a7c8811ef5b65ad4c6d7bcf461dba Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 29 Nov 2023 03:01:34 +0000 Subject: [PATCH 13/22] fix --- paddle/phi/infermeta/multiary.cc | 1 + .../dygraph_to_static_utils.py | 7 +++++ test/dygraph_to_static/test_layer_hook.py | 26 +++++++++++-------- test/dygraph_to_static/test_mnist.py | 6 ++--- test/dygraph_to_static/test_word2vec.py | 6 ++++- 5 files changed, 31 insertions(+), 15 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 7106aaaad5df9..cbb1100e89fa8 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -675,6 +675,7 @@ void BatchNormInferMeta(const MetaTensor& x, } if (reserve_space) { reserve_space->set_dims({-1}); + reserve_space->set_dtype(DataType::UINT8); } y->share_lod(x); y->set_dtype(x.dtype()); diff --git a/test/dygraph_to_static/dygraph_to_static_utils.py b/test/dygraph_to_static/dygraph_to_static_utils.py index 2047e5ea5da14..2d79add5a0387 100644 --- a/test/dygraph_to_static/dygraph_to_static_utils.py +++ b/test/dygraph_to_static/dygraph_to_static_utils.py @@ -398,6 +398,13 @@ def test_default_mode_only(fn): return fn +def test_default_and_pir(fn): + # Some unittests has high time complexity, we only test them with default mode + fn = set_to_static_mode(ToStaticMode.SOT)(fn) + fn = set_ir_mode(IrMode.PT | IrMode.PIR)(fn) + return fn + + # NOTE: This is a special decorator for comparing legacy and pt def compare_legacy_with_pt(fn): @wraps(fn) diff --git a/test/dygraph_to_static/test_layer_hook.py b/test/dygraph_to_static/test_layer_hook.py index 4ae73e450573f..e2a73fa4c59dd 100644 --- a/test/dygraph_to_static/test_layer_hook.py +++ b/test/dygraph_to_static/test_layer_hook.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_utils import Dy2StTestBase, compare_legacy_with_pt +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_legacy_and_pt_and_pir, +) import paddle @@ -66,7 +69,6 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() - @compare_legacy_with_pt def train_net(self, to_static=False): paddle.seed(2022) net = SimpleNet() @@ -74,7 +76,8 @@ def train_net(self, to_static=False): net = paddle.jit.to_static(net) out = net(self.x) - if to_static: + # TODO(xiongkun) save / load unitest. + if to_static and not paddle.base.framework.use_pir_api(): paddle.jit.save(net, self.path) return float(out) @@ -84,23 +87,24 @@ def load_train(self): out = net(self.x) return float(out) + @test_legacy_and_pt_and_pir def test_hook(self): dy_out = self.train_net(to_static=False) st_out = self.train_net(to_static=True) - load_out = self.load_train() - print(st_out, dy_out, load_out) np.testing.assert_allclose( st_out, dy_out, rtol=1e-05, err_msg=f'dygraph_res is {dy_out}\nstatic_res is {st_out}', ) - np.testing.assert_allclose( - st_out, - load_out, - rtol=1e-05, - err_msg=f'load_out is {load_out}\nstatic_res is {st_out}', - ) + if not paddle.base.framework.use_pir_api(): + load_out = self.load_train() + np.testing.assert_allclose( + st_out, + load_out, + rtol=1e-05, + err_msg=f'load_out is {load_out}\nstatic_res is {st_out}', + ) if __name__ == "__main__": diff --git a/test/dygraph_to_static/test_mnist.py b/test/dygraph_to_static/test_mnist.py index ccf37ac7faeac..6d76052fa973f 100644 --- a/test/dygraph_to_static/test_mnist.py +++ b/test/dygraph_to_static/test_mnist.py @@ -20,7 +20,7 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, - test_legacy_and_pir, + test_default_and_pir, ) from predictor_utils import PredictorTools @@ -161,7 +161,7 @@ def train_static(self): def train_dygraph(self): return self.train(to_static=False) - @test_legacy_and_pir + @test_default_and_pir def test_mnist_to_static(self): dygraph_loss = self.train_dygraph() static_loss = self.train_static() @@ -172,7 +172,7 @@ def test_mnist_to_static(self): err_msg=f'dygraph is {dygraph_loss}\n static_res is \n{static_loss}', ) - @test_legacy_and_pir + @test_default_and_pir def test_mnist_declarative_cpu_vs_mkldnn(self): dygraph_loss_cpu = self.train_dygraph() base.set_flags({'FLAGS_use_mkldnn': True}) diff --git a/test/dygraph_to_static/test_word2vec.py b/test/dygraph_to_static/test_word2vec.py index 9f61c540944d2..3b284b2af8426 100644 --- a/test/dygraph_to_static/test_word2vec.py +++ b/test/dygraph_to_static/test_word2vec.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_utils import Dy2StTestBase +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_legacy_and_pt_and_pir, +) import paddle from paddle import base @@ -317,6 +320,7 @@ def train(to_static): class TestWord2Vec(Dy2StTestBase): + @test_legacy_and_pt_and_pir def test_dygraph_static_same_loss(self): dygraph_loss = train(to_static=False) static_loss = train(to_static=True) From 413c948a56e68fa9ecc1ba9a17abf05725220115 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 29 Nov 2023 15:57:47 +0000 Subject: [PATCH 14/22] fix 3 unittest --- python/paddle/nn/layer/norm.py | 2 +- test/dygraph_to_static/test_bmn.py | 31 +++++---------------- test/dygraph_to_static/test_to_tensor.py | 35 ++++++++++++------------ 3 files changed, 25 insertions(+), 43 deletions(-) diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 4a192fd48c84b..4cd13ec19846a 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -1078,7 +1078,7 @@ def forward(self, input): ) else: act_op = getattr(_C_ops, self._act) - return act_op(input) + return act_op(batch_norm_out) else: # create output # mean and mean_out share the same memory diff --git a/test/dygraph_to_static/test_bmn.py b/test/dygraph_to_static/test_bmn.py index e0ac834d67290..ad14ea74d4a1e 100644 --- a/test/dygraph_to_static/test_bmn.py +++ b/test/dygraph_to_static/test_bmn.py @@ -21,7 +21,7 @@ from dygraph_to_static_utils import ( Dy2StTestBase, static_guard, - test_pt_only, + test_legacy_and_pt_and_pir, ) from predictor_utils import PredictorTools @@ -625,16 +625,6 @@ def val_bmn(model, args): float(pem_cls_loss), ] - print( - f'[VALID] iter {batch_id} ' - + '\tLoss = {}, \ttem_loss = {}, \tpem_reg_loss = {}, \tpem_cls_loss = {}'.format( - '%f' % float(avg_loss), - '%f' % float(tem_loss), - '%f' % float(pem_reg_loss), - '%f' % float(pem_cls_loss), - ) - ) - if batch_id == args.valid_batch_num: break return loss_data @@ -722,17 +712,6 @@ def train_bmn(self, args, to_static): float(pem_cls_loss), ] - if args.log_interval > 0 and ( - batch_id % args.log_interval == 0 - ): - print( - f'[TRAIN] Epoch {epoch}, iter {batch_id} ' - + f'\tLoss = {float(avg_loss):f}, ' - + f'\ttem_loss = {float(tem_loss):f}, ' - + f'\tpem_reg_loss = {float(pem_reg_loss):f}, ' - + f'\tpem_cls_loss = {float(pem_cls_loss):f}' - ) - # validation if batch_id % args.valid_interval == 0 and batch_id > 0: bmn.eval() @@ -741,7 +720,11 @@ def train_bmn(self, args, to_static): loss_data += val_loss_data if batch_id == args.train_batch_num: - if to_static: + # TODO(@xiongkun): open after save / load supported in pir. + if ( + to_static + and not paddle.base.framework.use_pir_api() + ): paddle.jit.save(bmn, self.model_save_prefix) else: paddle.save( @@ -751,7 +734,7 @@ def train_bmn(self, args, to_static): break return np.array(loss_data) - @test_pt_only + @test_legacy_and_pt_and_pir def test_train_pir(self): static_res = self.train_bmn(self.args, to_static=True) dygraph_res = self.train_bmn(self.args, to_static=False) diff --git a/test/dygraph_to_static/test_to_tensor.py b/test/dygraph_to_static/test_to_tensor.py index f8e295fcf6f91..9e061c46295f5 100644 --- a/test/dygraph_to_static/test_to_tensor.py +++ b/test/dygraph_to_static/test_to_tensor.py @@ -21,8 +21,6 @@ ToStaticMode, disable_test_case, test_legacy_and_pt_and_pir, - test_legacy_only, - test_pir_only, ) import paddle @@ -184,6 +182,7 @@ def test_to_tensor_err_log(self): class TestStatic(Dy2StTestBase): + @test_legacy_and_pt_and_pir def test_static(self): paddle.enable_static() main_prog = paddle.static.Program() @@ -194,6 +193,7 @@ def test_static(self): else: place = paddle.CPUPlace() + paddle.set_default_dtype("float64") x = paddle.to_tensor( paddle.randn([5, 2]), dtype='float64', @@ -201,7 +201,8 @@ def test_static(self): place=place, ) - out = paddle.static.nn.fc(x, 1) + fc_net = paddle.nn.Linear(2, 1) + out = fc_net(x) sgd = paddle.optimizer.SGD() sgd.minimize(paddle.mean(out)) @@ -212,29 +213,27 @@ def test_static(self): class TestInt16(Dy2StTestBase): - @test_legacy_only + @test_legacy_and_pt_and_pir def test_static(self): import numpy as np paddle.enable_static() data = np.array([1, 2], dtype="int16") x = paddle.to_tensor(data) - self.assertTrue(x.dtype == paddle.framework.core.VarDesc.VarType.INT16) - - y = paddle.to_tensor([1, 2], dtype="int16") - self.assertTrue(y.dtype == paddle.framework.core.VarDesc.VarType.INT16) - - @test_pir_only - def test_static_pir(self): - import numpy as np - - paddle.enable_static() - data = np.array([1, 2], dtype="int16") - x = paddle.to_tensor(data) - self.assertTrue(x.dtype == paddle.base.libpaddle.DataType.INT16) + if paddle.base.framework.use_pir_api(): + self.assertTrue(x.dtype == paddle.base.libpaddle.DataType.INT16) + else: + self.assertTrue( + x.dtype == paddle.framework.core.VarDesc.VarType.INT16 + ) y = paddle.to_tensor([1, 2], dtype="int16") - self.assertTrue(y.dtype == paddle.base.libpaddle.DataType.INT16) + if paddle.base.framework.use_pir_api(): + self.assertTrue(y.dtype == paddle.base.libpaddle.DataType.INT16) + else: + self.assertTrue( + y.dtype == paddle.framework.core.VarDesc.VarType.INT16 + ) if __name__ == '__main__': From 7afca17544a01e5437abe6e3a7433cd10928f16e Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 30 Nov 2023 08:49:25 +0000 Subject: [PATCH 15/22] fix code format --- test/dygraph_to_static/test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dygraph_to_static/test_mnist.py b/test/dygraph_to_static/test_mnist.py index 734f375a40999..694b6239f169e 100644 --- a/test/dygraph_to_static/test_mnist.py +++ b/test/dygraph_to_static/test_mnist.py @@ -339,4 +339,4 @@ def predictor_load_and_run_inference_analysis( if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From f089cf02703623e14ec90fb9da8338c4451d7b1a Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 5 Dec 2023 07:46:22 +0000 Subject: [PATCH 16/22] fix test_tensor_memcpy_on_gpu.py --- test/dygraph_to_static/test_tensor_memcpy_on_gpu.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py b/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py index 828c9874a4a25..b1274267d7bcb 100644 --- a/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py +++ b/test/dygraph_to_static/test_tensor_memcpy_on_gpu.py @@ -16,7 +16,10 @@ import unittest import numpy as np -from dygraph_to_static_utils import Dy2StTestBase +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_legacy_and_pt_and_pir, +) import paddle @@ -46,6 +49,7 @@ def _run(self, to_static): x2 = paddle.jit.to_static(tensor_copy_to_cpu)(x1) return x1.place, x2.place, x2.numpy() + @test_legacy_and_pt_and_pir def test_tensor_cpu_on_default_gpu(self): if paddle.base.is_compiled_with_cuda(): place = paddle.CUDAPlace( @@ -72,6 +76,7 @@ def _run(self, to_static): x2 = paddle.jit.to_static(tensor_copy_to_cuda)(x1) return x1.place, x2.place, x2.numpy() + @test_legacy_and_pt_and_pir def test_tensor_cuda_on_default_gpu(self): if paddle.is_compiled_with_cuda(): place = paddle.CUDAPlace( @@ -100,6 +105,7 @@ def _run(self, to_static): ) return x1.place, x2.place, x2.numpy() + @test_legacy_and_pt_and_pir def test_with_warning_on_gpu(self): if paddle.base.is_compiled_with_cuda(): place = paddle.CUDAPlace( From 4727c0f3fb5f8c9e970010607de86e5ae7e64d63 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 5 Dec 2023 07:51:21 +0000 Subject: [PATCH 17/22] fix test_partial_program.py --- test/dygraph_to_static/test_partial_program.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/dygraph_to_static/test_partial_program.py b/test/dygraph_to_static/test_partial_program.py index 4af1088f720a6..756b5b8299564 100644 --- a/test/dygraph_to_static/test_partial_program.py +++ b/test/dygraph_to_static/test_partial_program.py @@ -249,6 +249,7 @@ def forward(self, x): class TestPruneUnusedParamInProgram(Dy2StTestBase): + @test_legacy_and_pt_and_pir def test_prune(self): input_ids = np.array([[15, 11, 6, 3, 18, 13]]).astype("float32") From f881910090f881e79eb956762ed430323070f772 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 5 Dec 2023 08:01:16 +0000 Subject: [PATCH 18/22] fix test_ptb_lm_v2 --- test/dygraph_to_static/test_ptb_lm_v2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/dygraph_to_static/test_ptb_lm_v2.py b/test/dygraph_to_static/test_ptb_lm_v2.py index 0dcbf5ddc07d2..ab89c20afc25f 100644 --- a/test/dygraph_to_static/test_ptb_lm_v2.py +++ b/test/dygraph_to_static/test_ptb_lm_v2.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_utils import Dy2StTestBase +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_legacy_and_pt_and_pir, +) import paddle @@ -332,6 +335,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_legacy_and_pt_and_pir def test_check_result(self): loss_1, hidden_1, cell_1 = train_static(self.place) loss_2, hidden_2, cell_2 = train_dygraph(self.place) From b837fa9f01410bff1de5c617d62bf4f8004a0294 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Wed, 6 Dec 2023 07:44:46 +0000 Subject: [PATCH 19/22] [PIR]Using inplace batch norm in PIR --- .../pir/dialect/op_generator/ops_api_gen.py | 2 +- python/paddle/nn/layer/norm.py | 33 ++++++++++++++----- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index ed5c193aa2eca..fae7a3c9fc283 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -82,6 +82,7 @@ 'generate_sequence_xpu', 'layer_norm_act_xpu', 'memcpy', + 'batch_norm_', 'multi_encoder_xpu', 'multihead_matmul', 'squeeze_excitation_block', @@ -104,7 +105,6 @@ 'add_n_', 'add_n_with_kernel', 'assign_value', - 'batch_norm_', 'c_allgather', 'c_allreduce_max', 'c_allreduce_sum', diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 4cd13ec19846a..f7d32bb61908b 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -42,6 +42,7 @@ _global_flags, get_default_dtype, in_dynamic_or_pir_mode, + in_pir_mode, no_grad, ) from .. import functional as F @@ -1056,7 +1057,7 @@ def __init__( self._trainable_statistics = trainable_statistics def forward(self, input): - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm( input, self._mean, @@ -1072,13 +1073,29 @@ def forward(self, input): ) if self._act is None: return batch_norm_out - if in_dynamic_mode(): - return dygraph_utils._append_activation_in_dygraph( - batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn - ) - else: - act_op = getattr(_C_ops, self._act) - return act_op(batch_norm_out) + + return dygraph_utils._append_activation_in_dygraph( + batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn + ) + elif in_pir_mode(): + batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm_( + input, + self._mean, + self._variance, + self.weight, + self.bias, + not self.training, + self._momentum, + self._epsilon, + self._data_layout, + self._use_global_stats, + self._trainable_statistics, + ) + if self._act is None: + return batch_norm_out + + act_op = getattr(_C_ops, self._act) + return act_op(batch_norm_out) else: # create output # mean and mean_out share the same memory From 2f65ebacb2ec47c81778b37874bc939ed6ea2345 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 6 Dec 2023 07:53:38 +0000 Subject: [PATCH 20/22] fix apply pass error --- .../eager/to_static/run_program_op_node.h | 25 ++++++++++---- paddle/fluid/framework/executor_cache.cc | 34 +++++++++++-------- paddle/fluid/framework/executor_cache.h | 3 ++ 3 files changed, 41 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index f79649f71069d..2705daeac03df 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -17,6 +17,7 @@ #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/tensor_wrapper.h" +#include "paddle/fluid/framework/executor_cache.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" @@ -500,10 +501,16 @@ inline void PirRunProgramAPI( details::ShareTensorsIntoScopeByValue( forward_global_block, params, param_values, global_inner_scope); // Step 2. create new interpretercore - auto kernel_forward_program = - paddle::dialect::PdOpLowerToKernelPass(forward_program, place); + auto passed_kernel_program = + paddle::framework::ApplyIrPass(forward_program, place); + if (FLAGS_print_ir) { + std::ostringstream print_stream; + print_stream << "LoweredProgram( AfterPass ) is :\n"; + passed_kernel_program->Print(print_stream); + std::cout << print_stream.str() << std::endl; + } interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( - std::move(kernel_forward_program), + std::move(passed_kernel_program), place, /*is_grad=*/false, program_id, @@ -1037,10 +1044,16 @@ inline void PirRunProgramGradAPI( 1); VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; // Step 1. share input_vars & parameters into scope - auto kernel_backward_program = - paddle::dialect::PdOpLowerToKernelPass(backward_program, place); + auto passed_kernel_program = + paddle::framework::ApplyIrPass(backward_program, place); + if (FLAGS_print_ir) { + std::ostringstream print_stream; + print_stream << "LoweredProgram( AfterPass | Backward ) is :\n"; + passed_kernel_program->Print(print_stream); + std::cout << print_stream.str() << std::endl; + } interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( - std::move(kernel_backward_program), + std::move(passed_kernel_program), place, /*is_grad=*/true, program_id, diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 6a6be34c3eebc..6f64cf44bf69c 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -358,6 +358,24 @@ bool TensorSortHelper(const paddle::Tensor &t1, const paddle::Tensor &t2) { return t1.name() < t2.name(); } +std::unique_ptr<::pir::Program> ApplyIrPass(::pir::Program *program, + phi::Place place) { + auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program, place); + + if (FLAGS_pir_apply_inplace_pass) { + ::pir::PassManager pm(::pir::IrContext::Instance(), 3); + pm.AddPass(::pir::CreateInplacePass()); + pm.Run(ir_res.get()); + + if (FLAGS_print_ir) { + std::cout << "IR After inplace -------------------" << std::endl; + std::cout << *ir_res << std::endl; + } + } + + return ir_res; +} + std::unique_ptr<::pir::Program> ConstructFowardIrProgram( const paddle::framework::BlockDesc *forward_global_block, const paddle::framework::BlockDesc *backward_global_block, @@ -456,21 +474,7 @@ std::unique_ptr<::pir::Program> ConstructFowardIrProgram( program.get()); program_translator.Translate(); - - auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - if (FLAGS_pir_apply_inplace_pass) { - ::pir::PassManager pm(::pir::IrContext::Instance(), 3); - pm.AddPass(::pir::CreateInplacePass()); - pm.Run(ir_res.get()); - - if (FLAGS_print_ir) { - std::cout << "IR After inplace -------------------" << std::endl; - std::cout << *ir_res << std::endl; - } - } - - return ir_res; + return ApplyIrPass(program.get(), place); } std::unique_ptr<::pir::Program> ConstructBackwardIrProgram( diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index ad94fcbeca107..e6da435a903aa 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -252,6 +252,9 @@ std::shared_ptr CreatePirInterpreterCoreInfoToCache( int64_t program_id, framework::Scope* scope); +std::unique_ptr<::pir::Program> ApplyIrPass(::pir::Program* program, + phi::Place place); + std::unique_ptr<::pir::Program> ConstructFowardIrProgram( const paddle::framework::BlockDesc* forward_global_block, const paddle::framework::BlockDesc* backward_global_block, From 95d944edd58c9c329d9d4d0842bcea331cd0c665 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 6 Dec 2023 16:10:32 +0000 Subject: [PATCH 21/22] fix bn problem. --- paddle/fluid/eager/to_static/run_program_op_node.h | 7 +++---- paddle/fluid/pybind/pir.cc | 6 ++++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 2705daeac03df..7ce5e637fcc86 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -290,6 +290,7 @@ static void ShareTensorsFromScopeByValue( for (size_t i = 0; i < tensors.size(); ++i) { auto &name = names[i]; auto &value = values[i]; + VLOG(2) << "share " << name << " from scope"; if (value.impl() == nullptr) { // skip stop_gradient. continue; @@ -307,7 +308,7 @@ static void ShareTensorsFromScopeByValue( auto &src_tensor = var->Get(); auto *dst_tensor = const_cast( dynamic_cast(tensors[i]->impl().get())); - VLOG(2) << "share " << name << " from scope"; + VLOG(2) << "actually do sharing " << name << " from scope"; *dst_tensor = src_tensor; } else if (var->IsType()) { auto &src_tensor = var->Get(); @@ -1359,9 +1360,7 @@ class PirGradNodeRunProgram : public egr::GradNodeBase { x_grad_ptr.emplace_back(&i); } for (auto &i : params_grad) { - if (i.defined()) { - params_grad_ptr.emplace_back(&i); - } + params_grad_ptr.emplace_back(&i); } } diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 58a8012e021d0..e7964c3ae3368 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -136,6 +136,9 @@ inline void SetProgramInt64Attr(std::shared_ptr program, } std::string GetValueInfo(Value v) { + if (v.impl() == nullptr) { + return "nullptr value"; + } std::stringstream ss; if (auto op_result = v.dyn_cast()) { ss << "define_op_name=" << op_result.owner()->name(); @@ -1058,12 +1061,11 @@ int AppendSetParameters(Program *forward_program, std::unordered_set added_op_result; for (const auto &result : outputs_op_result) { - if (!added_op_result.count(result)) { + if (!added_op_result.count(result) || IsFakeOpResult(result)) { std::string parameter_name = name_prefix + std::to_string(counter); AppendSetParameter( forward_program, result, parameter_name, start_point + counter); counter += 1; - added_op_result.insert(result); } } From 6677bfe86b46713304cb3a96596e0fa7674ab4a4 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 6 Dec 2023 16:21:06 +0000 Subject: [PATCH 22/22] fix test_resnet.py uniitest --- test/dygraph_to_static/test_resnet.py | 35 ++++++++++----------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/test/dygraph_to_static/test_resnet.py b/test/dygraph_to_static/test_resnet.py index d86024400272f..7f72b900133a9 100644 --- a/test/dygraph_to_static/test_resnet.py +++ b/test/dygraph_to_static/test_resnet.py @@ -21,10 +21,7 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, - test_default_mode_only, - test_legacy_only, - test_pt_only, - test_sot_only, + test_default_and_pir, ) from predictor_utils import PredictorTools @@ -341,7 +338,12 @@ def do_train(self, to_static): ) if batch_id == 10: if to_static: - paddle.jit.save(resnet, self.model_save_prefix) + # TODO(@xiongkun): open after save / load supported in pir. + if ( + to_static + and not paddle.base.framework.use_pir_api() + ): + paddle.jit.save(resnet, self.model_save_prefix) else: paddle.save( resnet.state_dict(), @@ -442,20 +444,7 @@ def verify_predict(self): err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.', ) - @test_sot_only - @test_pt_only - def test_resnet_pir(self): - static_loss = self.train(to_static=True) - dygraph_loss = self.train(to_static=False) - np.testing.assert_allclose( - static_loss, - dygraph_loss, - rtol=1e-05, - err_msg=f'static_loss: {static_loss} \n dygraph_loss: {dygraph_loss}', - ) - - @test_sot_only - @test_legacy_only + @test_default_and_pir def test_resnet(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) @@ -465,9 +454,11 @@ def test_resnet(self): rtol=1e-05, err_msg=f'static_loss: {static_loss} \n dygraph_loss: {dygraph_loss}', ) - self.verify_predict() + # TODO(@xiongkun): open after save / load supported in pir. + if not paddle.base.framework.use_pir_api(): + self.verify_predict() - @test_default_mode_only + @test_default_and_pir def test_resnet_composite(self): core._set_prim_backward_enabled(True) core._add_skip_comp_ops("batch_norm") @@ -481,7 +472,7 @@ def test_resnet_composite(self): err_msg=f'static_loss: {static_loss} \n dygraph_loss: {dygraph_loss}', ) - @test_default_mode_only + @test_default_and_pir def test_in_static_mode_mkldnn(self): paddle.base.set_flags({'FLAGS_use_mkldnn': True}) try: