From 2da6c82c4c5e01f9426315225f2e7d8fff4e5ca5 Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 30 Jun 2023 18:44:59 +0800 Subject: [PATCH] fix and opt++ --- .../convolution_pack4_gemm_cm_16_16_16.comp | 142 +++++++----------- 1 file changed, 58 insertions(+), 84 deletions(-) diff --git a/src/layer/vulkan/shader/convolution_pack4_gemm_cm_16_16_16.comp b/src/layer/vulkan/shader/convolution_pack4_gemm_cm_16_16_16.comp index bc3d951aeff..71cef19638c 100644 --- a/src/layer/vulkan/shader/convolution_pack4_gemm_cm_16_16_16.comp +++ b/src/layer/vulkan/shader/convolution_pack4_gemm_cm_16_16_16.comp @@ -69,7 +69,7 @@ layout (push_constant) uniform parameter int outcstep; } p; -#define UNROLL_INCH 4 +#define UNROLL_INCH 2 shared uvec2 tmp_v0[UNROLL_INCH * 16*4]; shared uvec2 tmp_v1[UNROLL_INCH * 16*4]; @@ -84,7 +84,9 @@ void main() const int outsize = psc(outw) * psc(outh); const int lx = int(gl_LocalInvocationID.x); - const int ly = int(gl_LocalInvocationID.y); + + const int lxd16 = lx / 16; // 0 1 + const int lxm16 = lx % 16; // 0 1 2 3 .... 15 fcoopmatNV<32, gl_ScopeSubgroup, 16, 16> sum0; fcoopmatNV<32, gl_ScopeSubgroup, 16, 16> sum1; @@ -118,43 +120,36 @@ void main() int z = 0; for (; z + (UNROLL_INCH - 1) < N; z += UNROLL_INCH) { - if (lx == 0 && ly == 0) - { - for (int z4 = 0; z4 < UNROLL_INCH; z4++) { - for (int i = 0; i < 16; i++) + for (int j = 0; j < 4; j++) { - for (int j = 0; j < 4; j++) - { - const int tmp_i = z4*16*4 + i * 4 + j; + const int tmp_i = lxd16*16*4 + lxm16 * 4 + j; - const int sz = (z + z4) / maxk; - const int kk = (z + z4) % maxk; + const int sz = (z + lxd16) / maxk; + const int kk = (z + lxd16) % maxk; - const int ky = kk / kernel_w; - const int kx = kk % kernel_w; + const int ky = kk / kernel_w; + const int kx = kk % kernel_w; - const ivec2 gx16 = gx + i + ivec2(0, 16); + const ivec2 gx16 = gx + lxm16 + ivec2(0, 16); - const ivec2 sy16 = gx16 / psc(outw); - const ivec2 sx16 = gx16 % psc(outw); + const ivec2 sy16 = gx16 / psc(outw); + const ivec2 sx16 = gx16 % psc(outw); - const ivec2 sxs16 = sx16 * stride_w; - const ivec2 sys16 = sy16 * stride_h; + const ivec2 sxs16 = sx16 * stride_w; + const ivec2 sys16 = sy16 * stride_h; - const ivec2 v_offset = (sz * 4 + j) * psc(cstep) + (sys16 + ky * dilation_h) * psc(w) + sxs16 + kx * dilation_w; + const ivec2 v_offset = (sz * 4 + j) * psc(cstep) + (sys16 + ky * dilation_h) * psc(w) + sxs16 + kx * dilation_w; - tmp_v0[tmp_i] = gx16.r < psc(outcstep) ? bottom_blob_data[v_offset.r] : uvec2(0); - tmp_v1[tmp_i] = gx16.g < psc(outcstep) ? bottom_blob_data[v_offset.g] : uvec2(0); + tmp_v0[tmp_i] = gx16.r < psc(outcstep) ? bottom_blob_data[v_offset.r] : uvec2(0); + tmp_v1[tmp_i] = gx16.g < psc(outcstep) ? bottom_blob_data[v_offset.g] : uvec2(0); - int w_offset = gy * psc(c) * 4 + (z + z4) * 4 * 16 + (i * 4 + j); + int w_offset = gy * psc(c) * 4 * maxk + (z + lxd16) * 4 * 16 + (lxm16 * 4 + j); - tmp_k0[tmp_i] = weight_data[w_offset]; - tmp_k1[tmp_i] = weight_data[w_offset + psc(c) * 16]; - } + tmp_k0[tmp_i] = weight_data[w_offset]; + tmp_k1[tmp_i] = weight_data[w_offset + psc(c) * maxk * 16]; } } - } barrier(); @@ -184,43 +179,37 @@ void main() { const int remain = N - z; - if (lx == 0 && ly == 0) - { - for (int z4 = 0; z4 < remain; z4++) + if (lxd16 == 0) { - for (int i = 0; i < 16; i++) + for (int j = 0; j < 4; j++) { - for (int j = 0; j < 4; j++) - { - const int tmp_i = z4*16*4 + i * 4 + j; + const int tmp_i = lxd16*16*4 + lxm16 * 4 + j; - const int sz = (z + z4) / maxk; - const int kk = (z + z4) % maxk; + const int sz = (z + lxd16) / maxk; + const int kk = (z + lxd16) % maxk; - const int ky = kk / kernel_w; - const int kx = kk % kernel_w; + const int ky = kk / kernel_w; + const int kx = kk % kernel_w; - const ivec2 gx16 = gx + i + ivec2(0, 16); + const ivec2 gx16 = gx + lxm16 + ivec2(0, 16); - const ivec2 sy16 = gx16 / psc(outw); - const ivec2 sx16 = gx16 % psc(outw); + const ivec2 sy16 = gx16 / psc(outw); + const ivec2 sx16 = gx16 % psc(outw); - const ivec2 sxs16 = sx16 * stride_w; - const ivec2 sys16 = sy16 * stride_h; + const ivec2 sxs16 = sx16 * stride_w; + const ivec2 sys16 = sy16 * stride_h; - const ivec2 v_offset = (sz * 4 + j) * psc(cstep) + (sys16 + ky * dilation_h) * psc(w) + sxs16 + kx * dilation_w; + const ivec2 v_offset = (sz * 4 + j) * psc(cstep) + (sys16 + ky * dilation_h) * psc(w) + sxs16 + kx * dilation_w; - tmp_v0[tmp_i] = gx16.r < psc(outcstep) ? bottom_blob_data[v_offset.r] : uvec2(0); - tmp_v1[tmp_i] = gx16.g < psc(outcstep) ? bottom_blob_data[v_offset.g] : uvec2(0); + tmp_v0[tmp_i] = gx16.r < psc(outcstep) ? bottom_blob_data[v_offset.r] : uvec2(0); + tmp_v1[tmp_i] = gx16.g < psc(outcstep) ? bottom_blob_data[v_offset.g] : uvec2(0); - int w_offset = gy * psc(c) * 4 + (z + z4) * 4 * 16 + (i * 4 + j); + int w_offset = gy * psc(c) * 4 * maxk + (z + lxd16) * 4 * 16 + (lxm16 * 4 + j); - tmp_k0[tmp_i] = weight_data[w_offset]; - tmp_k1[tmp_i] = weight_data[w_offset + psc(c) * 16]; - } + tmp_k0[tmp_i] = weight_data[w_offset]; + tmp_k1[tmp_i] = weight_data[w_offset + psc(c) * maxk * 16]; } } - } barrier(); @@ -256,46 +245,31 @@ void main() coopMatStoreNV(sum0_fp16, tmp_v0, 0, 4, false); coopMatStoreNV(sum1_fp16, tmp_v1, 0, 4, false); - coopMatStoreNV(sum2_fp16, tmp_k0, 0, 4, false); - coopMatStoreNV(sum3_fp16, tmp_k1, 0, 4, false); + coopMatStoreNV(sum2_fp16, tmp_v0, 16*4, 4, false); + coopMatStoreNV(sum3_fp16, tmp_v1, 16*4, 4, false); barrier(); - if (lx == 0 && ly == 0) { - for (int i = 0; i < 16; i++) + for (int j = 0; j < 4; j++) { - for (int j = 0; j < 4; j++) + const int tmp_vi = lxm16 * 4 + j + lxd16*16*4; + + uvec2 sum0_u2 = tmp_v0[tmp_vi]; + uvec2 sum1_u2 = tmp_v1[tmp_vi]; + + afpvec4 sum0 = afpvec4(unpackHalf2x16(sum0_u2.x), unpackHalf2x16(sum0_u2.y)); + afpvec4 sum1 = afpvec4(unpackHalf2x16(sum1_u2.x), unpackHalf2x16(sum1_u2.y)); + + sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1); + + const int gi = (gy + lxd16 * 4 + j) * psc(outcstep) + (gx + lxm16); + + if (gy + lxd16 * 4 + j < psc(outc)) { - const int tmp_vi = i * 4 + j; - - uvec2 sum0_u2 = tmp_v0[tmp_vi]; - uvec2 sum1_u2 = tmp_v1[tmp_vi]; - uvec2 sum2_u2 = tmp_k0[tmp_vi]; - uvec2 sum3_u2 = tmp_k1[tmp_vi]; - - afpvec4 sum0 = afpvec4(unpackHalf2x16(sum0_u2.x), unpackHalf2x16(sum0_u2.y)); - afpvec4 sum1 = afpvec4(unpackHalf2x16(sum1_u2.x), unpackHalf2x16(sum1_u2.y)); - afpvec4 sum2 = afpvec4(unpackHalf2x16(sum2_u2.x), unpackHalf2x16(sum2_u2.y)); - afpvec4 sum3 = afpvec4(unpackHalf2x16(sum3_u2.x), unpackHalf2x16(sum3_u2.y)); - - sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1); - sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1); - sum2 = activation_afpvec4(sum2, activation_type, activation_param_0, activation_param_1); - sum3 = activation_afpvec4(sum3, activation_type, activation_param_0, activation_param_1); - - if (gy + j < psc(outc)) - { - int gi = (gy + j) * psc(outcstep) + (gx + i); - if (gx + i < psc(outcstep)) buffer_st4(top_blob_data, gi, sum0); - if (gx + i + 16 < psc(outcstep)) buffer_st4(top_blob_data, gi + 16, sum1); - } - if (gy + 4 + j < psc(outc)) - { - int gi = (gy + 4 + j) * psc(outcstep) + (gx + i); - if (gx + i < psc(outcstep)) buffer_st4(top_blob_data, gi, sum2); - if (gx + i + 16 < psc(outcstep)) buffer_st4(top_blob_data, gi + 16, sum3); - } + if (gx + lxm16 < psc(outcstep)) buffer_st4(top_blob_data, gi, sum0); + if (gx + lxm16 + 16 < psc(outcstep)) buffer_st4(top_blob_data, gi + 16, sum1); } } }