Skip to content

Commit

Permalink
fix and opt++
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jun 30, 2023
1 parent 8407274 commit 2da6c82
Showing 1 changed file with 58 additions and 84 deletions.
142 changes: 58 additions & 84 deletions src/layer/vulkan/shader/convolution_pack4_gemm_cm_16_16_16.comp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ layout (push_constant) uniform parameter
int outcstep;
} p;

#define UNROLL_INCH 4
#define UNROLL_INCH 2

shared uvec2 tmp_v0[UNROLL_INCH * 16*4];
shared uvec2 tmp_v1[UNROLL_INCH * 16*4];
Expand All @@ -84,7 +84,9 @@ void main()
const int outsize = psc(outw) * psc(outh);

const int lx = int(gl_LocalInvocationID.x);
const int ly = int(gl_LocalInvocationID.y);

const int lxd16 = lx / 16; // 0 1
const int lxm16 = lx % 16; // 0 1 2 3 .... 15

fcoopmatNV<32, gl_ScopeSubgroup, 16, 16> sum0;
fcoopmatNV<32, gl_ScopeSubgroup, 16, 16> sum1;
Expand Down Expand Up @@ -118,43 +120,36 @@ void main()
int z = 0;
for (; z + (UNROLL_INCH - 1) < N; z += UNROLL_INCH)
{
if (lx == 0 && ly == 0)
{
for (int z4 = 0; z4 < UNROLL_INCH; z4++)
{
for (int i = 0; i < 16; i++)
for (int j = 0; j < 4; j++)
{
for (int j = 0; j < 4; j++)
{
const int tmp_i = z4*16*4 + i * 4 + j;
const int tmp_i = lxd16*16*4 + lxm16 * 4 + j;

const int sz = (z + z4) / maxk;
const int kk = (z + z4) % maxk;
const int sz = (z + lxd16) / maxk;
const int kk = (z + lxd16) % maxk;

const int ky = kk / kernel_w;
const int kx = kk % kernel_w;
const int ky = kk / kernel_w;
const int kx = kk % kernel_w;

const ivec2 gx16 = gx + i + ivec2(0, 16);
const ivec2 gx16 = gx + lxm16 + ivec2(0, 16);

const ivec2 sy16 = gx16 / psc(outw);
const ivec2 sx16 = gx16 % psc(outw);
const ivec2 sy16 = gx16 / psc(outw);
const ivec2 sx16 = gx16 % psc(outw);

const ivec2 sxs16 = sx16 * stride_w;
const ivec2 sys16 = sy16 * stride_h;
const ivec2 sxs16 = sx16 * stride_w;
const ivec2 sys16 = sy16 * stride_h;

const ivec2 v_offset = (sz * 4 + j) * psc(cstep) + (sys16 + ky * dilation_h) * psc(w) + sxs16 + kx * dilation_w;
const ivec2 v_offset = (sz * 4 + j) * psc(cstep) + (sys16 + ky * dilation_h) * psc(w) + sxs16 + kx * dilation_w;

tmp_v0[tmp_i] = gx16.r < psc(outcstep) ? bottom_blob_data[v_offset.r] : uvec2(0);
tmp_v1[tmp_i] = gx16.g < psc(outcstep) ? bottom_blob_data[v_offset.g] : uvec2(0);
tmp_v0[tmp_i] = gx16.r < psc(outcstep) ? bottom_blob_data[v_offset.r] : uvec2(0);
tmp_v1[tmp_i] = gx16.g < psc(outcstep) ? bottom_blob_data[v_offset.g] : uvec2(0);

int w_offset = gy * psc(c) * 4 + (z + z4) * 4 * 16 + (i * 4 + j);
int w_offset = gy * psc(c) * 4 * maxk + (z + lxd16) * 4 * 16 + (lxm16 * 4 + j);

tmp_k0[tmp_i] = weight_data[w_offset];
tmp_k1[tmp_i] = weight_data[w_offset + psc(c) * 16];
}
tmp_k0[tmp_i] = weight_data[w_offset];
tmp_k1[tmp_i] = weight_data[w_offset + psc(c) * maxk * 16];
}
}
}

barrier();

Expand Down Expand Up @@ -184,43 +179,37 @@ void main()
{
const int remain = N - z;

if (lx == 0 && ly == 0)
{
for (int z4 = 0; z4 < remain; z4++)
if (lxd16 == 0)
{
for (int i = 0; i < 16; i++)
for (int j = 0; j < 4; j++)
{
for (int j = 0; j < 4; j++)
{
const int tmp_i = z4*16*4 + i * 4 + j;
const int tmp_i = lxd16*16*4 + lxm16 * 4 + j;

const int sz = (z + z4) / maxk;
const int kk = (z + z4) % maxk;
const int sz = (z + lxd16) / maxk;
const int kk = (z + lxd16) % maxk;

const int ky = kk / kernel_w;
const int kx = kk % kernel_w;
const int ky = kk / kernel_w;
const int kx = kk % kernel_w;

const ivec2 gx16 = gx + i + ivec2(0, 16);
const ivec2 gx16 = gx + lxm16 + ivec2(0, 16);

const ivec2 sy16 = gx16 / psc(outw);
const ivec2 sx16 = gx16 % psc(outw);
const ivec2 sy16 = gx16 / psc(outw);
const ivec2 sx16 = gx16 % psc(outw);

const ivec2 sxs16 = sx16 * stride_w;
const ivec2 sys16 = sy16 * stride_h;
const ivec2 sxs16 = sx16 * stride_w;
const ivec2 sys16 = sy16 * stride_h;

const ivec2 v_offset = (sz * 4 + j) * psc(cstep) + (sys16 + ky * dilation_h) * psc(w) + sxs16 + kx * dilation_w;
const ivec2 v_offset = (sz * 4 + j) * psc(cstep) + (sys16 + ky * dilation_h) * psc(w) + sxs16 + kx * dilation_w;

tmp_v0[tmp_i] = gx16.r < psc(outcstep) ? bottom_blob_data[v_offset.r] : uvec2(0);
tmp_v1[tmp_i] = gx16.g < psc(outcstep) ? bottom_blob_data[v_offset.g] : uvec2(0);
tmp_v0[tmp_i] = gx16.r < psc(outcstep) ? bottom_blob_data[v_offset.r] : uvec2(0);
tmp_v1[tmp_i] = gx16.g < psc(outcstep) ? bottom_blob_data[v_offset.g] : uvec2(0);

int w_offset = gy * psc(c) * 4 + (z + z4) * 4 * 16 + (i * 4 + j);
int w_offset = gy * psc(c) * 4 * maxk + (z + lxd16) * 4 * 16 + (lxm16 * 4 + j);

tmp_k0[tmp_i] = weight_data[w_offset];
tmp_k1[tmp_i] = weight_data[w_offset + psc(c) * 16];
}
tmp_k0[tmp_i] = weight_data[w_offset];
tmp_k1[tmp_i] = weight_data[w_offset + psc(c) * maxk * 16];
}
}
}

barrier();

Expand Down Expand Up @@ -256,46 +245,31 @@ void main()

coopMatStoreNV(sum0_fp16, tmp_v0, 0, 4, false);
coopMatStoreNV(sum1_fp16, tmp_v1, 0, 4, false);
coopMatStoreNV(sum2_fp16, tmp_k0, 0, 4, false);
coopMatStoreNV(sum3_fp16, tmp_k1, 0, 4, false);
coopMatStoreNV(sum2_fp16, tmp_v0, 16*4, 4, false);
coopMatStoreNV(sum3_fp16, tmp_v1, 16*4, 4, false);

barrier();

if (lx == 0 && ly == 0)
{
for (int i = 0; i < 16; i++)
for (int j = 0; j < 4; j++)
{
for (int j = 0; j < 4; j++)
const int tmp_vi = lxm16 * 4 + j + lxd16*16*4;

uvec2 sum0_u2 = tmp_v0[tmp_vi];
uvec2 sum1_u2 = tmp_v1[tmp_vi];

afpvec4 sum0 = afpvec4(unpackHalf2x16(sum0_u2.x), unpackHalf2x16(sum0_u2.y));
afpvec4 sum1 = afpvec4(unpackHalf2x16(sum1_u2.x), unpackHalf2x16(sum1_u2.y));

sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1);
sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1);

const int gi = (gy + lxd16 * 4 + j) * psc(outcstep) + (gx + lxm16);

if (gy + lxd16 * 4 + j < psc(outc))
{
const int tmp_vi = i * 4 + j;

uvec2 sum0_u2 = tmp_v0[tmp_vi];
uvec2 sum1_u2 = tmp_v1[tmp_vi];
uvec2 sum2_u2 = tmp_k0[tmp_vi];
uvec2 sum3_u2 = tmp_k1[tmp_vi];

afpvec4 sum0 = afpvec4(unpackHalf2x16(sum0_u2.x), unpackHalf2x16(sum0_u2.y));
afpvec4 sum1 = afpvec4(unpackHalf2x16(sum1_u2.x), unpackHalf2x16(sum1_u2.y));
afpvec4 sum2 = afpvec4(unpackHalf2x16(sum2_u2.x), unpackHalf2x16(sum2_u2.y));
afpvec4 sum3 = afpvec4(unpackHalf2x16(sum3_u2.x), unpackHalf2x16(sum3_u2.y));

sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1);
sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1);
sum2 = activation_afpvec4(sum2, activation_type, activation_param_0, activation_param_1);
sum3 = activation_afpvec4(sum3, activation_type, activation_param_0, activation_param_1);

if (gy + j < psc(outc))
{
int gi = (gy + j) * psc(outcstep) + (gx + i);
if (gx + i < psc(outcstep)) buffer_st4(top_blob_data, gi, sum0);
if (gx + i + 16 < psc(outcstep)) buffer_st4(top_blob_data, gi + 16, sum1);
}
if (gy + 4 + j < psc(outc))
{
int gi = (gy + 4 + j) * psc(outcstep) + (gx + i);
if (gx + i < psc(outcstep)) buffer_st4(top_blob_data, gi, sum2);
if (gx + i + 16 < psc(outcstep)) buffer_st4(top_blob_data, gi + 16, sum3);
}
if (gx + lxm16 < psc(outcstep)) buffer_st4(top_blob_data, gi, sum0);
if (gx + lxm16 + 16 < psc(outcstep)) buffer_st4(top_blob_data, gi + 16, sum1);
}
}
}
Expand Down

0 comments on commit 2da6c82

Please sign in to comment.