Skip to content

Commit

Permalink
TL/UCP: use pipelining in SRA allreduce for CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Nov 14, 2023
1 parent 885bf53 commit 0db620d
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 60 deletions.
5 changes: 4 additions & 1 deletion src/components/mc/base/ucc_mc_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ typedef struct ucc_mem_attr {
* UCC memory component attributes field mask
*/
typedef enum ucc_mc_attr_field {
UCC_MC_ATTR_FIELD_THREAD_MODE = UCC_BIT(0)
UCC_MC_ATTR_FIELD_THREAD_MODE = UCC_BIT(0),
/* size of memory pool chunk element */
UCC_MC_ATTR_FIELD_FAST_ALLOC_SIZE = UCC_BIT(1),
} ucc_mc_attr_field_t;

typedef struct ucc_mc_attr {
Expand All @@ -81,6 +83,7 @@ typedef struct ucc_mc_attr {
*/
uint64_t field_mask;
ucc_thread_mode_t thread_mode;
size_t fast_alloc_size;
} ucc_mc_attr_t;

/**
Expand Down
7 changes: 7 additions & 0 deletions src/components/mc/cuda/mc_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ static ucc_status_t ucc_mc_cuda_get_attr(ucc_mc_attr_t *mc_attr)
if (mc_attr->field_mask & UCC_MC_ATTR_FIELD_THREAD_MODE) {
mc_attr->thread_mode = ucc_mc_cuda.thread_mode;
}
if (mc_attr->field_mask & UCC_MC_ATTR_FIELD_FAST_ALLOC_SIZE) {
if (MC_CUDA_CONFIG->mpool_max_elems > 0) {
mc_attr->fast_alloc_size = MC_CUDA_CONFIG->mpool_elem_size;
} else {
mc_attr->fast_alloc_size = 0;
}
}
return UCC_OK;
}

Expand Down
11 changes: 11 additions & 0 deletions src/components/mc/ucc_mc.c
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ ucc_status_t ucc_mc_get_mem_attr(const void *ptr, ucc_mem_attr_t *mem_attr)
return UCC_OK;
}

ucc_status_t ucc_mc_get_attr(ucc_mc_attr_t *attr, ucc_memory_type_t mem_type)
{
ucc_memory_type_t mt = (mem_type == UCC_MEMORY_TYPE_CUDA_MANAGED) ?
UCC_MEMORY_TYPE_CUDA : mem_type;
ucc_mc_base_t *mc;

UCC_CHECK_MC_AVAILABLE(mt);
mc = ucc_container_of(mc_ops[mt], ucc_mc_base_t, ops);
return mc->get_attr(attr);
}

UCC_MC_PROFILE_FUNC(ucc_status_t, ucc_mc_alloc, (h_ptr, size, mem_type),
ucc_mc_buffer_header_t **h_ptr, size_t size,
ucc_memory_type_t mem_type)
Expand Down
2 changes: 2 additions & 0 deletions src/components/mc/ucc_mc.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ ucc_status_t ucc_mc_available(ucc_memory_type_t mem_type);
*/
ucc_status_t ucc_mc_get_mem_attr(const void *ptr, ucc_mem_attr_t *mem_attr);

ucc_status_t ucc_mc_get_attr(ucc_mc_attr_t *attr, ucc_memory_type_t mem_type);

ucc_status_t ucc_mc_alloc(ucc_mc_buffer_header_t **h_ptr, size_t len,
ucc_memory_type_t mem_type);

Expand Down
145 changes: 87 additions & 58 deletions src/components/tl/ucp/allreduce/allreduce_sra_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "coll_patterns/sra_knomial.h"
#include "utils/ucc_math.h"
#include "utils/ucc_coll_utils.h"
#include "components/mc/ucc_mc.h"
#include "../reduce_scatter/reduce_scatter.h"
#include "../allgather/allgather.h"

Expand Down Expand Up @@ -53,41 +54,40 @@ ucc_tl_ucp_allreduce_sra_knomial_frag_finalize(ucc_coll_task_t *task)
return status;
}

static ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_frag_setup(
ucc_schedule_pipelined_t *schedule_p, ucc_schedule_t *frag, int frag_num)
static ucc_status_t
ucc_tl_ucp_allreduce_sra_knomial_frag_setup(ucc_schedule_pipelined_t *schedule_p,
ucc_schedule_t *frag, int frag_num)
{
ucc_coll_args_t *args = &schedule_p->super.super.bargs.args;
ucc_datatype_t dt = args->dst.info.datatype;
size_t dt_size = ucc_dt_size(dt);
ucc_coll_args_t *targs;
ucc_coll_args_t *args = &schedule_p->super.super.bargs.args;
ucc_datatype_t dt = args->dst.info.datatype;
size_t dt_size = ucc_dt_size(dt);
int n_frags = schedule_p->super.n_tasks;
size_t frag_count = ucc_buffer_block_count(args->dst.info.count,
n_frags, frag_num);
size_t offset = ucc_buffer_block_offset(args->dst.info.count,
n_frags, frag_num);
ucc_coll_args_t *targs;

targs = &frag->tasks[0]->bargs.args; //REDUCE_SCATTER
targs->src.info.buffer =
PTR_OFFSET(args->src.info.buffer, offset * dt_size);
targs->dst.info.buffer =
PTR_OFFSET(args->dst.info.buffer, offset * dt_size);
targs->src.info.count = frag_count;
targs->dst.info.count = frag_count;
targs = &frag->tasks[0]->bargs.args; /* REDUCE_SCATTER */
targs->src.info.buffer = PTR_OFFSET(args->src.info.buffer, offset * dt_size);
targs->src.info.count = frag_count;
targs->dst.info.buffer = PTR_OFFSET(args->dst.info.buffer, offset * dt_size);
targs->dst.info.count = frag_count;

targs = &frag->tasks[1]->bargs.args; //ALLGATHER
targs = &frag->tasks[1]->bargs.args; /* ALLGATHER */
targs->src.info.buffer = NULL;
targs->dst.info.buffer =
PTR_OFFSET(args->dst.info.buffer, offset * dt_size);
targs->src.info.count = 0;
targs->dst.info.count = frag_count;
targs->src.info.count = 0;
targs->dst.info.buffer = PTR_OFFSET(args->dst.info.buffer, offset * dt_size);
targs->dst.info.count = frag_count;

return UCC_OK;
}

static ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_frag_init(
ucc_base_coll_args_t *coll_args,
ucc_schedule_pipelined_t *sp, //NOLINT
ucc_base_team_t *team, ucc_schedule_t **frag_p)
static ucc_status_t
ucc_tl_ucp_allreduce_sra_knomial_frag_init(ucc_base_coll_args_t *coll_args,
ucc_schedule_pipelined_t *sp, //NOLINT
ucc_base_team_t *team,
ucc_schedule_t **frag_p)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_datatype_t dtype = coll_args->args.dst.info.datatype;
Expand Down Expand Up @@ -166,55 +166,84 @@ ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_start(ucc_coll_task_t *task)
return ucc_schedule_pipelined_post(task);
}

ucc_status_t
ucc_tl_ucp_allreduce_sra_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
static void
ucc_tl_ucp_allreduce_sra_knomial_get_pipeline_params(ucc_tl_ucp_team_t *team,
ucc_coll_args_t *args,
ucc_pipeline_params_t *pp)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_tl_ucp_lib_config_t *cfg = &tl_team->cfg;
int n_frags, pipeline_depth;
ucc_schedule_pipelined_t *schedule_p;
ucc_status_t status;
ucc_base_coll_args_t bargs;
size_t max_frag_count, dt_size;
ucc_tl_ucp_lib_config_t *cfg = &team->cfg;

dt_size = ucc_dt_size(coll_args->args.dst.info.datatype);
status = ucc_tl_ucp_get_schedule(tl_team, coll_args,
(ucc_tl_ucp_schedule_t **)&schedule_p);
if (ucc_unlikely(UCC_OK != status)) {
return status;
if (!ucc_pipeline_params_is_auto(&cfg->allreduce_sra_kn_pipeline)) {
*pp = cfg->allreduce_sra_kn_pipeline;
return;
}
bargs = *coll_args;

if (bargs.mask & UCC_BASE_CARGS_MAX_FRAG_COUNT) {
max_frag_count = bargs.max_frag_count;
if ((args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) &&
(UCC_IS_INPLACE(*args))) {
ucc_mc_attr_t mc_attr;
mc_attr.field_mask = UCC_MC_ATTR_FIELD_FAST_ALLOC_SIZE;
ucc_mc_get_attr(&mc_attr, UCC_MEMORY_TYPE_CUDA);
pp->threshold = mc_attr.fast_alloc_size;
pp->n_frags = 2;
pp->frag_size = mc_attr.fast_alloc_size;
pp->order = UCC_PIPELINE_PARALLEL;
pp->pdepth = 2;
} else {
max_frag_count = coll_args->args.dst.info.count;
pp->threshold = SIZE_MAX;
pp->n_frags = 0;
pp->frag_size = 0;
pp->pdepth = 1;
pp->order = UCC_PIPELINE_PARALLEL;

}
}

ucc_pipeline_nfrags_pdepth(&cfg->allreduce_sra_kn_pipeline,
max_frag_count * dt_size, &n_frags,
&pipeline_depth);
ucc_status_t
ucc_tl_ucp_allreduce_sra_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_coll_args_t *args = &coll_args->args;
size_t dt_size = ucc_dt_size(args->dst.info.datatype);
int n_frags, pipeline_depth;
ucc_schedule_pipelined_t *schedule_p;
ucc_status_t st;
ucc_base_coll_args_t bargs;
size_t max_frag_count;
ucc_pipeline_params_t pipeline_params;

st = ucc_tl_ucp_get_schedule(tl_team, coll_args,
(ucc_tl_ucp_schedule_t **)&schedule_p);
if (ucc_unlikely(UCC_OK != st)) {
return st;
}

bargs = *coll_args;
max_frag_count = (bargs.mask & UCC_BASE_CARGS_MAX_FRAG_COUNT) ?
bargs.max_frag_count: args->dst.info.count;
ucc_tl_ucp_allreduce_sra_knomial_get_pipeline_params(tl_team, args,
&pipeline_params);
ucc_pipeline_nfrags_pdepth(&pipeline_params, max_frag_count * dt_size,
&n_frags, &pipeline_depth);
if (n_frags > 1) {
bargs.mask |= UCC_BASE_CARGS_MAX_FRAG_COUNT;
bargs.max_frag_count =
ucc_buffer_block_count(max_frag_count, n_frags, 0);
bargs.mask |= UCC_BASE_CARGS_MAX_FRAG_COUNT;
bargs.max_frag_count = ucc_buffer_block_count(max_frag_count, n_frags, 0);
}

status = ucc_schedule_pipelined_init(
&bargs, team, ucc_tl_ucp_allreduce_sra_knomial_frag_init,
ucc_tl_ucp_allreduce_sra_knomial_frag_setup, pipeline_depth, n_frags,
cfg->allreduce_sra_kn_pipeline.order, schedule_p);
if (UCC_OK != status) {
st = ucc_schedule_pipelined_init(&bargs, team,
ucc_tl_ucp_allreduce_sra_knomial_frag_init,
ucc_tl_ucp_allreduce_sra_knomial_frag_setup,
pipeline_depth, n_frags,
pipeline_params.order, schedule_p);
if (ucc_unlikely(UCC_OK != st)) {
tl_error(team->context->lib, "failed to init pipelined schedule");
ucc_tl_ucp_put_schedule(&schedule_p->super);
return status;
return st;
}
schedule_p->super.super.finalize =
ucc_tl_ucp_allreduce_sra_knomial_finalize;
schedule_p->super.super.post = ucc_tl_ucp_allreduce_sra_knomial_start;
*task_h = &schedule_p->super.super;

schedule_p->super.super.finalize = ucc_tl_ucp_allreduce_sra_knomial_finalize;
schedule_p->super.super.post = ucc_tl_ucp_allreduce_sra_knomial_start;
*task_h = &schedule_p->super.super;
return UCC_OK;
}
2 changes: 1 addition & 1 deletion src/components/tl/ucp/tl_ucp.c
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = {
ucc_offsetof(ucc_tl_ucp_lib_config_t, allreduce_sra_kn_radix),
UCC_CONFIG_TYPE_UINT_RANGED},

{"ALLREDUCE_SRA_KN_PIPELINE", "n",
{"ALLREDUCE_SRA_KN_PIPELINE", "auto",
"Pipelining settings for SRA Knomial allreduce algorithm",
ucc_offsetof(ucc_tl_ucp_lib_config_t, allreduce_sra_kn_pipeline),
UCC_CONFIG_TYPE_PIPELINE_PARAMS},
Expand Down

0 comments on commit 0db620d

Please sign in to comment.