Skip to content

Commit

Permalink
opt++
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jun 30, 2023
1 parent 5a60527 commit a26de4f
Showing 1 changed file with 30 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ layout (push_constant) uniform parameter
int outcstep;
} p;

#define UNROLL_INCH 4
#define UNROLL_INCH 2

shared uvec2 tmp_v0[UNROLL_INCH * 16*4];
shared uvec2 tmp_v1[UNROLL_INCH * 16*4];
Expand All @@ -71,34 +71,30 @@ void main()
const int lx = int(gl_LocalInvocationID.x);
const int ly = int(gl_LocalInvocationID.y);

const int lxd16 = lx / 16; // 0 1
const int lxm16 = lx % 16; // 0 1 2 3 .... 15

const int N = psc(c) / 4;

int z = 0;
for (; z + (UNROLL_INCH - 1) < N; z += UNROLL_INCH)
{
if (lx == 0 && ly == 0)
{
for (int z4 = 0; z4 < UNROLL_INCH; z4++)
{
for (int i = 0; i < 16; i++)
for (int j = 0; j < 4; j++)
{
for (int j = 0; j < 4; j++)
{
const int tmp_i = z4*16*4 + i * 4 + j;
const int tmp_i = lxd16*16*4 + lxm16 * 4 + j;

int v_offset = gz * psc(cstep) + ((z + z4) * 4 + j) * psc(outw) + (gx + i);
const int v_offset = gz * psc(cstep) + ((z + lxd16) * 4 + j) * psc(outw) + (gx + lxm16);

tmp_v0[tmp_i] = (gx + i) < psc(outw) ? bottom_tm_blob_data[v_offset] : uvec2(0);
tmp_v1[tmp_i] = (gx + i + 16) < psc(outw) ? bottom_tm_blob_data[v_offset + 16] : uvec2(0);
tmp_v0[tmp_i] = (gx + lxm16) < psc(outw) ? bottom_tm_blob_data[v_offset] : uvec2(0);
tmp_v1[tmp_i] = (gx + lxm16 + 16) < psc(outw) ? bottom_tm_blob_data[v_offset + 16] : uvec2(0);

int w_offset = gz * psc(outc) * psc(c) * 4 + gy * psc(c) * 4 + (z + z4) * 4 * 16 + (i * 4 + j);
const int w_offset = gz * psc(outc) * psc(c) * 4 + gy * psc(c) * 4 + (z + lxd16) * 4 * 16 + (lxm16 * 4 + j);

tmp_k0[tmp_i] = weight_tm_data[w_offset];
tmp_k1[tmp_i] = weight_tm_data[w_offset + psc(c) * 16];
}
tmp_k0[tmp_i] = weight_tm_data[w_offset];
tmp_k1[tmp_i] = weight_tm_data[w_offset + psc(c) * 16];
}
}
}

barrier();

Expand Down Expand Up @@ -128,29 +124,23 @@ void main()
{
const int remain = N - z;

if (lx == 0 && ly == 0)
if (lxd16 == 0)
{
for (int z4 = 0; z4 < remain; z4++)
{
for (int i = 0; i < 16; i++)
for (int j = 0; j < 4; j++)
{
for (int j = 0; j < 4; j++)
{
const int tmp_i = z4*16*4 + i * 4 + j;
const int tmp_i = lxd16*16*4 + lxm16 * 4 + j;

int v_offset = gz * psc(cstep) + ((z + z4) * 4 + j) * psc(outw) + (gx + i);
const int v_offset = gz * psc(cstep) + ((z + lxd16) * 4 + j) * psc(outw) + (gx + lxm16);

tmp_v0[tmp_i] = (gx + i) < psc(outw) ? bottom_tm_blob_data[v_offset] : uvec2(0);
tmp_v1[tmp_i] = (gx + i + 16) < psc(outw) ? bottom_tm_blob_data[v_offset + 16] : uvec2(0);
tmp_v0[tmp_i] = (gx + lxm16) < psc(outw) ? bottom_tm_blob_data[v_offset] : uvec2(0);
tmp_v1[tmp_i] = (gx + lxm16 + 16) < psc(outw) ? bottom_tm_blob_data[v_offset + 16] : uvec2(0);

int w_offset = gz * psc(outc) * psc(c) * 4 + gy * psc(c) * 4 + (z + z4) * 4 * 16 + (i * 4 + j);
const int w_offset = gz * psc(outc) * psc(c) * 4 + gy * psc(c) * 4 + (z + lxd16) * 4 * 16 + (lxm16 * 4 + j);

tmp_k0[tmp_i] = weight_tm_data[w_offset];
tmp_k1[tmp_i] = weight_tm_data[w_offset + psc(c) * 16];
}
tmp_k0[tmp_i] = weight_tm_data[w_offset];
tmp_k1[tmp_i] = weight_tm_data[w_offset + psc(c) * 16];
}
}
}

barrier();

Expand Down Expand Up @@ -186,31 +176,21 @@ void main()

coopMatStoreNV(sum0_fp16, tmp_v0, 0, 4, false);
coopMatStoreNV(sum1_fp16, tmp_v1, 0, 4, false);
coopMatStoreNV(sum2_fp16, tmp_k0, 0, 4, false);
coopMatStoreNV(sum3_fp16, tmp_k1, 0, 4, false);
coopMatStoreNV(sum2_fp16, tmp_v0, 16*4, 4, false);
coopMatStoreNV(sum3_fp16, tmp_v1, 16*4, 4, false);

barrier();

if (lx == 0 && ly == 0)
{
for (int i = 0; i < 16; i++)
for (int j = 0; j < 4; j++)
{
for (int j = 0; j < 4; j++)
const int tmp_vi = lxm16 * 4 + j + lxd16*16*4;
const int gi = gz * psc(outcstep) + (gy + lxd16 * 4 + j) * psc(outw) + (gx + lxm16);

if (gy + lxd16 * 4 + j < psc(outc))
{
const int tmp_vi = i * 4 + j;

if (gy + j < psc(outc))
{
int gi = gz * psc(outcstep) + (gy + j) * psc(outw) + (gx + i);
if (gx + i < psc(outw)) top_tm_blob_data[gi] = tmp_v0[tmp_vi];
if (gx + i + 16 < psc(outw)) top_tm_blob_data[gi + 16] = tmp_v1[tmp_vi];
}
if (gy + 4 + j < psc(outc))
{
int gi = gz * psc(outcstep) + (gy + 4 + j) * psc(outw) + (gx + i);
if (gx + i < psc(outw)) top_tm_blob_data[gi] = tmp_k0[tmp_vi];
if (gx + i + 16 < psc(outw)) top_tm_blob_data[gi + 16] = tmp_k1[tmp_vi];
}
if (gx + lxm16 < psc(outw)) top_tm_blob_data[gi] = tmp_v0[tmp_vi];
if (gx + lxm16 + 16 < psc(outw)) top_tm_blob_data[gi + 16] = tmp_v1[tmp_vi];
}
}
}
Expand Down

0 comments on commit a26de4f

Please sign in to comment.