-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : fix embeddings #5796
llama : fix embeddings #5796
Conversation
While we're modifying embeddings.cpp - shouldn't this: struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts); be this instead: struct llama_batch batch = llama_batch_init(n_batch, 0, 1); Because only one sequence ID is assigned per token? |
should the embeddings from |
Regarding removing the KV cache, I think this will give a big speedup. I did a flamegraph on CUDA, and it was spending fully 50% of the time in calls to |
ggml-ci
008f3fc
to
d034784
Compare
llama.cpp
Outdated
if (batch.logits[i] == 0) { | ||
continue; | ||
} | ||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@iamlemec What is the meaning of CLS
in this context? I don't associate this abbreviation with anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this mode uses the embedding of the CLS token (which is the Bert equivalent of BOS) instead of averaging the embedding of all tokens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, always confused me since I think CLoSe, but I guess it stands for (sentence) classification. There's something to be said for replacing it (and the command line flag) with something more expressive like START.
With these changes I'm getting an MSE of 1.3e-3 relative to Sentence Transformers on WikiText with nomic-embed-text-v1.f16.gguf instead of the previous 5.62e-10. Not sure what the problem is. edit: I'm also still seeing NaNs out of the embedding example. |
Does it use mean pooling? I think I got it wrong again - checking |
|
llama.cpp
Outdated
case LLAMA_POOLING_TYPE_CLS: | ||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float)); | ||
break; | ||
case LLAMA_POOLING_TYPE_MEAN: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have the LLAMA_POOLING_TYPE_MEAN
case join the LLAMA_POOLING_TYPE_CLS
case due to the output order of the averaging matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I keep getting confused with the sequence-based instead of token-based embedding extraction.
Will try to modify the API to make things more clear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, this change fixes the NaNs from embedding.cpp, and results in the MSE of nomic-embed-text-v1.f16.gguf actually being lower than before (4.71e-10 vs 5.62e-10). Also fp32 is down from 9.34e-11 to 1.18e-14.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also fp32 is down from 9.34e-11 to 1.18e-14.
This is likely due to no longer going through the KV cache.
Do you have any performance benchmarks to see if this change improved the speed?
So I've updated the API to support both sequence and token embeddings: Lines 657 to 671 in 79e4eed
The sequence embeddings (i.e. With this change, to make the llama.cpp/examples/embedding/embedding.cpp Lines 52 to 65 in fc9af15
First try to get sequence embeddings. This would return NULL if pooling is disabled and we fallback to token embeddings Hope this finally works. Let me know if you give it a try |
Working great here! Numbers look very, very similar to earlier. I'm a little surprised to see a slight performance degredation (around 20-30% on both CPU and CUDA). We're definitely getting gains from the kv-cache attention matrix stuff when that's a factor. But I wonder if it relates to the fact that we're breaking up the final copy into sequences? Either that or one of the smaller changes in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the llama_kv_cache_clear still do anything useful?
edit: I remembered that this example is used for models with causal attention as well. I won't need the equivalent if I'm just working with embedding models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, llama_kv_cache_clear
is not necessary for embedding models, but makes sense for causal attention models
* llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list
Let me know what you find. I guess the embeddings extraction is now indeed a bit slower due to the The main struggle with the API was to find a way to differentiate between token embeddings and sequence embeddings. Without the The big benefit of not using the KV cache is that now we don't need to allocate a large KV cache memory buffer. For example, you can compute 32k token embeddings with just a large enough batch size (e.g. |
* fix mul_mat fault in cpy_f32_f16 * rm unused function * add wait() for memcpy * restore ci/run.sh, rename struct defination, fix bug in ggml_sycl_op_mul_mat_sycl * fix format issue * llama : fix segfault from unknown model arch name (#5820) * llama : fix segfault from unknown model arch name * llama : make all LLM maps const This also requires using `std::map::at` instead of its `operator[]` which does not exist for const maps. * llama : name LLM_ARCH_UNKNOWN to "(unknown)" This avoids errors from `std::map::at` when getting the general name of the model architecture. Using "(unknown)" instead of an empty string as per suggestion #5820 (comment) * llama : remove redundant inner const for LLM_TENSOR_NAMES The extra const won't do anything here as const maps return const references to values. Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * llama : remove redundant nullptr check in llm_arch_from_string Since LLM_ARCH_NAMES is a const map, no spurious elements with a NULL name are inserted anymore, so this check is dead code. --------- Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * llama : refactor internal quantization functions (#5830) * scripts : add pod-llama.sh * ggml : IQ3_S improvements (#5829) * iq3_s: somewhat faster AVX2 dot product On Ryzen a 7950X TG-128 increases to 16 t/s from 15.5 t/s using 16 threads. For 8 threads it is 13.85 t/s vs 11.75 t/s. PP-512 increases to 28.5 t/s from 23.8 t/s. * iq3_s: somewhat faster ARM_NEON dot product Still dog slow - 10.7 t/s up from 9.9 t/s. * iq3_s: another small ARM_NEON improvement 10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick that works best on AVX2. * iq3_s: minor improvement on Metal 49.4 t/s -> 50.3 t/s * iq3_s: PPL improvement E.g., for a context of 4096 LLaMA-v2-7B goes to 5.1340 from 5.1653. * iq3_s: use new grid everywhere * Fix ARM_NEON --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> * convert-hf : make model class definitions self-contained (#5825) * convert : automatically fall back to HfVocab if tokenizer.model doesn't exist (#5821) * ggml : fix IQ3_S AVX implementation (#5834) ggml-ci * llama : add abort_callback to interrupt computation (#5409) * using abort_callback from ggml to stop llama computation * format fix * a brief explaining comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: tests: passkey challenge / self-extend with context shift demo (#5832) * server: tests: add models endpoint scenario * server: /v1/models add some metadata * server: tests: add debug field in context before scenario * server: tests: download model from HF, add batch size * server: tests: add passkey test * server: tests: add group attention params * server: do not truncate prompt tokens if self-extend through group attention is enabled * server: logs: do not truncate log values * server: tests - passkey - first good working value of nga * server: tests: fix server timeout * server: tests: fix passkey, add doc, fix regex content matching, fix timeout * server: tests: fix regex content matching * server: tests: schedule slow tests on master * server: metrics: fix when no prompt processed * server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1 * server: tests: increase timeout for completion * server: tests: keep only the PHI-2 test * server: tests: passkey add a negative test * flake.lock: Update (#5842) Flake lock file updates: • Updated input 'flake-parts': 'github:hercules-ci/flake-parts/b253292d9c0a5ead9bc98c4e9a26c6312e27d69f' (2024-02-01) → 'github:hercules-ci/flake-parts/f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2' (2024-03-01) • Updated input 'flake-parts/nixpkgs-lib': 'github:NixOS/nixpkgs/97b17f32362e475016f942bbdfda4a4a72a8a652?dir=lib' (2024-01-29) → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8?dir=lib' (2024-02-29) • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/cbc4211f0afffe6dfd2478a62615dd5175a13f9a' (2024-02-23) → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8' (2024-02-29) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * server : init http requests thread pool with --parallel if set (#5836) * ci : schedule slow server tests only on Release or on demand (#5839) * llama : fix llama_copy_state_data with fragmented KV cache (#5840) The row size of the saved states was based on kv_self.head while it should be based on llama_kv_cache_cell_max. Existing session files should still work. * llama : fix llama_kv_cache_cell_max inability to return 1 I've also changed its return type to uint32_t, because this function is always used to set the value of uint32_t variables, and because the index already has this type. * llama : fix state size calculation Some bytes in the state were unaccounted for in llama_get_state_size. Since the logits reserve so much space, it did not cause problems. * gguf-dump : support i-quants (#5841) Co-authored-by: Black_Fox <radekliska@gmail.com> * llama : allow for user specified embedding pooling type (#5849) * allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * readme : add API changes section * cuda : fix data race in soft max (#5853) * main : support special tokens as reverse/anti prompt (#5847) * Support special tokens as reverse/anti prompt. * Tokenize antiprompts only once. * main : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * common : use LLAMA_DEFAULT_SEED (#5855) * add some new ops, fix some operators and add batch operations to certain operators. (ggml/747) * cuda: fix group_norm * cuda: add batch inference support for ggml_pad/ggml_upscale * add ggml_arrange * add ggml_timestep_embedding * update ggml_arange/ggml_timestep_embedding tests * cuda: fix im2col * add ggml_arange/ggml_timestep_embbeding support for metal backend * fix some bugs * fix some bugs * Update ggml.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-cuda.cu Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.metal Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * modify according to the review comments * ggml : fix compile warnings + code style * ggml : normalize compute_forward calls + fix seg fault in debug * minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com> * sync : ggml * add alias for chat template (#5858) * speculative : implement stochastic speculative sampling (#5625) * (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix #5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README * cmake : handle cases where git index is not found in .git (#5844) * Update CMakeLists.txt * Update CMakeLists.txt * ggml : introduce ggml_status (ggml/750) * using enum as an exit code instead of macros * update return type from enum to unsigned int * indentation fix * compound update ggml_compute_exit_code -> ggml_status changed ggml_status from a bit-field type to simple codes ggml_status to string cast * ggml_status to string cast * GGML_CALL was removed Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * sync : ggml ggml-ci * ggml : fix unknown status (#0) * flake : fix * llama : fix embeddings (#5796) * llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list * nix: static build (#5814) * fix speculative decoding build on windows (#5874) * rebase and rm tailing space --------- Co-authored-by: LiangtaoJin <liang-tao.jin@intel.com> Co-authored-by: compilade <113953597+compilade@users.noreply.github.com> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: Michael Podvitskiy <podvitskiymichael@gmail.com> Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Nindaleth <Nindaleth@users.noreply.github.com> Co-authored-by: Black_Fox <radekliska@gmail.com> Co-authored-by: Douglas Hanley <thesecretaryofwar@gmail.com> Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: DAN™ <dranger003@gmail.com> Co-authored-by: leejet <leejet714@gmail.com> Co-authored-by: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> Co-authored-by: Dane Madsen <dane_madsen@hotmail.com> Co-authored-by: hutli <6594598+hutli@users.noreply.github.com> Co-authored-by: Jeffrey Quesnelle <emozilla@nousresearch.com>
Oh nice, I didn't know about
It seems like the kv cache code relies more on |
Try #5891 and see if it restores the performance |
* llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list
* fix mul_mat fault in cpy_f32_f16 * rm unused function * add wait() for memcpy * restore ci/run.sh, rename struct defination, fix bug in ggml_sycl_op_mul_mat_sycl * fix format issue * llama : fix segfault from unknown model arch name (ggerganov#5820) * llama : fix segfault from unknown model arch name * llama : make all LLM maps const This also requires using `std::map::at` instead of its `operator[]` which does not exist for const maps. * llama : name LLM_ARCH_UNKNOWN to "(unknown)" This avoids errors from `std::map::at` when getting the general name of the model architecture. Using "(unknown)" instead of an empty string as per suggestion ggerganov#5820 (comment) * llama : remove redundant inner const for LLM_TENSOR_NAMES The extra const won't do anything here as const maps return const references to values. Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * llama : remove redundant nullptr check in llm_arch_from_string Since LLM_ARCH_NAMES is a const map, no spurious elements with a NULL name are inserted anymore, so this check is dead code. --------- Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * llama : refactor internal quantization functions (ggerganov#5830) * scripts : add pod-llama.sh * ggml : IQ3_S improvements (ggerganov#5829) * iq3_s: somewhat faster AVX2 dot product On Ryzen a 7950X TG-128 increases to 16 t/s from 15.5 t/s using 16 threads. For 8 threads it is 13.85 t/s vs 11.75 t/s. PP-512 increases to 28.5 t/s from 23.8 t/s. * iq3_s: somewhat faster ARM_NEON dot product Still dog slow - 10.7 t/s up from 9.9 t/s. * iq3_s: another small ARM_NEON improvement 10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick that works best on AVX2. * iq3_s: minor improvement on Metal 49.4 t/s -> 50.3 t/s * iq3_s: PPL improvement E.g., for a context of 4096 LLaMA-v2-7B goes to 5.1340 from 5.1653. * iq3_s: use new grid everywhere * Fix ARM_NEON --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> * convert-hf : make model class definitions self-contained (ggerganov#5825) * convert : automatically fall back to HfVocab if tokenizer.model doesn't exist (ggerganov#5821) * ggml : fix IQ3_S AVX implementation (ggerganov#5834) ggml-ci * llama : add abort_callback to interrupt computation (ggerganov#5409) * using abort_callback from ggml to stop llama computation * format fix * a brief explaining comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: tests: passkey challenge / self-extend with context shift demo (ggerganov#5832) * server: tests: add models endpoint scenario * server: /v1/models add some metadata * server: tests: add debug field in context before scenario * server: tests: download model from HF, add batch size * server: tests: add passkey test * server: tests: add group attention params * server: do not truncate prompt tokens if self-extend through group attention is enabled * server: logs: do not truncate log values * server: tests - passkey - first good working value of nga * server: tests: fix server timeout * server: tests: fix passkey, add doc, fix regex content matching, fix timeout * server: tests: fix regex content matching * server: tests: schedule slow tests on master * server: metrics: fix when no prompt processed * server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1 * server: tests: increase timeout for completion * server: tests: keep only the PHI-2 test * server: tests: passkey add a negative test * flake.lock: Update (ggerganov#5842) Flake lock file updates: • Updated input 'flake-parts': 'github:hercules-ci/flake-parts/b253292d9c0a5ead9bc98c4e9a26c6312e27d69f' (2024-02-01) → 'github:hercules-ci/flake-parts/f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2' (2024-03-01) • Updated input 'flake-parts/nixpkgs-lib': 'github:NixOS/nixpkgs/97b17f32362e475016f942bbdfda4a4a72a8a652?dir=lib' (2024-01-29) → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8?dir=lib' (2024-02-29) • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/cbc4211f0afffe6dfd2478a62615dd5175a13f9a' (2024-02-23) → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8' (2024-02-29) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * server : init http requests thread pool with --parallel if set (ggerganov#5836) * ci : schedule slow server tests only on Release or on demand (ggerganov#5839) * llama : fix llama_copy_state_data with fragmented KV cache (ggerganov#5840) The row size of the saved states was based on kv_self.head while it should be based on llama_kv_cache_cell_max. Existing session files should still work. * llama : fix llama_kv_cache_cell_max inability to return 1 I've also changed its return type to uint32_t, because this function is always used to set the value of uint32_t variables, and because the index already has this type. * llama : fix state size calculation Some bytes in the state were unaccounted for in llama_get_state_size. Since the logits reserve so much space, it did not cause problems. * gguf-dump : support i-quants (ggerganov#5841) Co-authored-by: Black_Fox <radekliska@gmail.com> * llama : allow for user specified embedding pooling type (ggerganov#5849) * allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * readme : add API changes section * cuda : fix data race in soft max (ggerganov#5853) * main : support special tokens as reverse/anti prompt (ggerganov#5847) * Support special tokens as reverse/anti prompt. * Tokenize antiprompts only once. * main : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * common : use LLAMA_DEFAULT_SEED (ggerganov#5855) * add some new ops, fix some operators and add batch operations to certain operators. (ggml/747) * cuda: fix group_norm * cuda: add batch inference support for ggml_pad/ggml_upscale * add ggml_arrange * add ggml_timestep_embedding * update ggml_arange/ggml_timestep_embedding tests * cuda: fix im2col * add ggml_arange/ggml_timestep_embbeding support for metal backend * fix some bugs * fix some bugs * Update ggml.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-cuda.cu Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.metal Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * modify according to the review comments * ggml : fix compile warnings + code style * ggml : normalize compute_forward calls + fix seg fault in debug * minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com> * sync : ggml * add alias for chat template (ggerganov#5858) * speculative : implement stochastic speculative sampling (ggerganov#5625) * (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix ggerganov#5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README * cmake : handle cases where git index is not found in .git (ggerganov#5844) * Update CMakeLists.txt * Update CMakeLists.txt * ggml : introduce ggml_status (ggml/750) * using enum as an exit code instead of macros * update return type from enum to unsigned int * indentation fix * compound update ggml_compute_exit_code -> ggml_status changed ggml_status from a bit-field type to simple codes ggml_status to string cast * ggml_status to string cast * GGML_CALL was removed Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * sync : ggml ggml-ci * ggml : fix unknown status (#0) * flake : fix * llama : fix embeddings (ggerganov#5796) * llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list * nix: static build (ggerganov#5814) * fix speculative decoding build on windows (ggerganov#5874) * rebase and rm tailing space --------- Co-authored-by: LiangtaoJin <liang-tao.jin@intel.com> Co-authored-by: compilade <113953597+compilade@users.noreply.github.com> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: Michael Podvitskiy <podvitskiymichael@gmail.com> Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Nindaleth <Nindaleth@users.noreply.github.com> Co-authored-by: Black_Fox <radekliska@gmail.com> Co-authored-by: Douglas Hanley <thesecretaryofwar@gmail.com> Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: DAN™ <dranger003@gmail.com> Co-authored-by: leejet <leejet714@gmail.com> Co-authored-by: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> Co-authored-by: Dane Madsen <dane_madsen@hotmail.com> Co-authored-by: hutli <6594598+hutli@users.noreply.github.com> Co-authored-by: Jeffrey Quesnelle <emozilla@nousresearch.com>
* llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list
* fix mul_mat fault in cpy_f32_f16 * rm unused function * add wait() for memcpy * restore ci/run.sh, rename struct defination, fix bug in ggml_sycl_op_mul_mat_sycl * fix format issue * llama : fix segfault from unknown model arch name (ggerganov#5820) * llama : fix segfault from unknown model arch name * llama : make all LLM maps const This also requires using `std::map::at` instead of its `operator[]` which does not exist for const maps. * llama : name LLM_ARCH_UNKNOWN to "(unknown)" This avoids errors from `std::map::at` when getting the general name of the model architecture. Using "(unknown)" instead of an empty string as per suggestion ggerganov#5820 (comment) * llama : remove redundant inner const for LLM_TENSOR_NAMES The extra const won't do anything here as const maps return const references to values. Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * llama : remove redundant nullptr check in llm_arch_from_string Since LLM_ARCH_NAMES is a const map, no spurious elements with a NULL name are inserted anymore, so this check is dead code. --------- Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * llama : refactor internal quantization functions (ggerganov#5830) * scripts : add pod-llama.sh * ggml : IQ3_S improvements (ggerganov#5829) * iq3_s: somewhat faster AVX2 dot product On Ryzen a 7950X TG-128 increases to 16 t/s from 15.5 t/s using 16 threads. For 8 threads it is 13.85 t/s vs 11.75 t/s. PP-512 increases to 28.5 t/s from 23.8 t/s. * iq3_s: somewhat faster ARM_NEON dot product Still dog slow - 10.7 t/s up from 9.9 t/s. * iq3_s: another small ARM_NEON improvement 10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick that works best on AVX2. * iq3_s: minor improvement on Metal 49.4 t/s -> 50.3 t/s * iq3_s: PPL improvement E.g., for a context of 4096 LLaMA-v2-7B goes to 5.1340 from 5.1653. * iq3_s: use new grid everywhere * Fix ARM_NEON --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> * convert-hf : make model class definitions self-contained (ggerganov#5825) * convert : automatically fall back to HfVocab if tokenizer.model doesn't exist (ggerganov#5821) * ggml : fix IQ3_S AVX implementation (ggerganov#5834) ggml-ci * llama : add abort_callback to interrupt computation (ggerganov#5409) * using abort_callback from ggml to stop llama computation * format fix * a brief explaining comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: tests: passkey challenge / self-extend with context shift demo (ggerganov#5832) * server: tests: add models endpoint scenario * server: /v1/models add some metadata * server: tests: add debug field in context before scenario * server: tests: download model from HF, add batch size * server: tests: add passkey test * server: tests: add group attention params * server: do not truncate prompt tokens if self-extend through group attention is enabled * server: logs: do not truncate log values * server: tests - passkey - first good working value of nga * server: tests: fix server timeout * server: tests: fix passkey, add doc, fix regex content matching, fix timeout * server: tests: fix regex content matching * server: tests: schedule slow tests on master * server: metrics: fix when no prompt processed * server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1 * server: tests: increase timeout for completion * server: tests: keep only the PHI-2 test * server: tests: passkey add a negative test * flake.lock: Update (ggerganov#5842) Flake lock file updates: • Updated input 'flake-parts': 'github:hercules-ci/flake-parts/b253292d9c0a5ead9bc98c4e9a26c6312e27d69f' (2024-02-01) → 'github:hercules-ci/flake-parts/f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2' (2024-03-01) • Updated input 'flake-parts/nixpkgs-lib': 'github:NixOS/nixpkgs/97b17f32362e475016f942bbdfda4a4a72a8a652?dir=lib' (2024-01-29) → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8?dir=lib' (2024-02-29) • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/cbc4211f0afffe6dfd2478a62615dd5175a13f9a' (2024-02-23) → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8' (2024-02-29) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * server : init http requests thread pool with --parallel if set (ggerganov#5836) * ci : schedule slow server tests only on Release or on demand (ggerganov#5839) * llama : fix llama_copy_state_data with fragmented KV cache (ggerganov#5840) The row size of the saved states was based on kv_self.head while it should be based on llama_kv_cache_cell_max. Existing session files should still work. * llama : fix llama_kv_cache_cell_max inability to return 1 I've also changed its return type to uint32_t, because this function is always used to set the value of uint32_t variables, and because the index already has this type. * llama : fix state size calculation Some bytes in the state were unaccounted for in llama_get_state_size. Since the logits reserve so much space, it did not cause problems. * gguf-dump : support i-quants (ggerganov#5841) Co-authored-by: Black_Fox <radekliska@gmail.com> * llama : allow for user specified embedding pooling type (ggerganov#5849) * allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * readme : add API changes section * cuda : fix data race in soft max (ggerganov#5853) * main : support special tokens as reverse/anti prompt (ggerganov#5847) * Support special tokens as reverse/anti prompt. * Tokenize antiprompts only once. * main : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * common : use LLAMA_DEFAULT_SEED (ggerganov#5855) * add some new ops, fix some operators and add batch operations to certain operators. (ggml/747) * cuda: fix group_norm * cuda: add batch inference support for ggml_pad/ggml_upscale * add ggml_arrange * add ggml_timestep_embedding * update ggml_arange/ggml_timestep_embedding tests * cuda: fix im2col * add ggml_arange/ggml_timestep_embbeding support for metal backend * fix some bugs * fix some bugs * Update ggml.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-cuda.cu Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.metal Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * modify according to the review comments * ggml : fix compile warnings + code style * ggml : normalize compute_forward calls + fix seg fault in debug * minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com> * sync : ggml * add alias for chat template (ggerganov#5858) * speculative : implement stochastic speculative sampling (ggerganov#5625) * (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix ggerganov#5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README * cmake : handle cases where git index is not found in .git (ggerganov#5844) * Update CMakeLists.txt * Update CMakeLists.txt * ggml : introduce ggml_status (ggml/750) * using enum as an exit code instead of macros * update return type from enum to unsigned int * indentation fix * compound update ggml_compute_exit_code -> ggml_status changed ggml_status from a bit-field type to simple codes ggml_status to string cast * ggml_status to string cast * GGML_CALL was removed Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * sync : ggml ggml-ci * ggml : fix unknown status (#0) * flake : fix * llama : fix embeddings (ggerganov#5796) * llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list * nix: static build (ggerganov#5814) * fix speculative decoding build on windows (ggerganov#5874) * rebase and rm tailing space --------- Co-authored-by: LiangtaoJin <liang-tao.jin@intel.com> Co-authored-by: compilade <113953597+compilade@users.noreply.github.com> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: Michael Podvitskiy <podvitskiymichael@gmail.com> Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Nindaleth <Nindaleth@users.noreply.github.com> Co-authored-by: Black_Fox <radekliska@gmail.com> Co-authored-by: Douglas Hanley <thesecretaryofwar@gmail.com> Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: DAN™ <dranger003@gmail.com> Co-authored-by: leejet <leejet714@gmail.com> Co-authored-by: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> Co-authored-by: Dane Madsen <dane_madsen@hotmail.com> Co-authored-by: hutli <6594598+hutli@users.noreply.github.com> Co-authored-by: Jeffrey Quesnelle <emozilla@nousresearch.com>
* llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list
* fix mul_mat fault in cpy_f32_f16 * rm unused function * add wait() for memcpy * restore ci/run.sh, rename struct defination, fix bug in ggml_sycl_op_mul_mat_sycl * fix format issue * llama : fix segfault from unknown model arch name (ggerganov#5820) * llama : fix segfault from unknown model arch name * llama : make all LLM maps const This also requires using `std::map::at` instead of its `operator[]` which does not exist for const maps. * llama : name LLM_ARCH_UNKNOWN to "(unknown)" This avoids errors from `std::map::at` when getting the general name of the model architecture. Using "(unknown)" instead of an empty string as per suggestion ggerganov#5820 (comment) * llama : remove redundant inner const for LLM_TENSOR_NAMES The extra const won't do anything here as const maps return const references to values. Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * llama : remove redundant nullptr check in llm_arch_from_string Since LLM_ARCH_NAMES is a const map, no spurious elements with a NULL name are inserted anymore, so this check is dead code. --------- Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * llama : refactor internal quantization functions (ggerganov#5830) * scripts : add pod-llama.sh * ggml : IQ3_S improvements (ggerganov#5829) * iq3_s: somewhat faster AVX2 dot product On Ryzen a 7950X TG-128 increases to 16 t/s from 15.5 t/s using 16 threads. For 8 threads it is 13.85 t/s vs 11.75 t/s. PP-512 increases to 28.5 t/s from 23.8 t/s. * iq3_s: somewhat faster ARM_NEON dot product Still dog slow - 10.7 t/s up from 9.9 t/s. * iq3_s: another small ARM_NEON improvement 10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick that works best on AVX2. * iq3_s: minor improvement on Metal 49.4 t/s -> 50.3 t/s * iq3_s: PPL improvement E.g., for a context of 4096 LLaMA-v2-7B goes to 5.1340 from 5.1653. * iq3_s: use new grid everywhere * Fix ARM_NEON --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> * convert-hf : make model class definitions self-contained (ggerganov#5825) * convert : automatically fall back to HfVocab if tokenizer.model doesn't exist (ggerganov#5821) * ggml : fix IQ3_S AVX implementation (ggerganov#5834) ggml-ci * llama : add abort_callback to interrupt computation (ggerganov#5409) * using abort_callback from ggml to stop llama computation * format fix * a brief explaining comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: tests: passkey challenge / self-extend with context shift demo (ggerganov#5832) * server: tests: add models endpoint scenario * server: /v1/models add some metadata * server: tests: add debug field in context before scenario * server: tests: download model from HF, add batch size * server: tests: add passkey test * server: tests: add group attention params * server: do not truncate prompt tokens if self-extend through group attention is enabled * server: logs: do not truncate log values * server: tests - passkey - first good working value of nga * server: tests: fix server timeout * server: tests: fix passkey, add doc, fix regex content matching, fix timeout * server: tests: fix regex content matching * server: tests: schedule slow tests on master * server: metrics: fix when no prompt processed * server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1 * server: tests: increase timeout for completion * server: tests: keep only the PHI-2 test * server: tests: passkey add a negative test * flake.lock: Update (ggerganov#5842) Flake lock file updates: • Updated input 'flake-parts': 'github:hercules-ci/flake-parts/b253292d9c0a5ead9bc98c4e9a26c6312e27d69f' (2024-02-01) → 'github:hercules-ci/flake-parts/f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2' (2024-03-01) • Updated input 'flake-parts/nixpkgs-lib': 'github:NixOS/nixpkgs/97b17f32362e475016f942bbdfda4a4a72a8a652?dir=lib' (2024-01-29) → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8?dir=lib' (2024-02-29) • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/cbc4211f0afffe6dfd2478a62615dd5175a13f9a' (2024-02-23) → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8' (2024-02-29) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * server : init http requests thread pool with --parallel if set (ggerganov#5836) * ci : schedule slow server tests only on Release or on demand (ggerganov#5839) * llama : fix llama_copy_state_data with fragmented KV cache (ggerganov#5840) The row size of the saved states was based on kv_self.head while it should be based on llama_kv_cache_cell_max. Existing session files should still work. * llama : fix llama_kv_cache_cell_max inability to return 1 I've also changed its return type to uint32_t, because this function is always used to set the value of uint32_t variables, and because the index already has this type. * llama : fix state size calculation Some bytes in the state were unaccounted for in llama_get_state_size. Since the logits reserve so much space, it did not cause problems. * gguf-dump : support i-quants (ggerganov#5841) Co-authored-by: Black_Fox <radekliska@gmail.com> * llama : allow for user specified embedding pooling type (ggerganov#5849) * allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * readme : add API changes section * cuda : fix data race in soft max (ggerganov#5853) * main : support special tokens as reverse/anti prompt (ggerganov#5847) * Support special tokens as reverse/anti prompt. * Tokenize antiprompts only once. * main : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * common : use LLAMA_DEFAULT_SEED (ggerganov#5855) * add some new ops, fix some operators and add batch operations to certain operators. (ggml/747) * cuda: fix group_norm * cuda: add batch inference support for ggml_pad/ggml_upscale * add ggml_arrange * add ggml_timestep_embedding * update ggml_arange/ggml_timestep_embedding tests * cuda: fix im2col * add ggml_arange/ggml_timestep_embbeding support for metal backend * fix some bugs * fix some bugs * Update ggml.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-cuda.cu Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.metal Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * modify according to the review comments * ggml : fix compile warnings + code style * ggml : normalize compute_forward calls + fix seg fault in debug * minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com> * sync : ggml * add alias for chat template (ggerganov#5858) * speculative : implement stochastic speculative sampling (ggerganov#5625) * (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix ggerganov#5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README * cmake : handle cases where git index is not found in .git (ggerganov#5844) * Update CMakeLists.txt * Update CMakeLists.txt * ggml : introduce ggml_status (ggml/750) * using enum as an exit code instead of macros * update return type from enum to unsigned int * indentation fix * compound update ggml_compute_exit_code -> ggml_status changed ggml_status from a bit-field type to simple codes ggml_status to string cast * ggml_status to string cast * GGML_CALL was removed Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * sync : ggml ggml-ci * ggml : fix unknown status (#0) * flake : fix * llama : fix embeddings (ggerganov#5796) * llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list * nix: static build (ggerganov#5814) * fix speculative decoding build on windows (ggerganov#5874) * rebase and rm tailing space --------- Co-authored-by: LiangtaoJin <liang-tao.jin@intel.com> Co-authored-by: compilade <113953597+compilade@users.noreply.github.com> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: Michael Podvitskiy <podvitskiymichael@gmail.com> Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Nindaleth <Nindaleth@users.noreply.github.com> Co-authored-by: Black_Fox <radekliska@gmail.com> Co-authored-by: Douglas Hanley <thesecretaryofwar@gmail.com> Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: DAN™ <dranger003@gmail.com> Co-authored-by: leejet <leejet714@gmail.com> Co-authored-by: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> Co-authored-by: Dane Madsen <dane_madsen@hotmail.com> Co-authored-by: hutli <6594598+hutli@users.noreply.github.com> Co-authored-by: Jeffrey Quesnelle <emozilla@nousresearch.com>
* Updated binaries, using [this build](https://github.com/SciSharp/LLamaSharp/actions/runs/8654672719/job/23733195669) for llama.cpp commit `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7`. - Added all new functions. - Moved some functions (e.g. `SafeLlamaModelHandle` specific functions) into `SafeLlamaModelHandle.cs` - Exposed tokens on `SafeLlamaModelHandle` and `LLamaWeights` through a `Tokens` property. As new special tokens are added in the future they can be added here. - Changed all token properties to return nullable tokens, to handle some models not having some tokens. - Fixed `DefaultSamplingPipeline` to handle no newline token in some models. * Moved native methods to more specific locations. - Context specific things have been moved into `SafeLLamaContextHandle.cs` and made private - they're exposed through C# properties and methods already. - Checking that GPU layer count is zero if GPU offload is not supported. - Moved methods for creating default structs (`llama_model_quantize_default_params` and `llama_context_default_params`) into relevant structs. * Removed exception if `GpuLayerCount > 0` when GPU is not supported. * - Added low level wrapper methods for new per-sequence state load/save in `SafeLLamaContextHandle` - Added high level wrapper methods (save/load with `State` object or memory mapped file) in `LLamaContext` - Moved native methods for per-sequence state load/save into `SafeLLamaContextHandle` * Added update and defrag methods for KV cache in `SafeLLamaContextHandle` * Updated submodule to `f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7` * Passing the sequence ID when saving a single sequence state
ref #5655, #5783
llama_batch.logits
now also indicates if embeddings are output for that tokenllama_get_embeddings_ith()
to return token embeddingsllama_get_embeddings_seq()
to return sequence embeddingsembedding
example to work both with BERT and non-BERT modelsserver
to get the resulting embeddings correctlyexamples/server-embd.py
helper scriptserver
supports--pooling
TODO:
llama_batch.logits
tollama_batch.output
(future PR)server
should not queue partial prompts when computing embeddings (future PR)