From a26de4fbf35fdca4f70f64e677900e2f0f875ae4 Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 30 Jun 2023 17:16:44 +0800 Subject: [PATCH] opt++ --- ...ck4_3x3s1d1_winograd_gemm_cm_16_16_16.comp | 80 +++++++------------ 1 file changed, 30 insertions(+), 50 deletions(-) diff --git a/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd_gemm_cm_16_16_16.comp b/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd_gemm_cm_16_16_16.comp index 21517559085..5513b1f4ed5 100644 --- a/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd_gemm_cm_16_16_16.comp +++ b/src/layer/vulkan/shader/convolution_pack4_3x3s1d1_winograd_gemm_cm_16_16_16.comp @@ -50,7 +50,7 @@ layout (push_constant) uniform parameter int outcstep; } p; -#define UNROLL_INCH 4 +#define UNROLL_INCH 2 shared uvec2 tmp_v0[UNROLL_INCH * 16*4]; shared uvec2 tmp_v1[UNROLL_INCH * 16*4]; @@ -71,34 +71,30 @@ void main() const int lx = int(gl_LocalInvocationID.x); const int ly = int(gl_LocalInvocationID.y); + const int lxd16 = lx / 16; // 0 1 + const int lxm16 = lx % 16; // 0 1 2 3 .... 15 + const int N = psc(c) / 4; int z = 0; for (; z + (UNROLL_INCH - 1) < N; z += UNROLL_INCH) { - if (lx == 0 && ly == 0) - { - for (int z4 = 0; z4 < UNROLL_INCH; z4++) { - for (int i = 0; i < 16; i++) + for (int j = 0; j < 4; j++) { - for (int j = 0; j < 4; j++) - { - const int tmp_i = z4*16*4 + i * 4 + j; + const int tmp_i = lxd16*16*4 + lxm16 * 4 + j; - int v_offset = gz * psc(cstep) + ((z + z4) * 4 + j) * psc(outw) + (gx + i); + const int v_offset = gz * psc(cstep) + ((z + lxd16) * 4 + j) * psc(outw) + (gx + lxm16); - tmp_v0[tmp_i] = (gx + i) < psc(outw) ? bottom_tm_blob_data[v_offset] : uvec2(0); - tmp_v1[tmp_i] = (gx + i + 16) < psc(outw) ? bottom_tm_blob_data[v_offset + 16] : uvec2(0); + tmp_v0[tmp_i] = (gx + lxm16) < psc(outw) ? bottom_tm_blob_data[v_offset] : uvec2(0); + tmp_v1[tmp_i] = (gx + lxm16 + 16) < psc(outw) ? bottom_tm_blob_data[v_offset + 16] : uvec2(0); - int w_offset = gz * psc(outc) * psc(c) * 4 + gy * psc(c) * 4 + (z + z4) * 4 * 16 + (i * 4 + j); + const int w_offset = gz * psc(outc) * psc(c) * 4 + gy * psc(c) * 4 + (z + lxd16) * 4 * 16 + (lxm16 * 4 + j); - tmp_k0[tmp_i] = weight_tm_data[w_offset]; - tmp_k1[tmp_i] = weight_tm_data[w_offset + psc(c) * 16]; - } + tmp_k0[tmp_i] = weight_tm_data[w_offset]; + tmp_k1[tmp_i] = weight_tm_data[w_offset + psc(c) * 16]; } } - } barrier(); @@ -128,29 +124,23 @@ void main() { const int remain = N - z; - if (lx == 0 && ly == 0) + if (lxd16 == 0) { - for (int z4 = 0; z4 < remain; z4++) - { - for (int i = 0; i < 16; i++) + for (int j = 0; j < 4; j++) { - for (int j = 0; j < 4; j++) - { - const int tmp_i = z4*16*4 + i * 4 + j; + const int tmp_i = lxd16*16*4 + lxm16 * 4 + j; - int v_offset = gz * psc(cstep) + ((z + z4) * 4 + j) * psc(outw) + (gx + i); + const int v_offset = gz * psc(cstep) + ((z + lxd16) * 4 + j) * psc(outw) + (gx + lxm16); - tmp_v0[tmp_i] = (gx + i) < psc(outw) ? bottom_tm_blob_data[v_offset] : uvec2(0); - tmp_v1[tmp_i] = (gx + i + 16) < psc(outw) ? bottom_tm_blob_data[v_offset + 16] : uvec2(0); + tmp_v0[tmp_i] = (gx + lxm16) < psc(outw) ? bottom_tm_blob_data[v_offset] : uvec2(0); + tmp_v1[tmp_i] = (gx + lxm16 + 16) < psc(outw) ? bottom_tm_blob_data[v_offset + 16] : uvec2(0); - int w_offset = gz * psc(outc) * psc(c) * 4 + gy * psc(c) * 4 + (z + z4) * 4 * 16 + (i * 4 + j); + const int w_offset = gz * psc(outc) * psc(c) * 4 + gy * psc(c) * 4 + (z + lxd16) * 4 * 16 + (lxm16 * 4 + j); - tmp_k0[tmp_i] = weight_tm_data[w_offset]; - tmp_k1[tmp_i] = weight_tm_data[w_offset + psc(c) * 16]; - } + tmp_k0[tmp_i] = weight_tm_data[w_offset]; + tmp_k1[tmp_i] = weight_tm_data[w_offset + psc(c) * 16]; } } - } barrier(); @@ -186,31 +176,21 @@ void main() coopMatStoreNV(sum0_fp16, tmp_v0, 0, 4, false); coopMatStoreNV(sum1_fp16, tmp_v1, 0, 4, false); - coopMatStoreNV(sum2_fp16, tmp_k0, 0, 4, false); - coopMatStoreNV(sum3_fp16, tmp_k1, 0, 4, false); + coopMatStoreNV(sum2_fp16, tmp_v0, 16*4, 4, false); + coopMatStoreNV(sum3_fp16, tmp_v1, 16*4, 4, false); barrier(); - if (lx == 0 && ly == 0) { - for (int i = 0; i < 16; i++) + for (int j = 0; j < 4; j++) { - for (int j = 0; j < 4; j++) + const int tmp_vi = lxm16 * 4 + j + lxd16*16*4; + const int gi = gz * psc(outcstep) + (gy + lxd16 * 4 + j) * psc(outw) + (gx + lxm16); + + if (gy + lxd16 * 4 + j < psc(outc)) { - const int tmp_vi = i * 4 + j; - - if (gy + j < psc(outc)) - { - int gi = gz * psc(outcstep) + (gy + j) * psc(outw) + (gx + i); - if (gx + i < psc(outw)) top_tm_blob_data[gi] = tmp_v0[tmp_vi]; - if (gx + i + 16 < psc(outw)) top_tm_blob_data[gi + 16] = tmp_v1[tmp_vi]; - } - if (gy + 4 + j < psc(outc)) - { - int gi = gz * psc(outcstep) + (gy + 4 + j) * psc(outw) + (gx + i); - if (gx + i < psc(outw)) top_tm_blob_data[gi] = tmp_k0[tmp_vi]; - if (gx + i + 16 < psc(outw)) top_tm_blob_data[gi + 16] = tmp_k1[tmp_vi]; - } + if (gx + lxm16 < psc(outw)) top_tm_blob_data[gi] = tmp_v0[tmp_vi]; + if (gx + lxm16 + 16 < psc(outw)) top_tm_blob_data[gi + 16] = tmp_v1[tmp_vi]; } } }