From 6924ec973efc832f52f49ca5e390d91f1fd80d6e Mon Sep 17 00:00:00 2001 From: Dmitrii Zarukin Date: Tue, 10 Sep 2024 16:12:23 -0700 Subject: [PATCH] benchdnn: matmul: make reference code more generic --- tests/benchdnn/matmul/ref_matmul.cpp | 141 ++++++++++----------------- 1 file changed, 51 insertions(+), 90 deletions(-) diff --git a/tests/benchdnn/matmul/ref_matmul.cpp b/tests/benchdnn/matmul/ref_matmul.cpp index daf4e9246ad..52f30584f0e 100644 --- a/tests/benchdnn/matmul/ref_matmul.cpp +++ b/tests/benchdnn/matmul/ref_matmul.cpp @@ -61,59 +61,30 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) { const int dst_zp_mask = attr_t::get_default_mask( prb->attr.zero_points.get(DNNL_ARG_DST).policy); + const auto &src_scale_groups = prb->attr.scales.get(DNNL_ARG_SRC).groups; + const auto &wei_scale_groups + = prb->attr.scales.get(DNNL_ARG_WEIGHTS).groups; + const auto &wei_zp_groups + = prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).groups; + const int64_t M = prb->m; const int64_t N = prb->n; const int64_t K = prb->k; const int64_t MB = prb->mb; const int batch_ndims = dst_m.ndims() - 2; - const bool src_scale_per_k = src_scale_mask & (1 << (src_m.ndims() - 1)); - const bool src_scale_per_m = src_scale_mask & (1 << (src_m.ndims() - 2)); - const auto src_scale_groups = prb->attr.scales.get(DNNL_ARG_SRC).groups; - const int64_t src_scale_group_k - = !src_scale_groups.empty() ? src_scale_groups[1] : 1; - const int64_t src_scale_stride_k = src_scale_per_k ? 1 : 0; - const int64_t src_scale_stride_m = src_scale_per_m - ? src_scale_per_k ? (K / src_scale_group_k) : 1 - : 0; - - const bool wei_scale_per_n = wei_scale_mask & (1 << (wei_m.ndims() - 1)); - const bool wei_scale_per_k = wei_scale_mask & (1 << (wei_m.ndims() - 2)); - const auto wei_scale_groups = prb->attr.scales.get(DNNL_ARG_WEIGHTS).groups; - const int64_t wei_scale_group_k - = !wei_scale_groups.empty() ? wei_scale_groups[0] : 1; - const int64_t wei_scale_stride_n = wei_scale_per_n ? 1 : 0; - const int64_t wei_scale_stride_k - = wei_scale_per_k ? wei_scale_per_n ? N : 1 : 0; - - const bool wei_zp_per_n = wei_zp_mask & (1 << (wei_m.ndims() - 1)); - const bool wei_zp_per_k = wei_zp_mask & (1 << (wei_m.ndims() - 2)); - const int64_t wei_zp_stride_n = wei_zp_per_n ? 1 : 0; - const int64_t wei_zp_stride_k = wei_zp_per_k ? wei_zp_per_n ? N : 1 : 0; - const auto wei_zp_groups - = prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).groups; - const int64_t wei_zp_group_k - = !wei_zp_groups.empty() ? wei_zp_groups[0] : 1; - - const bool wei_decompression = prb->weights_decompression(); - const bool apply_scales_in_ker - = wei_decompression || wei_scale_per_k || src_scale_per_k; - // Fast return if any dim is zero. Common logic doesn't apply because of // broadcast semantics. for (int d = 0; d < dst_m.ndims(); d++) { if (prb->src_dims()[d] == 0 || prb->weights_dims()[d] == 0) return; } - dnn_mem_t dst_tmp(dst_m, dnnl_f32, tag::abx, dst_m.engine()); - const auto src_broadcast_mask = prb->src_broadcast_mask(); const auto wei_broadcast_mask = prb->weights_broadcast_mask(); + const auto bias_broadcast_mask = prb->bias_broadcast_mask(); + auto v_po_masks = prb->attr.post_ops.get_po_masks(); benchdnn_parallel_nd(MB, M, N, [&](int64_t mb, int64_t m, int64_t n) { - auto src = (const float *)src_m; - auto wei = (const float *)wei_m; - float dst = 0; const int64_t src_mb = dst_m.get_idx(mb, src_broadcast_mask, batch_ndims); @@ -121,70 +92,60 @@ void compute_ref_matmul(const prb_t *prb, const args_t &args) { = dst_m.get_idx(mb, wei_broadcast_mask, batch_ndims); for (int64_t k = 0; k < K; ++k) { - int src_zp = has_src_zp ? src_zps.get_elem(src_zp_mask > 0 ? k : 0) - : 0; - int wei_zp = has_wei_zp - ? wei_zps.get_elem(wei_zp_stride_k * (k / wei_zp_group_k) - + wei_zp_stride_n * n) - : 0; - auto s = src[src_off_f(prb, src_mb, m, k)] - src_zp; - auto w = wei[wei_off_f(prb, wei_mb, k, n)] - wei_zp; - // Compression scaling happens before the matmul, unlike regular - // quantization, to preserve the accuracy. - // Also, regular quantized matmul can have per group K-dim scales - // which require handling inside the kernel. - if (apply_scales_in_ker) { - if (has_src_scale) { - float src_scale = src_scales.get_elem( - src_scale_stride_k * (k / src_scale_group_k) - + src_scale_stride_m * m); - s *= src_scale; - } - if (has_wei_scale) { - float wei_scale = wei_scales.get_elem( - wei_scale_stride_k * (k / wei_scale_group_k) - + wei_scale_stride_n * n); - w *= wei_scale; - } + const auto src_off = src_off_f(prb, src_mb, m, k); + const auto wei_off = wei_off_f(prb, wei_mb, k, n); + + int src_zp = 0; + if (has_src_zp) { + const auto src_zp_idx = src_m.get_idx(src_off, src_zp_mask); + src_zp = src_zps.get_elem(src_zp_idx); + } + int wei_zp = 0; + if (has_wei_zp) { + const auto wei_zp_idx = wei_m.get_idx( + wei_off, wei_zp_mask, wei_m.ndims(), wei_zp_groups); + wei_zp = wei_zps.get_elem(wei_zp_idx); } - dst += s * w; - } - ((float *)dst_tmp)[dst_off_f(prb, mb, m, n)] = dst; - }); - auto v_po_masks = prb->attr.post_ops.get_po_masks(); - const auto bias_broadcast_mask = prb->bias_broadcast_mask(); - benchdnn_parallel_nd(MB, M, N, [&](int64_t mb, int64_t m, int64_t n) { - size_t dst_off = dst_off_f(prb, mb, m, n); - float &dst = ((float *)dst_m)[dst_off]; - - float wei_scale = 1.f; - float src_scale = 1.f; - if (!apply_scales_in_ker) { - assert(IMPLICATION(has_src_scale, src_scales.nelems() == 1)); - if (has_src_scale) { src_scale = src_scales.get_elem(0); } + float src_scale = 1.f; + if (has_src_scale) { + const auto src_scale_idx = src_m.get_idx(src_off, + src_scale_mask, src_m.ndims(), src_scale_groups); + src_scale = src_scales.get_elem(src_scale_idx); + } + float wei_scale = 1.f; if (has_wei_scale) { - wei_scale = wei_scales.get_elem(wei_scale_mask > 0 ? n : 0); + const auto wei_scale_idx = wei_m.get_idx(wei_off, + wei_scale_mask, wei_m.ndims(), wei_scale_groups); + wei_scale = wei_scales.get_elem(wei_scale_idx); } - } - float tmp = ((float *)dst_tmp)[dst_off] * src_scale * wei_scale; + auto s = src_scale * (src_m.get_elem(src_off) - src_zp); + auto w = wei_scale * (wei_m.get_elem(wei_off) - wei_zp); + dst += s * w; + } + + const auto dst_off = dst_off_f(prb, mb, m, n); if (prb->bia_dt != dnnl_data_type_undef) { - int64_t bia_off = dst_m.get_idx(dst_off, bias_broadcast_mask); - float *bia_ptr = (float *)bia_m; - tmp += bia_ptr[bia_off]; + const auto bia_idx = dst_m.get_idx(dst_off, bias_broadcast_mask); + dst += bia_m.get_elem(bia_idx); } const auto v_po_vals = prepare_po_vals(dst_m, args, v_po_masks, dst_off); - - maybe_dropout(prb->attr, tmp, dst_off, dropout); - maybe_post_ops(prb->attr, tmp, dst, v_po_vals); - - int dst_zp = has_dst_zp ? dst_zps.get_elem(dst_zp_mask > 0 ? n : 0) : 0; - dst = tmp * dst_scale + dst_zp; - maybe_round(prb->attr, DNNL_ARG_DST, dst, dst_off, prb->dst_dt()); + maybe_dropout(prb->attr, dst, dst_off, dropout); + const auto sum_val = dst_m.get_elem(dst_off); + maybe_post_ops(prb->attr, dst, sum_val, v_po_vals); + + int dst_zp = 0; + if (has_dst_zp) { + const auto dst_zp_idx = dst_m.get_idx(dst_off, dst_zp_mask); + dst_zp = dst_zps.get_elem(dst_zp_idx); + } + float dst_val = dst_scale * dst + dst_zp; + maybe_round(prb->attr, DNNL_ARG_DST, dst_val, dst_off, prb->dst_dt()); + dst_m.set_elem(dst_off, dst_val); }); }