Skip to content

Commit

Permalink
metal : fix soft_max kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 13, 2023
1 parent 82e4f64 commit ab558ac
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
4 changes: 3 additions & 1 deletion ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,8 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) {
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
Expand Down Expand Up @@ -1520,7 +1522,7 @@ void ggml_metal_graph_compute(
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
int64_t ny = (ne11 + nrows - 1)/nrows;
const int64_t ny = (ne11 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
Expand Down
16 changes: 10 additions & 6 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,9 @@ kernel void kernel_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;

// parallel max
float lmax = -INFINITY;
Expand Down Expand Up @@ -386,6 +386,8 @@ kernel void kernel_soft_max(
}

float sum = simd_sum(lsum);
threadgroup_barrier(mem_flags::mem_threadgroup);

if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
Expand Down Expand Up @@ -428,9 +430,9 @@ kernel void kernel_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);

// parallel max
float4 lmax4 = -INFINITY;
Expand Down Expand Up @@ -468,6 +470,8 @@ kernel void kernel_soft_max_4(
}

const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
threadgroup_barrier(mem_flags::mem_threadgroup);

float sum = simd_sum(lsum);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
Expand Down

0 comments on commit ab558ac

Please sign in to comment.