Skip to content

Commit

Permalink
generic sycl: refactor kernel mainloops
Browse files Browse the repository at this point in the history
  • Loading branch information
t4c1 authored and densamoilov committed Sep 20, 2024
1 parent 6924ec9 commit 7131c1b
Show file tree
Hide file tree
Showing 24 changed files with 276 additions and 528 deletions.
76 changes: 31 additions & 45 deletions src/gpu/generic/sycl/binary_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,6 @@ struct binary_kernel_vec_t {
memory_plain_t src0_scale_mem(src0_scale_, scales_dt_);
memory_plain_t src1_scale_mem(src1_scale_, scales_dt_);

auto sg = item.get_sub_group();
size_t wg_offset_t = item.get_group(0) * conf_.wg_size;
size_t sg_offset_t = sg.get_group_id()[0] * sg.get_local_range()[0];
size_t wi_offset_t = sg.get_local_id();
size_t offset_t = wg_offset_t + sg_offset_t + wi_offset_t;

size_t base_idx = offset_t * conf_.block_size;
size_t vec_base_idx = base_idx / vec_len;

size_t sg_base_idx = (wg_offset_t + sg_offset_t) * conf_.block_size;

const float sm_0 = (conf_.do_scale_src0 ? src0_scale_mem.load(0) : 1.f);

const float sm_1 = (conf_.do_scale_src1 ? src1_scale_mem.load(0) : 1.f);
Expand Down Expand Up @@ -98,12 +87,12 @@ struct binary_kernel_vec_t {

if (!any_broadcast && !is_blocked_fmt
&& conf_.post_ops.get_post_op() == 0
&& sg_base_idx + (sg.get_local_range()[0] * conf_.block_size)
< conf_.wk_size
&& is_same_tag) {
for (int i = 0; i < conf_.block_size / vec_len; i++) {
auto src0_vec = src0_mem.load_vec<vec_len>(vec_base_idx + i);
auto src1_vec = src1_mem.load_vec<vec_len>(vec_base_idx + i);
&& conf_.wk_size % vec_len == 0 && is_same_tag) {
for (int vec_idx = item.get_global_id(0);
vec_idx < conf_.wk_size / vec_len;
vec_idx += item.get_global_range(0)) {
auto src0_vec = src0_mem.load_vec<vec_len>(vec_idx);
auto src1_vec = src1_mem.load_vec<vec_len>(vec_idx);

if (conf_.do_scale_src0)
src0_vec *= ::sycl::vec<float, vec_len>(sm_0);
Expand All @@ -114,37 +103,34 @@ struct binary_kernel_vec_t {
// TODO: Adding post-ops seems to be interfering with compiler's
// optimizations. Figure out how to make the compiler to generate
// the right code.
dst_mem.store_vec(acc_vec, vec_base_idx + i);
dst_mem.store_vec(acc_vec, vec_idx);
}
} else {
for (int i = 0; i < conf_.block_size; i++) {
int idx = base_idx + i;
if (idx < conf_.wk_size) {
auto l_offset = idx;
for (int i = 0; i < conf_.ndims; i++) {
const int d = conf_.ndims - 1 - i;
const dim_t cur_dim = conf_.dst_md.dims()[d];
off_dst[d] = l_offset % cur_dim;
l_offset = l_offset / cur_dim;
}

for (int i = 0; i < max_supported_ndims; i++) {
off0[i] = conf_.broadcast_dims0[i] ? 0 : off_dst[i];
off1[i] = conf_.broadcast_dims1[i] ? 0 : off_dst[i];
}

auto src0 = src0_mem.load_md(off0);
auto src1 = src1_mem.load_md(off1);

if (conf_.do_scale_src0) src0 *= sm_0;
if (conf_.do_scale_src1) src1 *= sm_1;

auto acc = compute_alg_n(src0, src1, conf_.alg_kind);

acc = conf_.post_ops.apply(
acc, dst_, idx, po_args_, off_dst);
dst_mem.store_md(acc, off_dst);
for (int idx = item.get_global_id(0); idx < conf_.wk_size;
idx += item.get_global_range(0)) {
auto l_offset = idx;
for (int i = 0; i < conf_.ndims; i++) {
const int d = conf_.ndims - 1 - i;
const dim_t cur_dim = conf_.dst_md.dims()[d];
off_dst[d] = l_offset % cur_dim;
l_offset = l_offset / cur_dim;
}

for (int i = 0; i < max_supported_ndims; i++) {
off0[i] = conf_.broadcast_dims0[i] ? 0 : off_dst[i];
off1[i] = conf_.broadcast_dims1[i] ? 0 : off_dst[i];
}

auto src0 = src0_mem.load_md(off0);
auto src1 = src1_mem.load_md(off1);

if (conf_.do_scale_src0) src0 *= sm_0;
if (conf_.do_scale_src1) src1 *= sm_1;

auto acc = compute_alg_n(src0, src1, conf_.alg_kind);

acc = conf_.post_ops.apply(acc, dst_, idx, po_args_, off_dst);
dst_mem.store_md(acc, off_dst);
}
}
}
Expand Down
158 changes: 73 additions & 85 deletions src/gpu/generic/sycl/convolution_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,6 @@ struct convolution_kernel_fwd_t {
: data_type_t::dnnl_f32) {}

void operator()(::sycl::nd_item<1> item) const {
auto sg = item.get_sub_group();
size_t wg_offset_t = item.get_group(0) * conf_.wg_size;
size_t sg_offset_t = sg.get_group_id()[0] * sg.get_local_range()[0];
size_t wi_offset_t = sg.get_local_id();
size_t offset_t = wg_offset_t + sg_offset_t + wi_offset_t;

size_t base_idx = offset_t * conf_.block_size;

const float sm_data = (conf_.do_scale_data
? load_float_value(scales_data_dt_, data_scale_ptr(), 0)
: 1.f);
Expand Down Expand Up @@ -132,94 +124,90 @@ struct convolution_kernel_fwd_t {
const int DH = conf_.dilation[1];
const int DW = conf_.dilation[2];

for (int i = 0; i < conf_.block_size; i++) {
int idx = base_idx + i;
if (idx < conf_.wk_size) {
for (int i = 0; i < max_supported_ndims; i++) {
off[i] = idx / dst_strides[i] % dst_dims[i];
}
for (int idx = item.get_global_id(0); idx < conf_.wk_size;
idx += item.get_global_range(0)) {
for (int i = 0; i < max_supported_ndims; i++) {
off[i] = idx / dst_strides[i] % dst_dims[i];
}

const int n = off[0];
const int oc_tot = off[1];
const int oc = oc_tot % OC;
const int g = oc_tot / OC;

const int od = off[2];
const int oh = off[3];
const int ow = off[4];

float accumulator = 0;
for (int ic = 0; ic < IC; ++ic) {
for (int kd = 0; kd < KD; ++kd) {
for (int kh = 0; kh < KH; ++kh) {
for (int kw = 0; kw < KW; ++kw) {
const int id = od * SD - PD + kd * (1 + DD);
const int ih = oh * SH - PH + kh * (1 + DH);
const int iw = ow * SW - PW + kw * (1 + DW);

if (id < 0 || id >= data_dims[2] || ih < 0
|| ih >= data_dims[3] || iw < 0
|| iw >= data_dims[4]) {
continue;
}
const int n = off[0];
const int oc_tot = off[1];
const int oc = oc_tot % OC;
const int g = oc_tot / OC;

dims_t off_data {n, g * IC + ic, id, ih, iw};
const int data_idx = data_md().off_v(off_data);
dims_t off_weights {g, oc, ic, kd, kh, kw};
dims_t off_weights_no_groups {
oc, ic, kd, kh, kw};
const int weights_idx = weights_md().off_v(
no_groups ? off_weights_no_groups
: off_weights);

auto data = load_float_value(
data_md().data_type(), data_ptr(),
data_idx);
auto weight = load_float_value(
weights_md().data_type(), weights_ptr(),
weights_idx);

if (conf_.use_data_zeropoints) {
int zpoint_idx = conf_.single_data_zeropoint
? 0
: g * IC + ic;
auto data_zeropoint = load_float_value(
zeropoints_data_dt_,
data_zeropoint_ptr(), zpoint_idx);
data -= data_zeropoint;
}
accumulator += data * weight;
const int od = off[2];
const int oh = off[3];
const int ow = off[4];

float accumulator = 0;
for (int ic = 0; ic < IC; ++ic) {
for (int kd = 0; kd < KD; ++kd) {
for (int kh = 0; kh < KH; ++kh) {
for (int kw = 0; kw < KW; ++kw) {
const int id = od * SD - PD + kd * (1 + DD);
const int ih = oh * SH - PH + kh * (1 + DH);
const int iw = ow * SW - PW + kw * (1 + DW);

if (id < 0 || id >= data_dims[2] || ih < 0
|| ih >= data_dims[3] || iw < 0
|| iw >= data_dims[4]) {
continue;
}

dims_t off_data {n, g * IC + ic, id, ih, iw};
const int data_idx = data_md().off_v(off_data);
dims_t off_weights {g, oc, ic, kd, kh, kw};
dims_t off_weights_no_groups {oc, ic, kd, kh, kw};
const int weights_idx = weights_md().off_v(no_groups
? off_weights_no_groups
: off_weights);

auto data = load_float_value(data_md().data_type(),
data_ptr(), data_idx);
auto weight
= load_float_value(weights_md().data_type(),
weights_ptr(), weights_idx);

if (conf_.use_data_zeropoints) {
int zpoint_idx = conf_.single_data_zeropoint
? 0
: g * IC + ic;
auto data_zeropoint = load_float_value(
zeropoints_data_dt_,
data_zeropoint_ptr(), zpoint_idx);
data -= data_zeropoint;
}
accumulator += data * weight;
}
}
}
if (conf_.do_scale_data) { accumulator *= sm_data; }
if (conf_.do_scale_weights) {
if (!conf_.single_weight_scale) {
sm_weights = load_float_value(scales_weights_dt_,
weights_scale_ptr(), oc_tot);
}
accumulator *= sm_weights;
}
if (conf_.do_scale_data) { accumulator *= sm_data; }
if (conf_.do_scale_weights) {
if (!conf_.single_weight_scale) {
sm_weights = load_float_value(
scales_weights_dt_, weights_scale_ptr(), oc_tot);
}
accumulator *= sm_weights;
}

if (bias_md().ndims() != 0) {
auto bias = load_float_value(
bias_md().data_type(), bias_ptr(), oc_tot);
accumulator += bias;
}
if (bias_md().ndims() != 0) {
auto bias = load_float_value(
bias_md().data_type(), bias_ptr(), oc_tot);
accumulator += bias;
}

accumulator = conf_.post_ops.apply(accumulator, dst_, idx);
accumulator = conf_.post_ops.apply(accumulator, dst_, idx);

if (conf_.do_scale_dst) { accumulator /= sm_dst; }
if (conf_.use_dst_zeropoints) {
int zpoint_idx = conf_.single_dst_zeropoint ? 0 : oc_tot;
auto dst_zeropoint = load_float_value(zeropoints_dst_dt_,
dst_zeropoint_ptr(), zpoint_idx);
accumulator += dst_zeropoint;
}
store_float_value(
dst_md().data_type(), accumulator, dst_ptr(), idx);
if (conf_.do_scale_dst) { accumulator /= sm_dst; }
if (conf_.use_dst_zeropoints) {
int zpoint_idx = conf_.single_dst_zeropoint ? 0 : oc_tot;
auto dst_zeropoint = load_float_value(
zeropoints_dst_dt_, dst_zeropoint_ptr(), zpoint_idx);
accumulator += dst_zeropoint;
}
store_float_value(
dst_md().data_type(), accumulator, dst_ptr(), idx);
}
}

Expand Down
63 changes: 22 additions & 41 deletions src/gpu/generic/sycl/eltwise_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,6 @@ struct eltwise_fwd_kernel_vec_t {
memory_tensor_t src_mem(src_, conf_.src_md);
memory_tensor_t dst_mem(dst_, conf_.dst_md);

auto sg = item.get_sub_group();
size_t wg_offset_t = item.get_group(0) * conf_.wg_size;
size_t sg_offset_t = sg.get_group_id()[0] * sg.get_local_range()[0];
size_t wi_offset_t = sg.get_local_id();
size_t offset_t = wg_offset_t + sg_offset_t + wi_offset_t;

size_t base_idx = offset_t * conf_.block_size;

auto operation = [&](dim_t &idx, dim_t &n, dim_t &c, dim_t &d, dim_t &h,
dim_t &w) {
dim_t src_offset = data_offset(src_mem.md(), n, c, d, h, w);
Expand All @@ -72,22 +64,20 @@ struct eltwise_fwd_kernel_vec_t {
dst_mem.store(acc, src_offset);
};

for (dim_t blk_idx = 0; blk_idx < conf_.block_size; blk_idx++) {
dim_t idx = base_idx + blk_idx;
if (idx < conf_.wk_size) {
dim_t N = conf_.mb;
dim_t C = conf_.c;
dim_t D = conf_.d;
dim_t H = conf_.h;
dim_t W = conf_.w;

dim_t n = (idx / (C * D * H * W)) % N;
dim_t c = (idx / (D * H * W)) % C;
dim_t d = (idx / (H * W)) % D;
dim_t h = (idx / (W)) % H;
dim_t w = (idx / (1)) % W;
operation(idx, n, c, d, h, w);
}
for (dim_t idx = item.get_global_id(0); idx < conf_.wk_size;
idx += item.get_global_range(0)) {
dim_t N = conf_.mb;
dim_t C = conf_.c;
dim_t D = conf_.d;
dim_t H = conf_.h;
dim_t W = conf_.w;

dim_t n = (idx / (C * D * H * W)) % N;
dim_t c = (idx / (D * H * W)) % C;
dim_t d = (idx / (H * W)) % D;
dim_t h = (idx / (W)) % H;
dim_t w = (idx / (1)) % W;
operation(idx, n, c, d, h, w);
}
}

Expand Down Expand Up @@ -221,23 +211,14 @@ struct eltwise_bwd_kernel_vec_t {
memory_tensor_t diff_src_mem(diff_src_, conf_.diff_src_md);
memory_tensor_t diff_dst_mem(diff_dst_, conf_.diff_dst_md);

auto sg = item.get_sub_group();
size_t wg_offset_t = item.get_group(0) * conf_.wg_size;
size_t sg_offset_t = sg.get_group_id()[0] * sg.get_local_range()[0];
size_t wi_offset_t = sg.get_local_id();
size_t offset_t = wg_offset_t + sg_offset_t + wi_offset_t;
size_t base_idx = offset_t * conf_.block_size;

for (dim_t i = 0; i < conf_.block_size; i++) {
dim_t idx = base_idx + i;
if (idx < conf_.wk_size) {
auto diff_src = diff_src_mem.load(idx);
auto src = src_mem.load(idx);

auto dst = compute_alg_n(
diff_src, src, conf_.alpha, conf_.beta, conf_.alg_kind);
diff_dst_mem.store(dst, idx);
}
for (dim_t idx = item.get_global_id(0); idx < conf_.wk_size;
idx += item.get_global_range(0)) {
auto diff_src = diff_src_mem.load(idx);
auto src = src_mem.load(idx);

auto dst = compute_alg_n(
diff_src, src, conf_.alpha, conf_.beta, conf_.alg_kind);
diff_dst_mem.store(dst, idx);
}
}

Expand Down
Loading

0 comments on commit 7131c1b

Please sign in to comment.