Skip to content

Commit

Permalink
【PIR Dist Op Reg No.17】 reg barrier (#62802)
Browse files Browse the repository at this point in the history
* feat(pir): reg barrier

* feat(pir): reg barrier
  • Loading branch information
xiaoyewww authored Mar 20, 2024
1 parent 294b3cf commit 87500f4
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
'add_n_',
'all_reduce',
'all_reduce_',
'barrier',
'c_allgather',
'c_allreduce_avg',
'c_allreduce_max',
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@
data_type : dtype
backend : place > output

- op : barrier
args : (Tensor x, int ring_id=0)
output : Tensor(out)
kernel :
func : barrier

- op : batch_norm
args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_format, bool use_global_stats, bool trainable_statistics)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,12 @@
outputs :
{auc : AUC, stat_pos_out : StatPosOut, stat_neg_out : StatNegOut}

- op : barrier
inputs :
{x : X}
outputs :
out : Out

- op : batch_norm
backward : batch_norm_grad, batch_norm_double_grad(batch_norm_grad_grad)
inputs:
Expand Down
1 change: 1 addition & 0 deletions test/ir/pir/translator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ file(
string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}")

set(DISTRIBUTED_OP_TRANSLATOR_TEST test_all_reduce_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_barrier_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_min_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_allreduce_min_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_allreduce_prod_translator)
Expand Down
44 changes: 44 additions & 0 deletions test/ir/pir/translator/test_barrier_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import test_op_translator

import paddle
from paddle.base.layer_helper import LayerHelper


class TestBarrierOpTranslator(test_op_translator.TestOpTranslator):
def append_op(self):
self.op_type = "barrier"
x = paddle.ones(shape=(100, 2, 3), dtype='float32')
y = paddle.ones(shape=(100, 2, 3), dtype='float32')
attrs = {
'ring_id': 0,
}
helper = LayerHelper(self.op_type)
helper.append_op(
type=self.op_type,
inputs={"X": x},
outputs={"Out": y},
attrs=attrs,
)

def test_translator(self):
self.check()


if __name__ == "__main__":
unittest.main()

0 comments on commit 87500f4

Please sign in to comment.