-
Notifications
You must be signed in to change notification settings - Fork 10k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml-quants : ternary packing for TriLMs and BitNet b1.58 #8151
Conversation
Not using a lookup table anymore makes it match q4_0 speed. * gguf-py : fix formatting * llama : remove spaces on empty line
This makes the 1.625 bpw type go faster than q4_0. Still not the fastest.
This still results in the exact same tensor weights and scales, but it reveals some weirdness in the current algorithm.
Its FFN size is 5460 which is not convenient. The offending tensors are kept in F16, which makes the final model 5.01 bpw.
4522ed7
to
0996149
Compare
Wonderful job! I'm wondering if this PR can merge into the master branch, it would be so good if users of llama.cpp can use Q2_2 and Q1_3 conveniently. |
examples/quantize/quantize.cpp
Outdated
{ "Q1_3", LLAMA_FTYPE_MOSTLY_Q1_3, " 1.63 bpw for BitNet b1.58", }, | ||
{ "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2.00 bpw for BitNet b1.58", }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the names of the new quant types, since these are quite specific to BitNet models, I was thinking to name them with something starting with QB
, a bit like suggested in #5761 (comment).
I'll first be describing what I want from the naming scheme, then I'll attempt to make it work.
The naming scheme should have room for:
- Ternary types in
{-1, 0, 1}
1.625 bpw
quant with a block size of 64, with 13 bytes per block- To make the smallest possible lossless BitNet b1.58 model files
- Uses
Q8_0
as itsvec_dot_type
(for the activations) - (It's technically possible to store a
float16
scale in the leftover bits in the last byte of 16 consecutive blocks (this means 1024 elements minimum per row), although it can't really be extracted with SIMD)
2.000 bpw
quant with a block size of 32, with 8 bytes per block- For maximal performance
- Uses
Q8_0
as itsvec_dot_type
(for the activations)
2.000 bpw
quant with a block size of 64, with 16 bytes per block, and a float16 scale- Values would be packed similarly to the
1.625 bpw
type, but with an extra byte and a row-wisefloat16
scale duplicated in each block.
- Values would be packed similarly to the
2.000 bpw
quant with a block size of 4, with 1 byte per block- For weirdly-shaped models like the 1.3B BitNet b1.58 model
- Needs a compatible
vec_dot_type
- float types are slower than integer types for this
- Binary types in
{-1, 1}
1 bpw
type
- Binary types in
{0, 1}
- Are there models which use this?
- 8-bit activation with a row-wise scale
8.5 bpw
likeQ8_0
, but all the scales of a row are the same- Would allow reducing the number of
float32
operations in the vec_dot of the above types.
- Would allow reducing the number of
10 bpw
, 5 bytes per block of 4 elements, with a weird layout which only uses blocks to get a big enough buffer, with a single float32 scale and some padding before all row elements, aligned and contiguous.- For use with the weird
2.000 bpw
type, and also maybe the other ones for best performance.
- For use with the weird
So the naming scheme could be:
QB<x>_<y>
- where
<x>
is the floor of the expected bpw of the type - where
<y>
is0
binary type,{0, 1}
- except for
QB8_0
which is likeQ8_0
but with a guaranteed duplicated row-wise scale
- except for
1
binary type,{-1, 1}
2
ternary type using some kind of binary-coded ternary3
ternary type with fixed-point packed values4
weird type with a block size of 4
- where
Which for the previously-mentioned possible BitNet types would mean:
proposed name | Range | bits per weight | block size | bytes | row-wise scale | current name |
---|---|---|---|---|---|---|
QB1_3 |
{-1, 0, 1} |
1.625 | 64 | 13 | 1.0f | Q1_3 |
QB2_2 |
{-2, -1, 0, 1} |
2.000 | 32 | 8 | 1.0f | Q2_2 |
QB2_3 |
{-1, 0, 1} |
2.000 | 64 | 16 | f16 |
|
QB2_4 |
{-2, -1, 0, 1} |
2.000 | 4 | 1 | 1.0f | |
QB1_1 |
{-1, 1} |
1.000 | ? | ?/8 | 1.0f | |
QB1_0 |
{0, 1} |
1.000 | ? | ?/8 | 1.0f | |
QB8_0 |
[-127, 127] |
8.5 | 32 | 34 | f16 |
|
QB8_4 |
[-127, 127] |
10 | 4 | 5 | f32 , weird layout |
I'm not saying these should all exist, though, only that the naming scheme should not be too limiting for possible future extensions (which might not exist anyway due to lack of time).
So I think I'll rename Q1_3
to QB1_3
, and Q2_2
to QB2_2
. Anyone has comments on this? Or a better naming scheme for the new BitNet quant types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it were me, considering this only works with bitnet models and nothing else, I'd want the designations to be exceptionally clear that they are different and shouldn't be used on just anything. "QB" is good, but I'd take it a step further and remove the Q entirely. As bitnet is being colloquially referred to as a "1-bit" model, B1 makes more sense. Considering the plausible range for weights, I'd cut it off at tenths and ditch the decimal. This leaves plenty of room for variations, while making the native BPW very clear. I feel this is superior to the arbitrary "_2" and "_3" subtypes.
So what I would propose is:
1.625bpw = B1_16
2.000bpw = B1_20
@compilade and @Eddie-Wang1120 continuing the race to the bottom 🥳 , glorious. Did some quick testing with the 3B model and it looks very good.
What surprises me a little, after reading about edit: also updated the files at https://huggingface.co/Green-Sky/bitnet_b1_58-3B-GGUF , for anyone else willing to test. |
Did a bit of testing myself, it runs and generates well but unfortunately it's the undertrained models rather than our implementation that's holding back BitNet adoption. For me Q1_3 is slower but this computer is CPU rather than memory bound.
I wonder if Q2_2 could be made faster if we used a block size of say 256 like the K-quants so that we can handle more than 64 bits of Q2_2 quants in each dot product loop. Aside from that I can't find any further way to improve that AVX implementation, and while it's ironic that we're using a |
Can't go with bigger blocks than 64 elements or else the 3B model won't be fully quantizable. (Its FFN size is 8640 (which factors into Its current block size is 32, which is the same as its vec_dot_type, Q8_0. What would also help with performance would be to somehow use an 8-bit vec_dot_type having a single float scale per row. Might be interesting to explore later, but
Yeah, with AVX2, there are no good widening addition instructions like on ARM NEON, so Meanwhile, NEON doesn't have the equivalent of |
ggml/src/ggml-quants.c
Outdated
const uint8_t xi0 = x0 < 0 ? 1 : x0 == 0 ? 2 : 3; | ||
const uint8_t xi1 = x1 < 0 ? 1 : x1 == 0 ? 2 : 3; | ||
const uint8_t xi2 = x2 < 0 ? 1 : x2 == 0 ? 2 : 3; | ||
const uint8_t xi3 = x3 < 0 ? 1 : x3 == 0 ? 2 : 3; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As proposed, the type utilizes only 3 of the 4 possible values. I was thinking that the Q2_2
type would work the same as Q4_0
, but assumes amax == 1.0f
:
void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict y, int64_t k) {
static const int qk = QK2_2;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < qk; j++) {
const float v = x[i*qk + j];
if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
}
// assume amax = 1.0f
max /= amax;
const float d = max / -2;
const float id = d ? 1.0f/d : 0.0f;
for (int j = 0; j < qk/4; ++j) {
const float x0 = x[i*qk + 0*qk/4 + j]*id;
const float x1 = x[i*qk + 1*qk/4 + j]*id;
const float x2 = x[i*qk + 2*qk/4 + j]*id;
const float x3 = x[i*qk + 3*qk/4 + j]*id;
const uint8_t xi0 = MIN(3, (int8_t)(x0 + 2.5f));
const uint8_t xi1 = MIN(3, (int8_t)(x1 + 2.5f));
const uint8_t xi2 = MIN(3, (int8_t)(x2 + 2.5f));
const uint8_t xi3 = MIN(3, (int8_t)(x3 + 2.5f));
y[i].qs[j] = xi0;
y[i].qs[j] |= xi1 << 2;
y[i].qs[j] |= xi2 << 4;
y[i].qs[j] |= xi3 << 6;
}
}
}
(not tested, just pattern matching the existing quantize_row_q4_0_reference()
)
Edit: just realized the above would not work. We have assume that max == 1.0f
, not amax
, so:
const float max = 1.0f;
...
I am trying to test the TriLM_3.9B_Unpacked with both Using this exact pull request, I am building The issue comes when I am trying to perform inference on the Apple GPU: /path_to_built_llama/llama_cli -m quants/TriLM_3.9B_Unpacked_quant_TQ2_0.gguf -p "hey there"
Log start
main: build = 3610 (35cc5567)
main: built with Apple clang version 15.0.0 (clang-1500.3.9.4) for arm64-apple-darwin23.6.0
main: seed = 1724234806
llama_model_loader: loaded meta data with 28 key-value pairs and 273 tensors from quants/TriLM_3.9B_Unpacked_quant_TQ2_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = llama
llama_model_loader: - kv 1: general.type str = model
llama_model_loader: - kv 2: general.size_label str = 4.0B
llama_model_loader: - kv 3: general.license str = apache-2.0
llama_model_loader: - kv 4: llama.block_count u32 = 30
llama_model_loader: - kv 5: llama.context_length u32 = 2048
llama_model_loader: - kv 6: llama.embedding_length u32 = 3072
llama_model_loader: - kv 7: llama.feed_forward_length u32 = 9216
llama_model_loader: - kv 8: llama.attention.head_count u32 = 24
llama_model_loader: - kv 9: llama.attention.head_count_kv u32 = 24
llama_model_loader: - kv 10: llama.rope.freq_base f32 = 10000.000000
llama_model_loader: - kv 11: llama.attention.layer_norm_rms_epsilon f32 = 0.000010
llama_model_loader: - kv 12: llama.attention.layer_norm_epsilon f32 = 0.000010
llama_model_loader: - kv 13: general.file_type u32 = 37
llama_model_loader: - kv 14: llama.vocab_size u32 = 50688
llama_model_loader: - kv 15: llama.rope.dimension_count u32 = 128
llama_model_loader: - kv 16: tokenizer.ggml.add_space_prefix bool = false
llama_model_loader: - kv 17: tokenizer.ggml.model str = gpt2
llama_model_loader: - kv 18: tokenizer.ggml.pre str = olmo
llama_model_loader: - kv 19: tokenizer.ggml.tokens arr[str,50688] = ["<|endoftext|>", "<|padding|>", "!",...
llama_model_loader: - kv 20: tokenizer.ggml.token_type arr[i32,50688] = [3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv 21: tokenizer.ggml.merges arr[str,50009] = ["Ġ Ġ", "Ġ t", "Ġ a", "h e", "i n...
llama_model_loader: - kv 22: tokenizer.ggml.bos_token_id u32 = 0
llama_model_loader: - kv 23: tokenizer.ggml.eos_token_id u32 = 0
llama_model_loader: - kv 24: tokenizer.ggml.unknown_token_id u32 = 0
llama_model_loader: - kv 25: tokenizer.ggml.add_bos_token bool = false
llama_model_loader: - kv 26: tokenizer.ggml.add_eos_token bool = false
llama_model_loader: - kv 27: general.quantization_version u32 = 2
llama_model_loader: - type f32: 61 tensors
llama_model_loader: - type q4_K: 1 tensors
llama_model_loader: - type q6_K: 1 tensors
llama_model_loader: - type tq2_0: 210 tensors
llm_load_vocab: special tokens cache size = 25
llm_load_vocab: token to piece cache size = 0.2984 MB
llm_load_print_meta: format = GGUF V3 (latest)
llm_load_print_meta: arch = llama
llm_load_print_meta: vocab type = BPE
llm_load_print_meta: n_vocab = 50688
llm_load_print_meta: n_merges = 50009
llm_load_print_meta: vocab_only = 0
llm_load_print_meta: n_ctx_train = 2048
llm_load_print_meta: n_embd = 3072
llm_load_print_meta: n_layer = 30
llm_load_print_meta: n_head = 24
llm_load_print_meta: n_head_kv = 24
llm_load_print_meta: n_rot = 128
llm_load_print_meta: n_swa = 0
llm_load_print_meta: n_embd_head_k = 128
llm_load_print_meta: n_embd_head_v = 128
llm_load_print_meta: n_gqa = 1
llm_load_print_meta: n_embd_k_gqa = 3072
llm_load_print_meta: n_embd_v_gqa = 3072
llm_load_print_meta: f_norm_eps = 0.0e+00
llm_load_print_meta: f_norm_rms_eps = 1.0e-05
llm_load_print_meta: f_clamp_kqv = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale = 0.0e+00
llm_load_print_meta: n_ff = 9216
llm_load_print_meta: n_expert = 0
llm_load_print_meta: n_expert_used = 0
llm_load_print_meta: causal attn = 1
llm_load_print_meta: pooling type = 0
llm_load_print_meta: rope type = 0
llm_load_print_meta: rope scaling = linear
llm_load_print_meta: freq_base_train = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn = 2048
llm_load_print_meta: rope_finetuned = unknown
llm_load_print_meta: ssm_d_conv = 0
llm_load_print_meta: ssm_d_inner = 0
llm_load_print_meta: ssm_d_state = 0
llm_load_print_meta: ssm_dt_rank = 0
llm_load_print_meta: model type = ?B
llm_load_print_meta: model ftype = TQ2_0 - 2.06 bpw ternary
llm_load_print_meta: model params = 3.99 B
llm_load_print_meta: model size = 1.08 GiB (2.33 BPW)
llm_load_print_meta: general.name = n/a
llm_load_print_meta: BOS token = 0 '<|endoftext|>'
llm_load_print_meta: EOS token = 0 '<|endoftext|>'
llm_load_print_meta: UNK token = 0 '<|endoftext|>'
llm_load_print_meta: LF token = 128 'Ä'
llm_load_print_meta: EOT token = 0 '<|endoftext|>'
llm_load_print_meta: max token length = 1024
llm_load_tensors: ggml ctx size = 0.26 MiB
ggml_backend_metal_log_allocated_size: allocated buffer, size = 1027.47 MiB, ( 1027.55 / 12288.02)
llm_load_tensors: offloading 30 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 31/31 layers to GPU
llm_load_tensors: Metal buffer size = 1027.46 MiB
llm_load_tensors: CPU buffer size = 83.53 MiB
....................................................................................
llama_new_context_with_model: n_ctx = 2048
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base = 10000.0
llama_new_context_with_model: freq_scale = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Pro
ggml_metal_init: picking default device: Apple M3 Pro
ggml_metal_init: using embedded metal library
ggml_metal_init: GPU name: Apple M3 Pro
ggml_metal_init: GPU family: MTLGPUFamilyApple9 (1009)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3 (5001)
ggml_metal_init: simdgroup reduction support = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory = true
ggml_metal_init: recommendedMaxWorkingSetSize = 12884.92 MB
llama_kv_cache_init: Metal KV buffer size = 720.00 MiB
llama_new_context_with_model: KV self size = 720.00 MiB, K (f16): 360.00 MiB, V (f16): 360.00 MiB
llama_new_context_with_model: CPU output buffer size = 0.19 MiB
llama_new_context_with_model: Metal compute buffer size = 124.00 MiB
llama_new_context_with_model: CPU compute buffer size = 10.01 MiB
llama_new_context_with_model: graph nodes = 966
llama_new_context_with_model: graph splits = 2
ggml/src/ggml-metal.m:1619: MUL MAT-MAT not implemented
ggml/src/ggml-metal.m:1619: MUL MAT-MAT not implemented
[1] 36927 abort /Users/basavyr/Repos/external/llama.cpp/llama-cli -m -p "hey there" This error does not occur with the GPU inference explicitly disabled, via the Q: Am I missing something ? Did anyone else try to test this on M1/2/3 GPUs? |
@basavyr Could you share the quantified files on Huggingface? Then I'll happily give it a try on my Macbook Pro M1. |
I don't think TQ packing support GPU inference yet |
It does not (yet). But for at least I'll see what I can do, but I can't test Metal kernels directly, so I'll likely postpone full support to a follow-up pull-request. |
@flatsiedatsie The two quantized models are available here. Feel free to try them :) However, as per @sorasoras and @ggerganov, we might have to wait until support for Metal inference is officially confirmed. @ggerganov I will give that a try and see if it works. Thanks a lot guys for the support 🎉🙏 Edit: In the meantime, I have managed to perform Metal quantization ( |
Can you test whether https://github.com/compilade/llama.cpp/tree/compilade/bitnet-ternary-metal allows you to run
For the Metal kernels, I've mostly used the code from ikawrakow/ik_llama.cpp#13, but like I said, I can't test it directly because I don't have Apple hardware. If it does not work, then I'll leave that unimplemented here because debugging over comments would not really be convenient. Also, I am not Georgi :) |
It runs and looks correct.
FULL LOG
|
@compilade Don't worry about the Metal implementation. I can add this in a separate PR |
@compilade Sorry for the late answer... I have managed to compile your fork of Moreover, I have also tried to quantize all three versions of Answering your questions:
GPU:
CPU (
|
Is there any reason to not merge this? I'm already using it in a project built on Wllama, and now I have to take extra steps to compile it each time. Don't let the perfect be the enemy of the good. |
It would otherwise conflict with the more general optimization coming with Mamba-2. * ggml : handle TQ1_0 and TQ2_0 in dequantization-based operators
Not yet adding uncommented, because some backends like SYCL and Metal do not properly handle unknown types in supports_op for GGML_OP_MUL_MAT. (and Metal also doesn't handle it with GGML_OP_GET_ROWS) Support for TQ1_0 and TQ2_0 for other backends than CPU will be added in follow-up pull requests.
e4dc48a
to
75b3a09
Compare
Not really. It's pretty much ready (apart from support in other backends than CPU-only, and quantization to (the indexing experiment)Indices to extract 4 values per This tests the read order. tq1_0_pattern: list[list[int]] = [[32*v + b for v in range(5)] for b in range(32)] + [[16*v + b + 160 for v in range(5)] for b in range(16)] + [[4*v + b + 240 for v in range(4)] for b in range(4)]
print(f"{tq1_0_pattern=}")
for tid in range(64):
n = tid // 40; # 0 or 1
nh = tid // 60; # 0 or 1
il = tid // 5; # 0..13
ir = tid % 5; # 0..5
l = 32 - 16*n - 12*nh; # 32, 16 or 4
q: list[int] = [4*il + j for j in range(4)]
y: list[int] = [128*n + 64*nh + 4*il + l*ir + j for j in range(4)];
status = "good"
for a, b in zip(q, y):
if a >= len(tq1_0_pattern) or b != tq1_0_pattern[a][ir]:
status = "bad"
break
print(f"{tid=}, {q=}, {y=}, {status=}") Support for other backends than CPU will be added in separate pull requests. The only non-CPU backends I can possibly implement (and test) with the hardware I have are Vulkan (and CUDA, but only in evenings and weekends).
Right. And the recent commits (8d61607 and 75b3a09) should not be controversial (reducing the changes to I will merge this soon, either today or tomorrow if I forget. |
Thank you very much for this great job. Do you have any plans to further support risc-v devices? |
…8151) * ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b * ggml-quants : faster 1.625 bpw AVX2 vec_dot Not using a lookup table anymore makes it match q4_0 speed. * gguf-py : fix formatting * llama : remove spaces on empty line * ggml-quants : subtract 1 when back in epi8 This makes the 1.625 bpw type go faster than q4_0. Still not the fastest. * ggml-quants : Q2_2 now faster than Q4_K on with AVX2 * ggml-quants : cleanup Q1_3 code formatting * ggml-quants : ARM NEON vec_dot for q2_2 and q1_3 * ggml-quants : use ceiling division when quantizing q1_3 * convert-hf : simplify BitNet pre-quantization This still results in the exact same tensor weights and scales, but it reveals some weirdness in the current algorithm. * convert-hf : allow converting the weird BitNet 1.3B Its FFN size is 5460 which is not convenient. The offending tensors are kept in F16, which makes the final model 5.01 bpw. * bitnet : replace 1.58b with b1.58, as in the paper * ggml-quants : fix build failure on Windows * ggml-quants : attempt to fix Arm 32-bit support * ggml : add some informative comments in q1_3 vec_dot * ggml : add TQ1_0 and TQ2_0 ternary quantization types * ggml : even faster TQ2_0 * ggml : also faster TQ1_0 Same optimization as for TQ2_0 by offsetting the sum instead of the weights. This makes TQ1_0 almost as fast as Q8_0 on AVX2. * ggml : fix build issues in certain environments * ggml : add NEON vec_dot implementation for TQ1_0 and TQ2_0 * ggml : avoid directly using vmlal_high_s8, for 32-bit ARM compat The compiler seems smart enough to use the same instruction even when using vget_high_s8 instead. * ggml : remove q1_3 and q2_2 No more 1.625 bpw and 2.000 bpw, now instead using 1.6875 bpw and 2.0625 bpw with TQ1_0 and TQ2_0, respectively. * llama : remove the separate scale tensors of BitNet b1.58 They won't be needed, since the remaining ternary quant types have built-in scales. * ggml-quants : rename fields of TQ1_0 and TQ2_0 structs for consistency * ggml-quants : allow using vdotq_s32 in TQ2_0 vec_dot Not yet tested on hardware which supports it, might not work or might not even compile. But also it might. It should make the performance better on recent ARM CPUs. * ggml-quants : remove comment about possible format change of TQ2_0 Making it slightly more convenient for AVX512 but less convenient for everything else is not worth the trouble. * gguf-py : Numpy (de)quantization for TQ1_0 and TQ2_0 * ggml-quants : use roundf instead of nearest_int for TQ1_0 and TQ2_0 This does not change anything for ternary models, since their values should never end up being in halfway cases anyway. * convert : allow direct conversion to TQ1_0 and TQ2_0 The token embeddings and output tensors are kept in F16 to allow quantizing them to Q4_K and Q6_K with llama-quantize. * llama : handle fallback for TQ1_0 and TQ2_0 with Q4_0 Q4_0 is not completely symmetric (so not lossless for ternary models), but it should be good enough. * ggml-quants : allow using ARM dot product instructions for TQ1_0 * ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support * ggml : remove unused ggml_mul special case It would otherwise conflict with the more general optimization coming with Mamba-2. * ggml : handle TQ1_0 and TQ2_0 in dequantization-based operators * test-backend-ops : add TQ1_0 and TQ2_0 comments for later Not yet adding uncommented, because some backends like SYCL and Metal do not properly handle unknown types in supports_op for GGML_OP_MUL_MAT. (and Metal also doesn't handle it with GGML_OP_GET_ROWS) Support for TQ1_0 and TQ2_0 for other backends than CPU will be added in follow-up pull requests.
…8151) * ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b * ggml-quants : faster 1.625 bpw AVX2 vec_dot Not using a lookup table anymore makes it match q4_0 speed. * gguf-py : fix formatting * llama : remove spaces on empty line * ggml-quants : subtract 1 when back in epi8 This makes the 1.625 bpw type go faster than q4_0. Still not the fastest. * ggml-quants : Q2_2 now faster than Q4_K on with AVX2 * ggml-quants : cleanup Q1_3 code formatting * ggml-quants : ARM NEON vec_dot for q2_2 and q1_3 * ggml-quants : use ceiling division when quantizing q1_3 * convert-hf : simplify BitNet pre-quantization This still results in the exact same tensor weights and scales, but it reveals some weirdness in the current algorithm. * convert-hf : allow converting the weird BitNet 1.3B Its FFN size is 5460 which is not convenient. The offending tensors are kept in F16, which makes the final model 5.01 bpw. * bitnet : replace 1.58b with b1.58, as in the paper * ggml-quants : fix build failure on Windows * ggml-quants : attempt to fix Arm 32-bit support * ggml : add some informative comments in q1_3 vec_dot * ggml : add TQ1_0 and TQ2_0 ternary quantization types * ggml : even faster TQ2_0 * ggml : also faster TQ1_0 Same optimization as for TQ2_0 by offsetting the sum instead of the weights. This makes TQ1_0 almost as fast as Q8_0 on AVX2. * ggml : fix build issues in certain environments * ggml : add NEON vec_dot implementation for TQ1_0 and TQ2_0 * ggml : avoid directly using vmlal_high_s8, for 32-bit ARM compat The compiler seems smart enough to use the same instruction even when using vget_high_s8 instead. * ggml : remove q1_3 and q2_2 No more 1.625 bpw and 2.000 bpw, now instead using 1.6875 bpw and 2.0625 bpw with TQ1_0 and TQ2_0, respectively. * llama : remove the separate scale tensors of BitNet b1.58 They won't be needed, since the remaining ternary quant types have built-in scales. * ggml-quants : rename fields of TQ1_0 and TQ2_0 structs for consistency * ggml-quants : allow using vdotq_s32 in TQ2_0 vec_dot Not yet tested on hardware which supports it, might not work or might not even compile. But also it might. It should make the performance better on recent ARM CPUs. * ggml-quants : remove comment about possible format change of TQ2_0 Making it slightly more convenient for AVX512 but less convenient for everything else is not worth the trouble. * gguf-py : Numpy (de)quantization for TQ1_0 and TQ2_0 * ggml-quants : use roundf instead of nearest_int for TQ1_0 and TQ2_0 This does not change anything for ternary models, since their values should never end up being in halfway cases anyway. * convert : allow direct conversion to TQ1_0 and TQ2_0 The token embeddings and output tensors are kept in F16 to allow quantizing them to Q4_K and Q6_K with llama-quantize. * llama : handle fallback for TQ1_0 and TQ2_0 with Q4_0 Q4_0 is not completely symmetric (so not lossless for ternary models), but it should be good enough. * ggml-quants : allow using ARM dot product instructions for TQ1_0 * ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support * ggml : remove unused ggml_mul special case It would otherwise conflict with the more general optimization coming with Mamba-2. * ggml : handle TQ1_0 and TQ2_0 in dequantization-based operators * test-backend-ops : add TQ1_0 and TQ2_0 comments for later Not yet adding uncommented, because some backends like SYCL and Metal do not properly handle unknown types in supports_op for GGML_OP_MUL_MAT. (and Metal also doesn't handle it with GGML_OP_GET_ROWS) Support for TQ1_0 and TQ2_0 for other backends than CPU will be added in follow-up pull requests.
…8151) * ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b * ggml-quants : faster 1.625 bpw AVX2 vec_dot Not using a lookup table anymore makes it match q4_0 speed. * gguf-py : fix formatting * llama : remove spaces on empty line * ggml-quants : subtract 1 when back in epi8 This makes the 1.625 bpw type go faster than q4_0. Still not the fastest. * ggml-quants : Q2_2 now faster than Q4_K on with AVX2 * ggml-quants : cleanup Q1_3 code formatting * ggml-quants : ARM NEON vec_dot for q2_2 and q1_3 * ggml-quants : use ceiling division when quantizing q1_3 * convert-hf : simplify BitNet pre-quantization This still results in the exact same tensor weights and scales, but it reveals some weirdness in the current algorithm. * convert-hf : allow converting the weird BitNet 1.3B Its FFN size is 5460 which is not convenient. The offending tensors are kept in F16, which makes the final model 5.01 bpw. * bitnet : replace 1.58b with b1.58, as in the paper * ggml-quants : fix build failure on Windows * ggml-quants : attempt to fix Arm 32-bit support * ggml : add some informative comments in q1_3 vec_dot * ggml : add TQ1_0 and TQ2_0 ternary quantization types * ggml : even faster TQ2_0 * ggml : also faster TQ1_0 Same optimization as for TQ2_0 by offsetting the sum instead of the weights. This makes TQ1_0 almost as fast as Q8_0 on AVX2. * ggml : fix build issues in certain environments * ggml : add NEON vec_dot implementation for TQ1_0 and TQ2_0 * ggml : avoid directly using vmlal_high_s8, for 32-bit ARM compat The compiler seems smart enough to use the same instruction even when using vget_high_s8 instead. * ggml : remove q1_3 and q2_2 No more 1.625 bpw and 2.000 bpw, now instead using 1.6875 bpw and 2.0625 bpw with TQ1_0 and TQ2_0, respectively. * llama : remove the separate scale tensors of BitNet b1.58 They won't be needed, since the remaining ternary quant types have built-in scales. * ggml-quants : rename fields of TQ1_0 and TQ2_0 structs for consistency * ggml-quants : allow using vdotq_s32 in TQ2_0 vec_dot Not yet tested on hardware which supports it, might not work or might not even compile. But also it might. It should make the performance better on recent ARM CPUs. * ggml-quants : remove comment about possible format change of TQ2_0 Making it slightly more convenient for AVX512 but less convenient for everything else is not worth the trouble. * gguf-py : Numpy (de)quantization for TQ1_0 and TQ2_0 * ggml-quants : use roundf instead of nearest_int for TQ1_0 and TQ2_0 This does not change anything for ternary models, since their values should never end up being in halfway cases anyway. * convert : allow direct conversion to TQ1_0 and TQ2_0 The token embeddings and output tensors are kept in F16 to allow quantizing them to Q4_K and Q6_K with llama-quantize. * llama : handle fallback for TQ1_0 and TQ2_0 with Q4_0 Q4_0 is not completely symmetric (so not lossless for ternary models), but it should be good enough. * ggml-quants : allow using ARM dot product instructions for TQ1_0 * ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support * ggml : remove unused ggml_mul special case It would otherwise conflict with the more general optimization coming with Mamba-2. * ggml : handle TQ1_0 and TQ2_0 in dequantization-based operators * test-backend-ops : add TQ1_0 and TQ2_0 comments for later Not yet adding uncommented, because some backends like SYCL and Metal do not properly handle unknown types in supports_op for GGML_OP_MUL_MAT. (and Metal also doesn't handle it with GGML_OP_GET_ROWS) Support for TQ1_0 and TQ2_0 for other backends than CPU will be added in follow-up pull requests.
This adds
1.6875 bpw
and2.0625 bpw
quant types for TriLMs and BitNet b1.58 models. For now, these are namedTQ1_0
andTQ2_0
, respectively.I had given glimpses of this idea starting from #7931 (comment).
The
1.6875 bpw
type mostly relies on the fact that35 == 243 < 256 == 28
to pack 5 trits per byte.(I also made a blog post about ternary packing in an attempt to explain the core idea a bit more (storing the values in fixed-point to extract the most significant digit first with multiplications))
Huge thanks to @Eddie-Wang1120, who motivated this by adding initial BitNet b1.58 support in #7931.
How to try it
Using TriLM models is the easiest because all of their models have row sizes divisible by 256.
Important
To quantize the token embeddings and the output tensor to
Q4_K
andQ6_K
, you need to usellama-quantize
on the model files produced byconvert_hf_to_gguf.py --outtype tq1_0
(and also fortq2_0
). Otherwise these two tensors are kept asf16
and are responsible for most of the size of the models.If you want to try
TQ2_0
, which is faster (but bigger) thanTQ1_0
on compute-bound hardware, you can replacetq1_0
withtq2_0
in the above example, but it's also possible to convert from theTQ1_0
model file.The two ternary formats hold the same values, so round-trip quantizing between the two should result in the same files.
$ ./build/bin/llama-quantize --allow-requantize /somewhere/TriLM-3.9B-TQ1_0.gguf /somewhere/TriLM-3.9B-TQ2_0.gguf tq2_0
Speed
TQ2_0
is twice as fast asQ4_K
on my laptop. It's the fastest quant on compute-bound AVX2-capable computers.This is a table of the float32-equivalent throughput of the
vec_dot_q
operation for each of these quant types.t4g
(NEON)t4g
(DOTPROD)From this, it's easy to see that
TQ1_0
is usually slightly faster thanQ4_K
, and thatTQ2_0
is by far the fastest quant on AVX2.Note
There might be a way to make a similar type as
TQ2_0
like some sort ofQ2_1
, which could be almost as fast but still usable by non-ternary models, but this will probably require something like LQER to help with keeping some precision.Raw data (click to expand)
Intel Core m3-8100Y:
Arm Cortex A72 (Raspberry Pi 4):
Arm Cortex A53 (Some Android phone from 2017):
AWS
t4g.small
instance (Arm Neoverse N1) using NEON:AWS
t4g.small
(Arm Neoverse N1) with-march=native
:Size
The token embeddings are kept at
Q4_K
and the output projection atQ6_K
, which means the smaller models might be slightly bigger than 2 bits per weight.All of the TriLM models should work, because their row sizes are multiples of 256. I did not try them all yet, but those I tried are in the table below.
The BitNet b1.58 models from the 1bitLLM team however are not all compatible; only the 700M model has dimensions divisible by 256. The others are not supported (yet), unless when padding them.
Note
The 1.3B BitNet b1.58 model has a FFN size of 5460 which factors into
2 2 3 5 7 13
, which is not convenient for any block-wise types based on powers of 2, so these tensors are kept asF16
. My hypothesis is that 5460 was a typo for 5440 (factors into2 2 2 2 2 2 5 17
), but it was kept for some reason, and reproduced by the 1bitLLM team. If anyone training ternary models reads this, PLEASE DON'T USE5460
FOR THE FFN SIZE! Please use multiples of 256 for your row sizes.Perplexity
Quality seems good. I don't have a powerful machine, so my tests only include the first 16 chunks of wikitext-2-raw with https://huggingface.co/SpectraSuite/TriLM_390M_Unpacked.
The tests below use
Q4_K
token embeddings andQ6_K
output tensor forTQ1_0
andTQ2_0
, whileF16
token embeddings and output tensor is used inTQ1_0_L
andTQ2_0_L
.From this it seems like there is no significant quality loss for the ternary quants for TriLM models (I think the difference with pure
f16
comes from the 8-bit activations), and thatTQ1_0
andTQ2_0
are completely equivalent in quality (and they should be, because lossless conversion between the two is possible).Structure of
TQ1_0
This type relies on the fact that
3^5 == 243 < 256 == 2^8
.In a block of 256 elements, there are 240 elements encoded in 5 elements per byte, while the last 16 elements are encoded in 4 elements per byte.
This means
(240 / 5) + (16 / 4) == 48 + 4 == 52
bytes are used to pack 256 ternary weights (this is1.625
bits per weight).But there is also one
float16
scale per block, so the size of a block is 54 bytes making it a1.6875 bpw
type. Even though it's not ideal, this is still1.6875 / (log(3) / log(2)) ≈ 94%
of the best ternary packing efficiency.In the table below I'm describing the order of the elements within the bytes. I'm using ranges to make this shorter, with the notation
start..end
where the start is inclusive and the end is exclusive. (So0..3
is{0, 1, 2}
)Read this as if the ranges of a row are zipped together. A byte never contains more than 5 ternary values.
The ternary values are stored unsigned, so
{-1, 0, 1}
is stored as{0, 1, 2}
.0..32
0..32
32..64
64..96
96..128
128..160
32..48
160..176
176..192
192..208
208..224
224..240
48..52
240..244
244..248
248..252
252..256
And then byte
52
and53
contain thefloat16
scale in little-endian.Values are stored in fixed point to allow extracting the most significant digit first. This is explained in https://compilade.net/blog/ternary-packing.
Structure of
TQ2_0
This type was originally inspired by the
Q2_2
type made by @Eddie-Wang1120, but the block size, the order, and the mapping of the values are different.TQ2_0
started as an experiment to see how fast a 2-bit type can be compared to a 1.6-bit type on compute-bound hardware.This packs each ternary value in 2 bits, which means each byte contains 4 values.
The ternary values are stored unsigned, so
{-1, 0, 1}
is stored as{0, 1, 2}
.Again, the ranges use the
start..end
notation where the start is inclusive and the end is exclusive, and the ranges of a row should be read as being zipped together (they advance in parallel in lockstep).0..32
96..128
64..96
32..64
0..32
32..64
224..256
192..224
160..192
128..160
And then byte
64
and65
contain thefloat16
scale in little-endian.TODO
TQ1_0
andTQ2_0
convert_hf_to_gguf.py
to directly convert a ternary model to a ternary encodingf16
for the token embeddings and output tensor becauseQ4_K
andQ6_K
quantization is not yet supported bygguf-py
. This meansllama-quantize
needs to be used to quantize these tensors.llama-quantize
afterwards.TQ1_0_L
or something?float16
scale should be before or after the packed weightsQ2_K
,Q3_K
andQ6_K
) is to keep the scale before.llama-quantize
Q4_0
as a fallback type, because the smallest symmetric quant type isQ8_0
but it's a bit big, soQ4_0
it is (even though it's not ideal). Only relevant when row sizes are not multiples of 256.__ARM_FEATURE_DOTPROD
variants of the dot products ofTQ1_0
andTQ2_0
with their bare__ARM_NEON
variants to reduce code duplication.TQ1_0
andTQ2_0
for correctness on an ARM CPU which supports dot product instructionst4g.small
instance.TQ1_0
's first 48 bytes be divided in 3 sub-blocks of 16 bytes (80 elements) instead of one of 32 bytes (140 elements) and one of 16 bytes?Q1_3
andQ2_2
TQ1_0
andTQ2_0
.ggml_mul
when the broadcasted tensor only has a single element