From 0a8116b3de98243a234680d8cda869d2f20dd178 Mon Sep 17 00:00:00 2001 From: "Wierschem, Keola" Date: Tue, 3 Jan 2023 13:19:44 -0800 Subject: [PATCH] cpu: x64: fix postops handling in brgconv postops kernel --- src/cpu/x64/jit_brgemm_conv.cpp | 5 ++--- src/cpu/x64/jit_brgemm_post_ops.hpp | 5 +++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/cpu/x64/jit_brgemm_conv.cpp b/src/cpu/x64/jit_brgemm_conv.cpp index 89659e5d09e..41e2ace91df 100644 --- a/src/cpu/x64/jit_brgemm_conv.cpp +++ b/src/cpu/x64/jit_brgemm_conv.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2022 Intel Corporation +* Copyright 2021-2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -411,8 +411,7 @@ status_t brgemm_convolution_fwd_t::add_po_kernel( bcfg->LDD = (is_init && jcp.use_buffer) ? jcp.LDC : jcp.LDD; bcfg->dt_c = (!is_init && jcp.use_buffer) ? jcp.acc_dt : jcp.dst_dt; // inp bcfg->dt_d = (is_init && jcp.use_buffer) ? jcp.acc_dt : jcp.dst_dt; // out - bcfg->alpha = (!is_init) - && (IMPLICATION(jcp.with_sum, jcp.use_buffer) || jcp.with_eltwise); + bcfg->alpha = is_init ? 0 : 1; bcfg->beta = is_init ? 0 : 1; CHECK(safe_ptr_assign(kernels_po_[ker_idx], new jit_brgemm_kernel_post_ops(jcp, *bcfg, *_pd->attr()))); diff --git a/src/cpu/x64/jit_brgemm_post_ops.hpp b/src/cpu/x64/jit_brgemm_post_ops.hpp index a20438d9e85..2a50ea46934 100644 --- a/src/cpu/x64/jit_brgemm_post_ops.hpp +++ b/src/cpu/x64/jit_brgemm_post_ops.hpp @@ -681,6 +681,11 @@ struct jit_brgemm_kernel_post_ops : public jit_generator { // if sum then have to init vmm each time uni_vpxor(vector(m, n), vector(m, n), vector(m, n)); } + } else if (!IMPLICATION(jcp.with_sum, jcp.use_buffer)) { + if (sum_idx != -1 && brg.beta != 0) { + // if sum without buffer then have to init vmm each time + uni_vpxor(vector(m, n), vector(m, n), vector(m, n)); + } } else { auto inp_addr = ptr[aux_reg_in + inp_typesize_ * (m * brg.LDC + n * brg.ld_block)];