Skip to content

Commit

Permalink
metal : limit kernels to not use more than the allowed threads
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 13, 2023
1 parent ab558ac commit 109e7aa
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,8 @@ void ggml_metal_graph_compute(

int64_t nb = ne00;

id<MTLComputePipelineState> pipeline = nil;

if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));

Expand All @@ -1088,21 +1090,23 @@ void ggml_metal_graph_compute(

nb = ne00 / 4;
switch (dst->op) {
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
default: GGML_ASSERT(false);
}

bcast_row = true;
} else {
switch (dst->op) {
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
default: GGML_ASSERT(false);
}
}

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
Expand Down Expand Up @@ -1137,7 +1141,7 @@ void ggml_metal_graph_compute(

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
const int nth = MIN(1024, ne0);
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
Expand Down

0 comments on commit 109e7aa

Please sign in to comment.