Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Sep 14, 2023
1 parent 124caf8 commit f9571ac
Showing 1 changed file with 48 additions and 115 deletions.
163 changes: 48 additions & 115 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -910,13 +910,14 @@ void BroadcastKernelApply(const KPDevice &ctx,
}
#ifndef PADDLE_WITH_XPU_KP
constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3);
auto loader_classifier = LoaderTypeClassifier<OutT, kArity, Functor>();
// check whether need broadcast
auto compute_size = std::numeric_limits<int32_t>::max();
bool use_int64_index_kernel =
kEnabledInt64IndexKernel && (*outs)[0]->numel() >= compute_size;
std::cout << " out_num " << (*outs)[0]->numel() << " " << ins[0]->numel()
<< " scalar " << ins[0]->dims().size() << std::endl;
if (((*outs)[0])->numel() > compute_size) { // use_int64_index_kernel) {
auto loader_classifier = LoaderTypeClassifier<OutT, kArity, Functor>();
bool use_int64_index_kernel = kEnabledInt64IndexKernel &&
(*outs)[0]->numel() >= compute_size &&
(!loader_classifier.all_elementwise);

if (use_int64_index_kernel) { // use_int64_index_kernel) {
const auto dims_simplifier =
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
if (VLOG_IS_ON(6)) {
Expand All @@ -925,40 +926,36 @@ void BroadcastKernelApply(const KPDevice &ctx,
}

std::vector<int64_t> old_out_strides;
int all_rank = dims_simplifier.rank;
auto old_in_dims = dims_simplifier.in_dims;
auto old_out_dims = dims_simplifier.out_dims;

int all_rank = dims_simplifier.rank;
auto old_in_strides = dims_simplifier.in_dims;

old_out_strides.resize(all_rank);
old_out_strides[0] = 1;

std::vector<int64_t> ins_compute_size;

// for split
std::vector<int64_t> loop_num_out;
std::vector<int64_t> loop_num_out_stride;
loop_num_out.resize(all_rank);
loop_num_out_stride.resize(all_rank);

// for input's offset
std::vector<int64_t> ins_offset;
ins_offset.resize(kArity);
ins_compute_size.resize(kArity);
std::cout << " check 946 " << std::endl;

std::vector<int64_t> ins_scale_for_dim;

ins_offset.resize(kArity);
ins_scale_for_dim.resize(kArity);
for (int k = 0; k < kArity; k++) {
ins_scale_for_dim[k] = ins[k]->dims().size() == 0 ? 0 : 1;
ins_offset[k] = 0;
}

// init offset and check in's dim
for (int k = 0; k < kArity; k++) {
ins_offset[k] = 0;
ins_scale_for_dim[k] = ins[k]->dims().size() == 0 ? 0 : 1;
if (ins_scale_for_dim[k]) {
old_in_strides[k][0] = 1;
}
}

// out_dims and has been reversed in BroadcastDimsSimplifier
// out_dims and has been reversed in BroadcastDimsSimplifier
for (int i = 1; i < all_rank; i++) {
loop_num_out[i] = 1;
loop_num_out_stride[i] = 1;
Expand All @@ -971,6 +968,7 @@ void BroadcastKernelApply(const KPDevice &ctx,
}
}

// reverse old_in_dim and old_in_stride if in's dim_size > 0
for (int k = 0; k < kArity; k++) {
if (ins_scale_for_dim[k]) {
std::reverse(old_in_dims[k].begin(), old_in_dims[k].end());
Expand All @@ -980,36 +978,33 @@ void BroadcastKernelApply(const KPDevice &ctx,
std::reverse(old_out_dims.begin(), old_out_dims.end());
std::reverse(old_out_strides.begin(), old_out_strides.end());

std::cout << "out old stride : ";
for (int i = 0; i < all_rank; i++) {
std::cout << " " << old_out_strides[i];
}
std::cout << std::endl;

std::cout << "out_dims : ";
for (int i = 0; i < all_rank; i++) {
std::cout << " " << old_out_dims[i];
}
std::cout << std::endl;

// init out_split_dim and in_split_dims
auto out_split_dim = old_out_dims;
auto in_split_dims = old_in_dims;

// init
int64_t loop_num = 1;
int64_t split_idx = 0;
int64_t out_offset = 0;

for (int r = 0; r < all_rank; r++) {
// compute the split_dims
int64_t split_size = compute_size / old_out_strides[r];
// if the compute_size was too small the split_size must be 0, but the
// dim_num must ge 1
out_split_dim[r] = std::max(split_size, static_cast<int64_t>(1));
// get the split num of current dim
loop_num_out[r] =
(old_out_dims[r] + out_split_dim[r] - 1) / out_split_dim[r];
loop_num *= loop_num_out[r];

for (int k = 0; k < kArity; k++) {
// compute the split_dim of input if in's dim_size > 0
if (ins_scale_for_dim[k]) {
in_split_dims[k][r] = std::min(old_in_dims[k][r], out_split_dim[r]);
}
}

// split_idx is the index for lash split dim
if (split_size != 0) {
split_idx = r;
break;
Expand All @@ -1020,128 +1015,66 @@ void BroadcastKernelApply(const KPDevice &ctx,
for (int r = all_rank - 2; r >= 0; r--) {
loop_num_out_stride[r] = loop_num_out_stride[r + 1] * loop_num_out[r + 1];
}
std::cout << "split_idx " << split_idx << "loop_num_out : ";
for (int r = 0; r < all_rank; r++) {
std::cout << loop_num_out[r] << " ";
}
std::cout << std::endl;

std::cout << "loop_num_out_stride : ";
for (int r = 0; r < all_rank; r++) {
std::cout << loop_num_out_stride[r] << " ";
}
std::cout << std::endl;

std::cout << "out_split_dim : ";
for (int r = 0; r < all_rank; r++) {
std::cout << out_split_dim[r] << " ";
}
std::cout << std::endl;

for (int k = 0; k < kArity; k++) {
std::cout << k << " in_split_dims : ";
for (int r = 0; r < all_rank; r++) {
std::cout << in_split_dims[k][r] << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;

for (int k = 0; k < kArity; k++) {
std::cout << k << " in_split_dims : ";
for (int r = 0; r < all_rank; r++) {
if (ins_scale_for_dim[k]) {
std::cout << old_in_dims[k][r] << " ";
}
}
std::cout << std::endl;
}
std::cout << std::endl;

// compute
int64_t out_compute_size = 0;
DenseTensor tmp_in[kArity];
DenseTensor tmp_out[NumOuts];

for (int iter = 0; iter < loop_num; iter++) {
std::vector<const DenseTensor *> new_ins = {};
std::vector<DenseTensor *> new_outs = {};
phi::DenseTensor tmp_in[kArity];

int64_t tmp_size = iter;
out_offset = 0;
// out
int64_t out_offset = 0;
// compute the offset before last split dim
for (int i = 0; i < split_idx; i++) {
auto rp = tmp_size / loop_num_out_stride[i];
out_offset += rp * old_out_strides[i];
auto repeat_times = tmp_size / loop_num_out_stride[i];
out_offset += repeat_times * old_out_strides[i];
for (int k = 0; k < kArity; k++) {
if (ins_scale_for_dim[k]) {
std::cout << k << " kArity " << kArity << "ins_offset "
<< ins_offset[k] << " rp " << rp % old_in_dims[k][i]
<< " " << old_in_dims[k][i] << std::endl;
ins_offset[k] += (rp % old_in_dims[k][i]) * old_in_strides[k][i];
ins_offset[k] +=
(repeat_times % old_in_dims[k][i]) * old_in_strides[k][i];
}
}
tmp_size = tmp_size % loop_num_out_stride[i];
}
std::cout << "out_offset " << out_offset << "tmp_size * out_off"
<< tmp_size * out_split_dim[split_idx] *
old_out_strides[split_idx]
<< std::endl;
// for split_idx
out_offset +=
tmp_size * out_split_dim[split_idx] * old_out_strides[split_idx];
// tmp_size is the last split_dims's repeat idx
auto pre_deal_size = tmp_size * out_split_dim[split_idx];
out_offset += pre_deal_size * old_out_strides[split_idx];
// compute_size
auto tmp_out_compute_size =
old_out_dims[split_idx] - tmp_size * out_split_dim[split_idx];
out_compute_size =
std::min(out_split_dim[split_idx], tmp_out_compute_size);
std::cout << "out_compute_size " << out_compute_size
<< " out_split_dim[split_idx] " << out_split_dim[split_idx]
<< " tmp_out_compute_size " << tmp_out_compute_size
<< " tmp_size " << tmp_size << std::endl;
auto remainder_size = old_out_dims[split_idx] - pre_deal_size;

// get current compute size
auto out_compute_dims = out_split_dim;
out_compute_dims[split_idx] =
std::min(out_split_dim[split_idx], remainder_size);

// in + compute_size
auto in_compute_dims = in_split_dims;
for (int k = 0; k < kArity; k++) {
if (ins_scale_for_dim[k]) {
std::cout << "ins_offset " << ins_offset[k] << " tmp_size "
<< tmp_size << " " << old_in_dims[k][split_idx]
<< " in_split_dims " << in_split_dims[k][split_idx] << " "
<< old_in_strides[k][split_idx] << std::endl;
auto split_repeat =
old_in_dims[k][split_idx] == old_out_dims[split_idx] ? tmp_size
: 0;
ins_offset[k] += split_repeat * in_split_dims[k][split_idx] *
old_in_strides[k][split_idx];
ins_compute_size[k] =
std::min(in_split_dims[k][split_idx], tmp_out_compute_size);
in_compute_dims[k][split_idx] = ins_compute_size[k];

std::cout << "out_offset " << out_offset << std::endl;
std::cout << "ins_offset " << ins_offset[k] << " " << k << " dtype "
<< SizeOf(ins[k]->dtype()) << " " << sizeof(ins[k]->dtype())
<< " in_compute_dims : ";
for (int r = 0; r < all_rank; r++)
std::cout << in_compute_dims[k][r] << " ";
std::cout << std::endl;
in_compute_dims[k][split_idx] = std::min(in_split_dims[k][split_idx],
out_compute_dims[split_idx]);
}
auto new_dim = make_ddim(in_compute_dims[k]);
DenseTensorMeta meta(
ins[k]->dtype(),
new_dim,
ins[k]->layout(),
ins_scale_for_dim[k] * ins_offset[k] * SizeOf(ins[k]->dtype()));
std::cout << " new offset " << meta.offset << std::endl;
tmp_in[k].set_meta(meta);
tmp_in[k].ShareBufferWith(*(ins[k]), true);
tmp_in[k].Resize(new_dim);
new_ins.emplace_back(&tmp_in[k]);
ins_offset[k] = 0;
}

auto out_compute_dims = out_split_dim;
out_compute_dims[split_idx] = out_compute_size;
for (int r = 0; r < all_rank; r++)
std::cout << out_compute_dims[r] << " ";
std::cout << std::endl;
for (int n = 0; n < NumOuts; n++) {
auto new_dim = make_ddim(out_compute_dims);
DenseTensorMeta meta((*outs)[n]->dtype(),
Expand Down

0 comments on commit f9571ac

Please sign in to comment.