Skip to content

Commit

Permalink
ggml : allow ggml_get_rows to use multiple threads if they are available
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Mar 13, 2024
1 parent 529e749 commit 54cdd47
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -17814,7 +17814,7 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
node->perf_time_us += time_us_cur;
}

static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) {
int n_tasks = 0;

switch (node->op) {
Expand Down Expand Up @@ -17899,7 +17899,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
{
// FIXME: the cost of launching additional threads decreases performance with GPU offloading
//n_tasks = MIN(n_threads, ggml_nelements(node->src[1]));
n_tasks = 1;
n_tasks = MIN(n_cur_threads, ggml_nelements(node->src[1]));
} break;
case GGML_OP_SCALE:
case GGML_OP_SET:
Expand Down Expand Up @@ -18125,7 +18125,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
/* FINALIZE */
struct ggml_tensor * node = cgraph->nodes[node_n];
if (GGML_OP_HAS_FINALIZE[node->op]) {
params.nth = ggml_get_n_tasks(node, n_threads);
params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
ggml_compute_forward(&params, node);
}
ggml_graph_compute_perf_stats_node(node, state->shared);
Expand All @@ -18135,7 +18135,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
while (++node_n < cgraph->n_nodes) {
GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
struct ggml_tensor * node = cgraph->nodes[node_n];
const int n_tasks = ggml_get_n_tasks(node, n_threads);
const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);

state->shared->perf_node_start_cycles = ggml_perf_cycles();
state->shared->perf_node_start_time_us = ggml_perf_time_us();
Expand Down Expand Up @@ -18183,7 +18183,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {

/* INIT & COMPUTE */
struct ggml_tensor * node = cgraph->nodes[node_n];
const int n_tasks = ggml_get_n_tasks(node, n_threads);
const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);

struct ggml_compute_params params = {
/*.type =*/ GGML_TASK_TYPE_INIT,
Expand Down Expand Up @@ -18248,7 +18248,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];

const int n_tasks = ggml_get_n_tasks(node, n_threads);
const int n_tasks = ggml_get_n_tasks(node, n_threads, 1);

max_tasks = MAX(max_tasks, n_tasks);

Expand Down

0 comments on commit 54cdd47

Please sign in to comment.