Skip to content

Commit

Permalink
benchdnn: matmul: make reference code more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Sep 19, 2024
1 parent 65c5086 commit 6924ec9
Showing 1 changed file with 51 additions and 90 deletions.
141 changes: 51 additions & 90 deletions tests/benchdnn/matmul/ref_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,130 +61,91 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) {
const int dst_zp_mask = attr_t::get_default_mask(
prb->attr.zero_points.get(DNNL_ARG_DST).policy);

const auto &src_scale_groups = prb->attr.scales.get(DNNL_ARG_SRC).groups;
const auto &wei_scale_groups
= prb->attr.scales.get(DNNL_ARG_WEIGHTS).groups;
const auto &wei_zp_groups
= prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).groups;

const int64_t M = prb->m;
const int64_t N = prb->n;
const int64_t K = prb->k;
const int64_t MB = prb->mb;
const int batch_ndims = dst_m.ndims() - 2;

const bool src_scale_per_k = src_scale_mask & (1 << (src_m.ndims() - 1));
const bool src_scale_per_m = src_scale_mask & (1 << (src_m.ndims() - 2));
const auto src_scale_groups = prb->attr.scales.get(DNNL_ARG_SRC).groups;
const int64_t src_scale_group_k
= !src_scale_groups.empty() ? src_scale_groups[1] : 1;
const int64_t src_scale_stride_k = src_scale_per_k ? 1 : 0;
const int64_t src_scale_stride_m = src_scale_per_m
? src_scale_per_k ? (K / src_scale_group_k) : 1
: 0;

const bool wei_scale_per_n = wei_scale_mask & (1 << (wei_m.ndims() - 1));
const bool wei_scale_per_k = wei_scale_mask & (1 << (wei_m.ndims() - 2));
const auto wei_scale_groups = prb->attr.scales.get(DNNL_ARG_WEIGHTS).groups;
const int64_t wei_scale_group_k
= !wei_scale_groups.empty() ? wei_scale_groups[0] : 1;
const int64_t wei_scale_stride_n = wei_scale_per_n ? 1 : 0;
const int64_t wei_scale_stride_k
= wei_scale_per_k ? wei_scale_per_n ? N : 1 : 0;

const bool wei_zp_per_n = wei_zp_mask & (1 << (wei_m.ndims() - 1));
const bool wei_zp_per_k = wei_zp_mask & (1 << (wei_m.ndims() - 2));
const int64_t wei_zp_stride_n = wei_zp_per_n ? 1 : 0;
const int64_t wei_zp_stride_k = wei_zp_per_k ? wei_zp_per_n ? N : 1 : 0;
const auto wei_zp_groups
= prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).groups;
const int64_t wei_zp_group_k
= !wei_zp_groups.empty() ? wei_zp_groups[0] : 1;

const bool wei_decompression = prb->weights_decompression();
const bool apply_scales_in_ker
= wei_decompression || wei_scale_per_k || src_scale_per_k;

// Fast return if any dim is zero. Common logic doesn't apply because of
// broadcast semantics.
for (int d = 0; d < dst_m.ndims(); d++) {
if (prb->src_dims()[d] == 0 || prb->weights_dims()[d] == 0) return;
}

dnn_mem_t dst_tmp(dst_m, dnnl_f32, tag::abx, dst_m.engine());

const auto src_broadcast_mask = prb->src_broadcast_mask();
const auto wei_broadcast_mask = prb->weights_broadcast_mask();
const auto bias_broadcast_mask = prb->bias_broadcast_mask();
auto v_po_masks = prb->attr.post_ops.get_po_masks();

benchdnn_parallel_nd(MB, M, N, [&](int64_t mb, int64_t m, int64_t n) {
auto src = (const float *)src_m;
auto wei = (const float *)wei_m;

float dst = 0;
const int64_t src_mb
= dst_m.get_idx(mb, src_broadcast_mask, batch_ndims);
const int64_t wei_mb
= dst_m.get_idx(mb, wei_broadcast_mask, batch_ndims);

for (int64_t k = 0; k < K; ++k) {
int src_zp = has_src_zp ? src_zps.get_elem(src_zp_mask > 0 ? k : 0)
: 0;
int wei_zp = has_wei_zp
? wei_zps.get_elem(wei_zp_stride_k * (k / wei_zp_group_k)
+ wei_zp_stride_n * n)
: 0;
auto s = src[src_off_f(prb, src_mb, m, k)] - src_zp;
auto w = wei[wei_off_f(prb, wei_mb, k, n)] - wei_zp;
// Compression scaling happens before the matmul, unlike regular
// quantization, to preserve the accuracy.
// Also, regular quantized matmul can have per group K-dim scales
// which require handling inside the kernel.
if (apply_scales_in_ker) {
if (has_src_scale) {
float src_scale = src_scales.get_elem(
src_scale_stride_k * (k / src_scale_group_k)
+ src_scale_stride_m * m);
s *= src_scale;
}
if (has_wei_scale) {
float wei_scale = wei_scales.get_elem(
wei_scale_stride_k * (k / wei_scale_group_k)
+ wei_scale_stride_n * n);
w *= wei_scale;
}
const auto src_off = src_off_f(prb, src_mb, m, k);
const auto wei_off = wei_off_f(prb, wei_mb, k, n);

int src_zp = 0;
if (has_src_zp) {
const auto src_zp_idx = src_m.get_idx(src_off, src_zp_mask);
src_zp = src_zps.get_elem(src_zp_idx);
}
int wei_zp = 0;
if (has_wei_zp) {
const auto wei_zp_idx = wei_m.get_idx(
wei_off, wei_zp_mask, wei_m.ndims(), wei_zp_groups);
wei_zp = wei_zps.get_elem(wei_zp_idx);
}
dst += s * w;
}
((float *)dst_tmp)[dst_off_f(prb, mb, m, n)] = dst;
});

auto v_po_masks = prb->attr.post_ops.get_po_masks();
const auto bias_broadcast_mask = prb->bias_broadcast_mask();
benchdnn_parallel_nd(MB, M, N, [&](int64_t mb, int64_t m, int64_t n) {
size_t dst_off = dst_off_f(prb, mb, m, n);
float &dst = ((float *)dst_m)[dst_off];

float wei_scale = 1.f;
float src_scale = 1.f;
if (!apply_scales_in_ker) {
assert(IMPLICATION(has_src_scale, src_scales.nelems() == 1));
if (has_src_scale) { src_scale = src_scales.get_elem(0); }
float src_scale = 1.f;
if (has_src_scale) {
const auto src_scale_idx = src_m.get_idx(src_off,
src_scale_mask, src_m.ndims(), src_scale_groups);
src_scale = src_scales.get_elem(src_scale_idx);
}
float wei_scale = 1.f;
if (has_wei_scale) {
wei_scale = wei_scales.get_elem(wei_scale_mask > 0 ? n : 0);
const auto wei_scale_idx = wei_m.get_idx(wei_off,
wei_scale_mask, wei_m.ndims(), wei_scale_groups);
wei_scale = wei_scales.get_elem(wei_scale_idx);
}
}

float tmp = ((float *)dst_tmp)[dst_off] * src_scale * wei_scale;
auto s = src_scale * (src_m.get_elem(src_off) - src_zp);
auto w = wei_scale * (wei_m.get_elem(wei_off) - wei_zp);

dst += s * w;
}

const auto dst_off = dst_off_f(prb, mb, m, n);
if (prb->bia_dt != dnnl_data_type_undef) {
int64_t bia_off = dst_m.get_idx(dst_off, bias_broadcast_mask);
float *bia_ptr = (float *)bia_m;
tmp += bia_ptr[bia_off];
const auto bia_idx = dst_m.get_idx(dst_off, bias_broadcast_mask);
dst += bia_m.get_elem(bia_idx);
}

const auto v_po_vals
= prepare_po_vals(dst_m, args, v_po_masks, dst_off);

maybe_dropout(prb->attr, tmp, dst_off, dropout);
maybe_post_ops(prb->attr, tmp, dst, v_po_vals);

int dst_zp = has_dst_zp ? dst_zps.get_elem(dst_zp_mask > 0 ? n : 0) : 0;
dst = tmp * dst_scale + dst_zp;
maybe_round(prb->attr, DNNL_ARG_DST, dst, dst_off, prb->dst_dt());
maybe_dropout(prb->attr, dst, dst_off, dropout);
const auto sum_val = dst_m.get_elem(dst_off);
maybe_post_ops(prb->attr, dst, sum_val, v_po_vals);

int dst_zp = 0;
if (has_dst_zp) {
const auto dst_zp_idx = dst_m.get_idx(dst_off, dst_zp_mask);
dst_zp = dst_zps.get_elem(dst_zp_idx);
}
float dst_val = dst_scale * dst + dst_zp;
maybe_round(prb->attr, DNNL_ARG_DST, dst_val, dst_off, prb->dst_dt());
dst_m.set_elem(dst_off, dst_val);
});
}

Expand Down

0 comments on commit 6924ec9

Please sign in to comment.