Skip to content

Commit

Permalink
Merge branch 'master' into concedo_experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Dec 3, 2023
2 parents ac36aee + fbbc428 commit 8602f5a
Showing 1 changed file with 6 additions and 22 deletions.
28 changes: 6 additions & 22 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -15629,7 +15629,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
Expand All @@ -15645,6 +15644,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
{
n_tasks = 1; //TODO
} break;
case GGML_OP_SOFT_MAX:
{
n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
} break;
case GGML_OP_CONV_TRANSPOSE_1D:
{
n_tasks = n_threads;
Expand Down Expand Up @@ -15872,35 +15875,29 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {

// thread scheduling for the different operations + work buffer size estimation
for (int i = 0; i < cgraph->n_nodes; i++) {
int n_tasks = 1;

struct ggml_tensor * node = cgraph->nodes[i];

const int n_tasks = ggml_get_n_tasks(node, n_threads);

size_t cur = 0;

switch (node->op) {
case GGML_OP_CPY:
case GGML_OP_DUP:
{
n_tasks = n_threads;

if (ggml_is_quantized(node->type)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
}
} break;
case GGML_OP_ADD:
case GGML_OP_ADD1:
{
n_tasks = n_threads;

if (ggml_is_quantized(node->src[0]->type)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
}
} break;
case GGML_OP_ACC:
{
n_tasks = n_threads;

if (ggml_is_quantized(node->src[0]->type)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
}
Expand Down Expand Up @@ -15928,16 +15925,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} break;
case GGML_OP_OUT_PROD:
{
n_tasks = n_threads;

if (ggml_is_quantized(node->src[0]->type)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
}
} break;
case GGML_OP_SOFT_MAX:
{
n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));

cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
} break;
case GGML_OP_CONV_TRANSPOSE_1D:
Expand Down Expand Up @@ -15967,7 +15960,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} break;
case GGML_OP_IM2COL:
{
n_tasks = n_threads;
} break;
case GGML_OP_CONV_TRANSPOSE_2D:
{
Expand All @@ -15985,8 +15977,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} break;
case GGML_OP_FLASH_ATTN:
{
n_tasks = n_threads;

const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);

if (node->src[1]->type == GGML_TYPE_F32) {
Expand All @@ -15999,8 +15989,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} break;
case GGML_OP_FLASH_FF:
{
n_tasks = n_threads;

if (node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
Expand All @@ -16011,8 +15999,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
n_tasks = n_threads;

const int64_t D = node->src[0]->ne[0];
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
Expand All @@ -16027,8 +16013,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {

case GGML_OP_CROSS_ENTROPY_LOSS:
{
n_tasks = n_threads;

cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
} break;
case GGML_OP_COUNT:
Expand Down

0 comments on commit 8602f5a

Please sign in to comment.