Skip to content

Commit

Permalink
TL/UCP: allreduce sra knomial opt (openucx#823)
Browse files Browse the repository at this point in the history
* TL/UCP: allreduce sra knomial opt

* EC/CPU: fix overlap reductions
  • Loading branch information
Sergei-Lebedev authored and janjust committed Jan 31, 2024
1 parent f482d9b commit 30a6988
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 105 deletions.
48 changes: 24 additions & 24 deletions src/components/ec/cpu/ec_cpu_reduce.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -8,9 +8,10 @@
#include "ec_cpu.h"
#include <complex.h>

#define DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, OP) \
#define DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, OP) \
do { \
size_t _i, _j; \
type _tmp; \
switch (_n_srcs) { \
case 2: \
for (_i = 0; _i < _count; _i++) { \
Expand Down Expand Up @@ -53,13 +54,12 @@
break; \
default: \
for (_i = 0; _i < _count; _i++) { \
d[_i] = OP##_8(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \
s[4][_i], s[5][_i], s[6][_i], s[7][_i]); \
} \
for (_j = 8; _j < _n_srcs; _j++) { \
for (_i = 0; _i < _count; _i++) { \
d[_i] = OP##_2(d[_i], s[_j][_i]); \
_tmp = OP##_8(s[0][_i], s[1][_i], s[2][_i], s[3][_i], \
s[4][_i], s[5][_i], s[6][_i], s[7][_i]); \
for (_j = 8; _j < _n_srcs; _j++) { \
_tmp = OP##_2(_tmp, s[_j][_i]); \
} \
d[_i] = _tmp; \
} \
break; \
} \
Expand All @@ -80,37 +80,37 @@
switch (_op) { \
case UCC_OP_AVG: \
case UCC_OP_SUM: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_SUM); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_SUM); \
if (flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA) { \
VEC_OP(d, _count, task->alpha); \
} \
break; \
case UCC_OP_MIN: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_MIN); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_MIN); \
break; \
case UCC_OP_MAX: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_MAX); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_MAX); \
break; \
case UCC_OP_PROD: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_PROD); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_PROD); \
break; \
case UCC_OP_LAND: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_LAND); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_LAND); \
break; \
case UCC_OP_BAND: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_BAND); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_BAND); \
break; \
case UCC_OP_LOR: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_LOR); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_LOR); \
break; \
case UCC_OP_BOR: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_BOR); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_BOR); \
break; \
case UCC_OP_LXOR: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_LXOR); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_LXOR); \
break; \
case UCC_OP_BXOR: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_BXOR); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_BXOR); \
break; \
default: \
ec_error(&ucc_ec_cpu.super, \
Expand Down Expand Up @@ -176,16 +176,16 @@
switch (_op) { \
case UCC_OP_AVG: \
case UCC_OP_SUM: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_SUM); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_SUM); \
break; \
case UCC_OP_PROD: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_PROD); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_PROD); \
break; \
case UCC_OP_MIN: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_MIN); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_MIN); \
break; \
case UCC_OP_MAX: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_MAX); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_MAX); \
break; \
default: \
ec_error(&ucc_ec_cpu.super, \
Expand All @@ -206,10 +206,10 @@
switch (_op) { \
case UCC_OP_AVG: \
case UCC_OP_SUM: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_SUM); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_SUM); \
break; \
case UCC_OP_PROD: \
DO_DT_REDUCE_WITH_OP(s, d, _count, _n_srcs, DO_OP_PROD); \
DO_DT_REDUCE_WITH_OP(type, s, d, _count, _n_srcs, DO_OP_PROD); \
break; \
default: \
ec_error(&ucc_ec_cpu.super, \
Expand Down
Loading

0 comments on commit 30a6988

Please sign in to comment.