diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index ed5c193aa2eca..fae7a3c9fc283 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -82,6 +82,7 @@ 'generate_sequence_xpu', 'layer_norm_act_xpu', 'memcpy', + 'batch_norm_', 'multi_encoder_xpu', 'multihead_matmul', 'squeeze_excitation_block', @@ -104,7 +105,6 @@ 'add_n_', 'add_n_with_kernel', 'assign_value', - 'batch_norm_', 'c_allgather', 'c_allreduce_max', 'c_allreduce_sum', diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 4cd13ec19846a..f7d32bb61908b 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -42,6 +42,7 @@ _global_flags, get_default_dtype, in_dynamic_or_pir_mode, + in_pir_mode, no_grad, ) from .. import functional as F @@ -1056,7 +1057,7 @@ def __init__( self._trainable_statistics = trainable_statistics def forward(self, input): - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm( input, self._mean, @@ -1072,13 +1073,29 @@ def forward(self, input): ) if self._act is None: return batch_norm_out - if in_dynamic_mode(): - return dygraph_utils._append_activation_in_dygraph( - batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn - ) - else: - act_op = getattr(_C_ops, self._act) - return act_op(batch_norm_out) + + return dygraph_utils._append_activation_in_dygraph( + batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn + ) + elif in_pir_mode(): + batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm_( + input, + self._mean, + self._variance, + self.weight, + self.bias, + not self.training, + self._momentum, + self._epsilon, + self._data_layout, + self._use_global_stats, + self._trainable_statistics, + ) + if self._act is None: + return batch_norm_out + + act_op = getattr(_C_ops, self._act) + return act_op(batch_norm_out) else: # create output # mean and mean_out share the same memory