diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 95c3095dd04da..f271e862e8a81 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -178,6 +178,8 @@ struct cmd_params { std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; + //I'm not sure if it's safe to call llama_numa_init multiple times, so this isn't a vector. + ggml_numa_strategy numa; int reps; bool verbose; output_formats output_format; @@ -200,6 +202,7 @@ static const cmd_params cmd_params_defaults = { /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, + /* numa */ GGML_NUMA_STRATEGY_DISABLED, /* reps */ 5, /* verbose */ false, /* output_format */ MARKDOWN @@ -224,6 +227,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); + printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); printf(" -r, --repetitions (default: %d)\n", cmd_params_defaults.reps); @@ -396,6 +400,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = split(argv[i], split_delim); params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end()); + } else if (arg == "--numa") { + if (++i >= argc) { + invalid_param = true; + break; + } else { + std::string value(argv[i]); + /**/ if (value == "distribute" || value == "" ) { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } + else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } + else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } + else { invalid_param = true; break; } + } } else if (arg == "-fa" || arg == "--flash-attn") { if (++i >= argc) { invalid_param = true; @@ -1215,6 +1230,7 @@ int main(int argc, char ** argv) { llama_log_set(llama_null_log_callback, NULL); } llama_backend_init(); + llama_numa_init(params.numa); // initialize printer std::unique_ptr p;