Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml-quants : ternary packing for TriLMs and BitNet b1.58 #8151

Merged
merged 33 commits into from
Sep 6, 2024

Conversation

compilade
Copy link
Collaborator

@compilade compilade commented Jun 27, 2024

This adds 1.6875 bpw and 2.0625 bpw quant types for TriLMs and BitNet b1.58 models. For now, these are named TQ1_0 and TQ2_0, respectively.
I had given glimpses of this idea starting from #7931 (comment).

The 1.6875 bpw type mostly relies on the fact that 35 == 243 < 256 == 28 to pack 5 trits per byte.

(I also made a blog post about ternary packing in an attempt to explain the core idea a bit more (storing the values in fixed-point to extract the most significant digit first with multiplications))

Huge thanks to @Eddie-Wang1120, who motivated this by adding initial BitNet b1.58 support in #7931.

How to try it

Using TriLM models is the easiest because all of their models have row sizes divisible by 256.

Important

To quantize the token embeddings and the output tensor to Q4_K and Q6_K, you need to use llama-quantize on the model files produced by convert_hf_to_gguf.py --outtype tq1_0 (and also for tq2_0). Otherwise these two tensors are kept as f16 and are responsible for most of the size of the models.

$ python3 convert_hf_to_gguf.py /path/to/TriLM_3.9B_Unpacked/ --outfile /somewhere/TriLM-3.9B-TQ1_0-big.gguf --outtype tq1_0
$ ./build/bin/llama-quantize /somewhere/TriLM-3.9B-TQ1_0-big.gguf /somewhere/TriLM-3.9B-TQ1_0.gguf tq1_0

If you want to try TQ2_0, which is faster (but bigger) than TQ1_0 on compute-bound hardware, you can replace tq1_0 with tq2_0 in the above example, but it's also possible to convert from the TQ1_0 model file.

The two ternary formats hold the same values, so round-trip quantizing between the two should result in the same files.

$ ./build/bin/llama-quantize --allow-requantize /somewhere/TriLM-3.9B-TQ1_0.gguf /somewhere/TriLM-3.9B-TQ2_0.gguf tq2_0

Speed

TQ2_0 is twice as fast as Q4_K on my laptop. It's the fastest quant on compute-bound AVX2-capable computers.

This is a table of the float32-equivalent throughput of the vec_dot_q operation for each of these quant types.

CPU F16 Q8_0 Q4_K Q2_K TQ1_0 TQ2_0
Intel Core m3-8100Y (AVX2) 30.60 GB/s 67.03 GB/s 64.17 GB/s 81.73 GB/s 70.31 GB/s 141.83 GB/s
Arm Cortex A72 (NEON) 3.84 GB/s 9.51 GB/s 9.26 GB/s 9.79 GB/s 11.81 GB/s 15.78 GB/s
Arm Cortex A53 (NEON) 4.30 GB/s 5.87 GB/s 5.76 GB/s 5.84 GB/s 8.97 GB/s 10.29 GB/s
AWS t4g (NEON) 8.69 GB/s 22.35 GB/s 25.34 GB/s 22.84 GB/s 33.34 GB/s 44.80 GB/s
AWS t4g (DOTPROD) 49.17 GB/s 42.63 GB/s 45.40 GB/s 29.84 GB/s 40.44 GB/s 65.76 GB/s

From this, it's easy to see that TQ1_0 is usually slightly faster than Q4_K, and that TQ2_0 is by far the fastest quant on AVX2.

Note

There might be a way to make a similar type as TQ2_0 like some sort of Q2_1, which could be almost as fast but still usable by non-ternary models, but this will probably require something like LQER to help with keeping some precision.

Raw data (click to expand)

Intel Core m3-8100Y:

$ for t in bf16 f16 q8_0 q4_0 q4_K q2_K tq1_0 tq2_0; do ./bin/test-quantize-perf --op vec_dot_q -i 10000000 --type "$t"; done
bf16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      4.28
      avg cycles/32 vals   :      4.72
      float32 throughput   :     37.89 GB/s
      quantized throughput :     18.95 GB/s

f16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      5.52
      avg cycles/32 vals   :      5.93
      float32 throughput   :     30.60 GB/s
      quantized throughput :     15.30 GB/s

q8_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      2.27
      avg cycles/32 vals   :      2.56
      float32 throughput   :     67.03 GB/s
      quantized throughput :     17.81 GB/s

q4_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      3.04
      avg cycles/32 vals   :      3.38
      float32 throughput   :     52.20 GB/s
      quantized throughput :      7.34 GB/s

q4_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      2.22
      avg cycles/32 vals   :      2.61
      float32 throughput   :     64.17 GB/s
      quantized throughput :      9.02 GB/s

q2_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      1.77
      avg cycles/32 vals   :      1.99
      float32 throughput   :     81.73 GB/s
      quantized throughput :      6.70 GB/s

tq1_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      2.12
      avg cycles/32 vals   :      2.33
      float32 throughput   :     70.31 GB/s
      quantized throughput :      3.71 GB/s

tq2_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.85
      avg cycles/32 vals   :      0.97
      float32 throughput   :    141.83 GB/s
      quantized throughput :      9.14 GB/s

Arm Cortex A72 (Raspberry Pi 4):

$ for t in f16 q8_0 q4_K q2_K tq1_0 tq2_0; do ./bin/test-quantize-perf --op vec_dot_q -i 2000000 --type "$t"; done                                                                                        
f16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      3.84 GB/s
      quantized throughput :      1.92 GB/s

q8_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      9.51 GB/s
      quantized throughput :      2.53 GB/s

q4_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      9.26 GB/s
      quantized throughput :      1.30 GB/s

q2_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      9.79 GB/s
      quantized throughput :      0.80 GB/s

tq1_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     11.81 GB/s
      quantized throughput :      0.62 GB/s

tq2_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     15.78 GB/s
      quantized throughput :      1.02 GB/s

Arm Cortex A53 (Some Android phone from 2017):

$ for t in f16 q8_0 q4_K q2_K tq1_0 tq2_0; do ./bin/test-quantize-perf --op vec_dot_q -i 2000000 --type "$t"; done
f16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      4.30 GB/s
      quantized throughput :      2.15 GB/s

q8_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      5.87 GB/s
      quantized throughput :      1.56 GB/s

q4_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      5.76 GB/s
      quantized throughput :      0.81 GB/s

q2_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      5.84 GB/s
      quantized throughput :      0.48 GB/s

tq1_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      8.97 GB/s
      quantized throughput :      0.47 GB/s

tq2_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     10.29 GB/s
      quantized throughput :      0.66 GB/s

AWS t4g.small instance (Arm Neoverse N1) using NEON:

$ for t in f16 q8_0 q4_K q2_K tq1_0 tq2_0; do ./bin/test-quantize-perf --op vec_dot_q -i 2000000 --type "$t"; done
f16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      8.69 GB/s
      quantized throughput :      4.35 GB/s

q8_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     22.35 GB/s
      quantized throughput :      5.94 GB/s

q4_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     25.34 GB/s
      quantized throughput :      3.56 GB/s

q2_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     22.84 GB/s
      quantized throughput :      1.87 GB/s

tq1_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     33.34 GB/s
      quantized throughput :      1.76 GB/s

tq2_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     44.80 GB/s
      quantized throughput :      2.89 GB/s

AWS t4g.small (Arm Neoverse N1) with -march=native:

$ for t in f16 q8_0 q4_K q2_K tq1_0 tq2_0; do ./tests/test-quantize-perf --op vec_dot_q -i 2000000 --type "$t"; done
f16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     49.17 GB/s
      quantized throughput :     24.59 GB/s

q8_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     42.63 GB/s
      quantized throughput :     11.32 GB/s

q4_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     45.40 GB/s
      quantized throughput :      6.38 GB/s

q2_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     29.84 GB/s
      quantized throughput :      2.45 GB/s

tq1_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     40.44 GB/s
      quantized throughput :      2.13 GB/s

tq2_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     65.76 GB/s
      quantized throughput :      4.24 GB/s

Size

The token embeddings are kept at Q4_K and the output projection at Q6_K, which means the smaller models might be slightly bigger than 2 bits per weight.

All of the TriLM models should work, because their row sizes are multiples of 256. I did not try them all yet, but those I tried are in the table below.

The BitNet b1.58 models from the 1bitLLM team however are not all compatible; only the 700M model has dimensions divisible by 256. The others are not supported (yet), unless when padding them.

Model F16 TQ1_0 TQ2_0
https://huggingface.co/1bitLLM/bitnet_b1_58-large (728.84 M) 1391.26 MiB 176.65 MiB 207.03 MiB
https://huggingface.co/SpectraSuite/TriLM_390M_Unpacked 750.39 MiB 128.04 MiB 140.98 MiB
https://huggingface.co/SpectraSuite/TriLM_1.5B_Unpacked 2892.09 MiB 401.54 MiB 460.04 MiB
https://huggingface.co/SpectraSuite/TriLM_2.4B_Unpacked 4696.86 MiB 603.59 MiB 703.26 MiB
https://huggingface.co/SpectraSuite/TriLM_3.9B_Unpacked 7616.43 MiB 948.16 MiB 1112.70 MiB

Note

The 1.3B BitNet b1.58 model has a FFN size of 5460 which factors into 2 2 3 5 7 13, which is not convenient for any block-wise types based on powers of 2, so these tensors are kept as F16. My hypothesis is that 5460 was a typo for 5440 (factors into 2 2 2 2 2 2 5 17), but it was kept for some reason, and reproduced by the 1bitLLM team. If anyone training ternary models reads this, PLEASE DON'T USE 5460 FOR THE FFN SIZE! Please use multiples of 256 for your row sizes.

Perplexity

Quality seems good. I don't have a powerful machine, so my tests only include the first 16 chunks of wikitext-2-raw with https://huggingface.co/SpectraSuite/TriLM_390M_Unpacked.

The tests below use Q4_K token embeddings and Q6_K output tensor for TQ1_0 and TQ2_0, while F16 token embeddings and output tensor is used in TQ1_0_L and TQ2_0_L.

chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p
TQ1_0 16 23.6336 ± 1.0765 0.00463 ± 0.00141 0.00187 ± 0.00002 0.860 ± 0.020 % 97.279 ± 0.255 %
TQ2_0 16 23.6336 ± 1.0765 0.00463 ± 0.00141 0.00187 ± 0.00002 0.860 ± 0.020 % 97.279 ± 0.255 %
TQ1_0_L 16 23.5758 ± 1.0746 0.00218 ± 0.00112 0.00034 ± 0.00001 0.405 ± 0.012 % 98.971 ± 0.158 %
TQ2_0_L 16 23.5758 ± 1.0746 0.00218 ± 0.00112 0.00034 ± 0.00001 0.405 ± 0.012 % 98.971 ± 0.158 %

From this it seems like there is no significant quality loss for the ternary quants for TriLM models (I think the difference with pure f16 comes from the 8-bit activations), and that TQ1_0 and TQ2_0 are completely equivalent in quality (and they should be, because lossless conversion between the two is possible).

Structure of TQ1_0

This type relies on the fact that 3^5 == 243 < 256 == 2^8.

In a block of 256 elements, there are 240 elements encoded in 5 elements per byte, while the last 16 elements are encoded in 4 elements per byte.

This means (240 / 5) + (16 / 4) == 48 + 4 == 52 bytes are used to pack 256 ternary weights (this is 1.625 bits per weight).

But there is also one float16 scale per block, so the size of a block is 54 bytes making it a 1.6875 bpw type. Even though it's not ideal, this is still 1.6875 / (log(3) / log(2)) ≈ 94% of the best ternary packing efficiency.

In the table below I'm describing the order of the elements within the bytes. I'm using ranges to make this shorter, with the notation start..end where the start is inclusive and the end is exclusive. (So 0..3 is {0, 1, 2})

Read this as if the ranges of a row are zipped together. A byte never contains more than 5 ternary values.

The ternary values are stored unsigned, so {-1, 0, 1} is stored as {0, 1, 2}.

byte x * 3-1 x * 3-2 x * 3-3 x * 3-4 x * 3-5
0..32 0..32 32..64 64..96 96..128 128..160
32..48 160..176 176..192 192..208 208..224 224..240
48..52 240..244 244..248 248..252 252..256 N/A

And then byte 52 and 53 contain the float16 scale in little-endian.

Values are stored in fixed point to allow extracting the most significant digit first. This is explained in https://compilade.net/blog/ternary-packing.

Structure of TQ2_0

This type was originally inspired by the Q2_2 type made by @Eddie-Wang1120, but the block size, the order, and the mapping of the values are different.

TQ2_0 started as an experiment to see how fast a 2-bit type can be compared to a 1.6-bit type on compute-bound hardware.

This packs each ternary value in 2 bits, which means each byte contains 4 values.

The ternary values are stored unsigned, so {-1, 0, 1} is stored as {0, 1, 2}.

Again, the ranges use the start..end notation where the start is inclusive and the end is exclusive, and the ranges of a row should be read as being zipped together (they advance in parallel in lockstep).

byte x << 6 x << 4 x << 2 x << 0
0..32 96..128 64..96 32..64 0..32
32..64 224..256 192..224 160..192 128..160

And then byte 64 and 65 contain the float16 scale in little-endian.

TODO

  • Implement Numpy (de)quantization for TQ1_0 and TQ2_0
  • Allow convert_hf_to_gguf.py to directly convert a ternary model to a ternary encoding
    • Using f16 for the token embeddings and output tensor because Q4_K and Q6_K quantization is not yet supported by gguf-py. This means llama-quantize needs to be used to quantize these tensors.
    • Make it more obvious that the models should go through llama-quantize afterwards.
      • Maybe use other type names, like TQ1_0_L or something?
  • Decide whether the float16 scale should be before or after the packed weights
    • I'd prefer it after because I feel like the scales are read after the weights in dot products, but the convention with the other types (except for Q2_K, Q3_K and Q6_K) is to keep the scale before.
    • Okay, I've decided the scales should stay at the end.
  • More graceful fallback conversion with llama-quantize
    • Using Q4_0 as a fallback type, because the smallest symmetric quant type is Q8_0 but it's a bit big, so Q4_0 it is (even though it's not ideal). Only relevant when row sizes are not multiples of 256.
  • Unify the __ARM_FEATURE_DOTPROD variants of the dot products of TQ1_0 and TQ2_0 with their bare __ARM_NEON variants to reduce code duplication.
  • Test TQ1_0 and TQ2_0 for correctness on an ARM CPU which supports dot product instructions
    • Tested on an AWS t4g.small instance.
    • Also test relative performance for fun
  • Should TQ1_0's first 48 bytes be divided in 3 sub-blocks of 16 bytes (80 elements) instead of one of 32 bytes (140 elements) and one of 16 bytes?
    • I've done the 32-16 split to use 256-bit registers on AVX2 for the pow3 shifts for at least the 32 byte part, but 16-16-16 would be more regular, although it would require using 128-bit registers for all the ternary shifts. Not sure if there's a performance difference.
  • Rename references to "BitNet 1.58b" to "BitNet b1.58". The "b" comes before in the paper.
  • Find a naming convention for BitNet quants and rename Q1_3 and Q2_2
    • They were renamed and redesigned as TQ1_0 and TQ2_0.
  • Decide to keep or to remove the optimization for ggml_mul when the broadcasted tensor only has a single element
  • Fix Android CI build issues.
    • It was apparently a problem with Arm 32-bit. Fixed in 8fbd593

Not using a lookup table anymore makes it match q4_0 speed.

* gguf-py : fix formatting

* llama : remove spaces on empty line
This makes the 1.625 bpw type go faster than q4_0. Still not the fastest.
This still results in the exact same tensor weights and scales,
but it reveals some weirdness in the current algorithm.
Its FFN size is 5460 which is not convenient.
The offending tensors are kept in F16,
which makes the final model 5.01 bpw.
@compilade compilade added enhancement New feature or request python python script changes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs ggml changes relating to the ggml tensor library for machine learning Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes labels Jun 27, 2024
@compilade compilade force-pushed the compilade/bitnet-ternary branch from 4522ed7 to 0996149 Compare June 27, 2024 06:13
@github-actions github-actions bot added testing Everything test related examples labels Jun 27, 2024
@Eddie-Wang1120
Copy link
Contributor

Eddie-Wang1120 commented Jun 27, 2024

Wonderful job! I'm wondering if this PR can merge into the master branch, it would be so good if users of llama.cpp can use Q2_2 and Q1_3 conveniently.

@compilade compilade changed the title ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b ggml-quants : 1.625 bpw ternary packing for BitNet b1.58 Jun 27, 2024
Comment on lines 29 to 30
{ "Q1_3", LLAMA_FTYPE_MOSTLY_Q1_3, " 1.63 bpw for BitNet b1.58", },
{ "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2.00 bpw for BitNet b1.58", },
Copy link
Collaborator Author

@compilade compilade Jun 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the names of the new quant types, since these are quite specific to BitNet models, I was thinking to name them with something starting with QB, a bit like suggested in #5761 (comment).

I'll first be describing what I want from the naming scheme, then I'll attempt to make it work.

The naming scheme should have room for:

  • Ternary types in {-1, 0, 1}
    • 1.625 bpw quant with a block size of 64, with 13 bytes per block
      • To make the smallest possible lossless BitNet b1.58 model files
      • Uses Q8_0 as its vec_dot_type (for the activations)
      • (It's technically possible to store a float16 scale in the leftover bits in the last byte of 16 consecutive blocks (this means 1024 elements minimum per row), although it can't really be extracted with SIMD)
    • 2.000 bpw quant with a block size of 32, with 8 bytes per block
      • For maximal performance
      • Uses Q8_0 as its vec_dot_type (for the activations)
    • 2.000 bpw quant with a block size of 64, with 16 bytes per block, and a float16 scale
      • Values would be packed similarly to the 1.625 bpw type, but with an extra byte and a row-wise float16 scale duplicated in each block.
    • 2.000 bpw quant with a block size of 4, with 1 byte per block
      • For weirdly-shaped models like the 1.3B BitNet b1.58 model
      • Needs a compatible vec_dot_type
        • float types are slower than integer types for this
  • Binary types in {-1, 1}
    • 1 bpw type
  • Binary types in {0, 1}
    • Are there models which use this?
  • 8-bit activation with a row-wise scale
    • 8.5 bpw like Q8_0, but all the scales of a row are the same
      • Would allow reducing the number of float32 operations in the vec_dot of the above types.
    • 10 bpw, 5 bytes per block of 4 elements, with a weird layout which only uses blocks to get a big enough buffer, with a single float32 scale and some padding before all row elements, aligned and contiguous.
      • For use with the weird 2.000 bpw type, and also maybe the other ones for best performance.

So the naming scheme could be:

  • QB<x>_<y>
    • where <x> is the floor of the expected bpw of the type
    • where <y> is
      • 0 binary type, {0, 1}
        • except for QB8_0 which is like Q8_0 but with a guaranteed duplicated row-wise scale
      • 1 binary type, {-1, 1}
      • 2 ternary type using some kind of binary-coded ternary
      • 3 ternary type with fixed-point packed values
      • 4 weird type with a block size of 4

Which for the previously-mentioned possible BitNet types would mean:

proposed name Range bits per weight block size bytes row-wise scale current name
QB1_3 {-1, 0, 1} 1.625 64 13 1.0f Q1_3
QB2_2 {-2, -1, 0, 1} 2.000 32 8 1.0f Q2_2
QB2_3 {-1, 0, 1} 2.000 64 16 f16
QB2_4 {-2, -1, 0, 1} 2.000 4 1 1.0f
QB1_1 {-1, 1} 1.000 ? ?/8 1.0f
QB1_0 {0, 1} 1.000 ? ?/8 1.0f
QB8_0 [-127, 127] 8.5 32 34 f16
QB8_4 [-127, 127] 10 4 5 f32, weird layout

I'm not saying these should all exist, though, only that the naming scheme should not be too limiting for possible future extensions (which might not exist anyway due to lack of time).

So I think I'll rename Q1_3 to QB1_3, and Q2_2 to QB2_2. Anyone has comments on this? Or a better naming scheme for the new BitNet quant types?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it were me, considering this only works with bitnet models and nothing else, I'd want the designations to be exceptionally clear that they are different and shouldn't be used on just anything. "QB" is good, but I'd take it a step further and remove the Q entirely. As bitnet is being colloquially referred to as a "1-bit" model, B1 makes more sense. Considering the plausible range for weights, I'd cut it off at tenths and ditch the decimal. This leaves plenty of room for variations, while making the native BPW very clear. I feel this is superior to the arbitrary "_2" and "_3" subtypes.

So what I would propose is:

1.625bpw = B1_16
2.000bpw = B1_20

@Green-Sky
Copy link
Collaborator

Green-Sky commented Jun 29, 2024

@compilade and @Eddie-Wang1120 continuing the race to the bottom 🥳 , glorious.

Did some quick testing with the 3B model and it looks very good.

model size params backend threads test t/s
bitnet 3B Q1_3 - 1.625 bpw for BitNet b1.58 729.64 MiB 3.32 B BLAS 12 pp512 78.40 ± 0.27
bitnet 3B Q1_3 - 1.625 bpw for BitNet b1.58 729.64 MiB 3.32 B BLAS 12 tg128 38.16 ± 0.04
bitnet 3B Q2_2 - 2.000 bpw for BitNet b1.58 873.65 MiB 3.32 B BLAS 12 pp512 73.35 ± 6.23
bitnet 3B Q2_2 - 2.000 bpw for BitNet b1.58 873.65 MiB 3.32 B BLAS 12 tg128 36.86 ± 0.12

What surprises me a little, after reading about q2_2 being faster, is that q1_3 seems to be faster with the setup I used here. Will investigate further.

edit: also updated the files at https://huggingface.co/Green-Sky/bitnet_b1_58-3B-GGUF , for anyone else willing to test.

@netrunnereve
Copy link
Collaborator

Did a bit of testing myself, it runs and generates well but unfortunately it's the undertrained models rather than our implementation that's holding back BitNet adoption. For me Q1_3 is slower but this computer is CPU rather than memory bound.

model size params backend threads test t/s
bitnet 3B Q1_3 - 1.625 bpw for BitNet 1.58b 729.64 MiB 3.32 B CPU 4 pp512 15.15 ± 0.07
bitnet 3B Q1_3 - 1.625 bpw for BitNet 1.58b 729.64 MiB 3.32 B CPU 4 tg128 9.87 ± 0.65
bitnet 3B Q2_2 - 2.000 bpw for BitNet 1.58b 873.65 MiB 3.32 B CPU 4 pp512 19.25 ± 0.44
bitnet 3B Q2_2 - 2.000 bpw for BitNet 1.58b 873.65 MiB 3.32 B CPU 4 tg128 13.07 ± 0.28
bitnet 3B Q4_0 1.79 GiB 3.32 B CPU 4 pp512 18.44 ± 0.40
bitnet 3B Q4_0 1.79 GiB 3.32 B CPU 4 tg128 5.87 ± 0.12

I wonder if Q2_2 could be made faster if we used a block size of say 256 like the K-quants so that we can handle more than 64 bits of Q2_2 quants in each dot product loop. Aside from that I can't find any further way to improve that AVX implementation, and while it's ironic that we're using a madds instruction there when BitNet technically doesn't require multiplication that looks like the fastest way to dot the activations and ternary weights.

@compilade
Copy link
Collaborator Author

I wonder if Q2_2 could be made faster if we used a block size of say 256 like the K-quants

Can't go with bigger blocks than 64 elements or else the 3B model won't be fully quantizable. (Its FFN size is 8640 (which factors into 2 2 2 2 2 2 3 3 3 5))

Its current block size is 32, which is the same as its vec_dot_type, Q8_0.

What would also help with performance would be to somehow use an 8-bit vec_dot_type having a single float scale per row. Might be interesting to explore later, but ggml does not have row-wise quant types yet, although this could still be done with a block quant.

it's ironic that we're using a madds instruction

Yeah, with AVX2, there are no good widening addition instructions like on ARM NEON, so _mm256_maddubs_epi16 is used for that.

Meanwhile, NEON doesn't have the equivalent of _mm_sign_epi8, so it needs to use multiplications or conditional masks, which are both slower than a dedicated instruction doing zeroing and sign flipping like in SSSE3.

Comment on lines 675 to 678
const uint8_t xi0 = x0 < 0 ? 1 : x0 == 0 ? 2 : 3;
const uint8_t xi1 = x1 < 0 ? 1 : x1 == 0 ? 2 : 3;
const uint8_t xi2 = x2 < 0 ? 1 : x2 == 0 ? 2 : 3;
const uint8_t xi3 = x3 < 0 ? 1 : x3 == 0 ? 2 : 3;
Copy link
Owner

@ggerganov ggerganov Jul 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As proposed, the type utilizes only 3 of the 4 possible values. I was thinking that the Q2_2 type would work the same as Q4_0, but assumes amax == 1.0f:

void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict y, int64_t k) {
    static const int qk = QK2_2;

    assert(k % qk == 0);

    const int nb = k / qk;

    for (int i = 0; i < nb; i++) {
        float amax = 0.0f; // absolute max
        float max  = 0.0f;

        for (int j = 0; j < qk; j++) {
            const float v = x[i*qk + j];
            if (amax < fabsf(v)) {
                amax = fabsf(v);
                max  = v;
            }
        }

        // assume amax = 1.0f
        max /= amax;

        const float d  = max / -2;
        const float id = d ? 1.0f/d : 0.0f;

        for (int j = 0; j < qk/4; ++j) {
            const float x0 = x[i*qk + 0*qk/4 + j]*id;
            const float x1 = x[i*qk + 1*qk/4 + j]*id;
            const float x2 = x[i*qk + 2*qk/4 + j]*id;
            const float x3 = x[i*qk + 3*qk/4 + j]*id;

            const uint8_t xi0 = MIN(3, (int8_t)(x0 + 2.5f));
            const uint8_t xi1 = MIN(3, (int8_t)(x1 + 2.5f));
            const uint8_t xi2 = MIN(3, (int8_t)(x2 + 2.5f));
            const uint8_t xi3 = MIN(3, (int8_t)(x3 + 2.5f));

            y[i].qs[j]  = xi0;
            y[i].qs[j] |= xi1 << 2;
            y[i].qs[j] |= xi2 << 4;
            y[i].qs[j] |= xi3 << 6;
        }
    }
}

(not tested, just pattern matching the existing quantize_row_q4_0_reference())

Edit: just realized the above would not work. We have assume that max == 1.0f, not amax, so:

const float max = 1.0f;
...

ggml/src/ggml.c Outdated Show resolved Hide resolved
@basavyr
Copy link

basavyr commented Aug 21, 2024

I am trying to test the TriLM_3.9B_Unpacked with both TQ1_0 and TQ2_0 quants. Reading this discussion, I see that these two quantization methods are still supported on TriLM models (as opposed to the abandoned quantization for BitNet).

Using this exact pull request, I am building llama.cpp on a MacBook M3 Pro. The straightforward make -j n build command should build with Metal support by default (source). After building llama.cpp with success, I am firstly converting the HF model of TriLM_3.9B_Unpacked to f16-GGUF format, then finally quantizing with llama-quantize to the aforementioned formats. Everything works fine up until here.

The issue comes when I am trying to perform inference on the Apple GPU:

/path_to_built_llama/llama_cli -m quants/TriLM_3.9B_Unpacked_quant_TQ2_0.gguf -p "hey there"
Log start
main: build = 3610 (35cc5567)
main: built with Apple clang version 15.0.0 (clang-1500.3.9.4) for arm64-apple-darwin23.6.0
main: seed  = 1724234806
llama_model_loader: loaded meta data with 28 key-value pairs and 273 tensors from quants/TriLM_3.9B_Unpacked_quant_TQ2_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                         general.size_label str              = 4.0B
llama_model_loader: - kv   3:                            general.license str              = apache-2.0
llama_model_loader: - kv   4:                          llama.block_count u32              = 30
llama_model_loader: - kv   5:                       llama.context_length u32              = 2048
llama_model_loader: - kv   6:                     llama.embedding_length u32              = 3072
llama_model_loader: - kv   7:                  llama.feed_forward_length u32              = 9216
llama_model_loader: - kv   8:                 llama.attention.head_count u32              = 24
llama_model_loader: - kv   9:              llama.attention.head_count_kv u32              = 24
llama_model_loader: - kv  10:                       llama.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  11:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  12:         llama.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  13:                          general.file_type u32              = 37
llama_model_loader: - kv  14:                           llama.vocab_size u32              = 50688
llama_model_loader: - kv  15:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  16:            tokenizer.ggml.add_space_prefix bool             = false
llama_model_loader: - kv  17:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  18:                         tokenizer.ggml.pre str              = olmo
llama_model_loader: - kv  19:                      tokenizer.ggml.tokens arr[str,50688]   = ["<|endoftext|>", "<|padding|>", "!",...
llama_model_loader: - kv  20:                  tokenizer.ggml.token_type arr[i32,50688]   = [3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  21:                      tokenizer.ggml.merges arr[str,50009]   = ["Ġ Ġ", "Ġ t", "Ġ a", "h e", "i n...
llama_model_loader: - kv  22:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  23:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  24:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  25:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  26:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  27:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   61 tensors
llama_model_loader: - type q4_K:    1 tensors
llama_model_loader: - type q6_K:    1 tensors
llama_model_loader: - type tq2_0:  210 tensors
llm_load_vocab: special tokens cache size = 25
llm_load_vocab: token to piece cache size = 0.2984 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 50688
llm_load_print_meta: n_merges         = 50009
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 2048
llm_load_print_meta: n_embd           = 3072
llm_load_print_meta: n_layer          = 30
llm_load_print_meta: n_head           = 24
llm_load_print_meta: n_head_kv        = 24
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 3072
llm_load_print_meta: n_embd_v_gqa     = 3072
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 9216
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 2048
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = TQ2_0 - 2.06 bpw ternary
llm_load_print_meta: model params     = 3.99 B
llm_load_print_meta: model size       = 1.08 GiB (2.33 BPW)
llm_load_print_meta: general.name     = n/a
llm_load_print_meta: BOS token        = 0 '<|endoftext|>'
llm_load_print_meta: EOS token        = 0 '<|endoftext|>'
llm_load_print_meta: UNK token        = 0 '<|endoftext|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 0 '<|endoftext|>'
llm_load_print_meta: max token length = 1024
llm_load_tensors: ggml ctx size =    0.26 MiB
ggml_backend_metal_log_allocated_size: allocated buffer, size =  1027.47 MiB, ( 1027.55 / 12288.02)
llm_load_tensors: offloading 30 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 31/31 layers to GPU
llm_load_tensors:      Metal buffer size =  1027.46 MiB
llm_load_tensors:        CPU buffer size =    83.53 MiB
....................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Pro
ggml_metal_init: picking default device: Apple M3 Pro
ggml_metal_init: using embedded metal library
ggml_metal_init: GPU name:   Apple M3 Pro
ggml_metal_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction support   = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory              = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 12884.92 MB
llama_kv_cache_init:      Metal KV buffer size =   720.00 MiB
llama_new_context_with_model: KV self size  =  720.00 MiB, K (f16):  360.00 MiB, V (f16):  360.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.19 MiB
llama_new_context_with_model:      Metal compute buffer size =   124.00 MiB
llama_new_context_with_model:        CPU compute buffer size =    10.01 MiB
llama_new_context_with_model: graph nodes  = 966
llama_new_context_with_model: graph splits = 2
ggml/src/ggml-metal.m:1619: MUL MAT-MAT not implemented
ggml/src/ggml-metal.m:1619: MUL MAT-MAT not implemented
[1]    36927 abort      /Users/basavyr/Repos/external/llama.cpp/llama-cli -m  -p "hey there"

This error does not occur with the GPU inference explicitly disabled, via the --n-gpu-layers|-ngl 0 flag.

Q: Am I missing something ? Did anyone else try to test this on M1/2/3 GPUs?

@flatsiedatsie
Copy link

@basavyr Could you share the quantified files on Huggingface? Then I'll happily give it a try on my Macbook Pro M1.

@sorasoras
Copy link

I am trying to test the TriLM_3.9B_Unpacked with both TQ1_0 and TQ2_0 quants. Reading this discussion, I see that these two quantization methods are still supported on TriLM models (as opposed to the abandoned quantization for BitNet).

Using this exact pull request, I am building llama.cpp on a MacBook M3 Pro. The straightforward make -j n build command should build with Metal support by default (source). After building llama.cpp with success, I am firstly converting the HF model of TriLM_3.9B_Unpacked to f16-GGUF format, then finally quantizing with llama-quantize to the aforementioned formats. Everything works fine up until here.

The issue comes when I am trying to perform inference on the Apple GPU:

/path_to_built_llama/llama_cli -m quants/TriLM_3.9B_Unpacked_quant_TQ2_0.gguf -p "hey there"
Log start
main: build = 3610 (35cc5567)
main: built with Apple clang version 15.0.0 (clang-1500.3.9.4) for arm64-apple-darwin23.6.0
main: seed  = 1724234806
llama_model_loader: loaded meta data with 28 key-value pairs and 273 tensors from quants/TriLM_3.9B_Unpacked_quant_TQ2_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                         general.size_label str              = 4.0B
llama_model_loader: - kv   3:                            general.license str              = apache-2.0
llama_model_loader: - kv   4:                          llama.block_count u32              = 30
llama_model_loader: - kv   5:                       llama.context_length u32              = 2048
llama_model_loader: - kv   6:                     llama.embedding_length u32              = 3072
llama_model_loader: - kv   7:                  llama.feed_forward_length u32              = 9216
llama_model_loader: - kv   8:                 llama.attention.head_count u32              = 24
llama_model_loader: - kv   9:              llama.attention.head_count_kv u32              = 24
llama_model_loader: - kv  10:                       llama.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  11:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  12:         llama.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  13:                          general.file_type u32              = 37
llama_model_loader: - kv  14:                           llama.vocab_size u32              = 50688
llama_model_loader: - kv  15:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  16:            tokenizer.ggml.add_space_prefix bool             = false
llama_model_loader: - kv  17:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  18:                         tokenizer.ggml.pre str              = olmo
llama_model_loader: - kv  19:                      tokenizer.ggml.tokens arr[str,50688]   = ["<|endoftext|>", "<|padding|>", "!",...
llama_model_loader: - kv  20:                  tokenizer.ggml.token_type arr[i32,50688]   = [3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  21:                      tokenizer.ggml.merges arr[str,50009]   = ["Ġ Ġ", "Ġ t", "Ġ a", "h e", "i n...
llama_model_loader: - kv  22:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  23:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  24:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  25:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  26:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  27:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   61 tensors
llama_model_loader: - type q4_K:    1 tensors
llama_model_loader: - type q6_K:    1 tensors
llama_model_loader: - type tq2_0:  210 tensors
llm_load_vocab: special tokens cache size = 25
llm_load_vocab: token to piece cache size = 0.2984 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 50688
llm_load_print_meta: n_merges         = 50009
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 2048
llm_load_print_meta: n_embd           = 3072
llm_load_print_meta: n_layer          = 30
llm_load_print_meta: n_head           = 24
llm_load_print_meta: n_head_kv        = 24
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 3072
llm_load_print_meta: n_embd_v_gqa     = 3072
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 9216
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 2048
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = TQ2_0 - 2.06 bpw ternary
llm_load_print_meta: model params     = 3.99 B
llm_load_print_meta: model size       = 1.08 GiB (2.33 BPW)
llm_load_print_meta: general.name     = n/a
llm_load_print_meta: BOS token        = 0 '<|endoftext|>'
llm_load_print_meta: EOS token        = 0 '<|endoftext|>'
llm_load_print_meta: UNK token        = 0 '<|endoftext|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 0 '<|endoftext|>'
llm_load_print_meta: max token length = 1024
llm_load_tensors: ggml ctx size =    0.26 MiB
ggml_backend_metal_log_allocated_size: allocated buffer, size =  1027.47 MiB, ( 1027.55 / 12288.02)
llm_load_tensors: offloading 30 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 31/31 layers to GPU
llm_load_tensors:      Metal buffer size =  1027.46 MiB
llm_load_tensors:        CPU buffer size =    83.53 MiB
....................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Pro
ggml_metal_init: picking default device: Apple M3 Pro
ggml_metal_init: using embedded metal library
ggml_metal_init: GPU name:   Apple M3 Pro
ggml_metal_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction support   = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory              = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 12884.92 MB
llama_kv_cache_init:      Metal KV buffer size =   720.00 MiB
llama_new_context_with_model: KV self size  =  720.00 MiB, K (f16):  360.00 MiB, V (f16):  360.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.19 MiB
llama_new_context_with_model:      Metal compute buffer size =   124.00 MiB
llama_new_context_with_model:        CPU compute buffer size =    10.01 MiB
llama_new_context_with_model: graph nodes  = 966
llama_new_context_with_model: graph splits = 2
ggml/src/ggml-metal.m:1619: MUL MAT-MAT not implemented
ggml/src/ggml-metal.m:1619: MUL MAT-MAT not implemented
[1]    36927 abort      /Users/basavyr/Repos/external/llama.cpp/llama-cli -m  -p "hey there"

This error does not occur with the GPU inference explicitly disabled, via the --n-gpu-layers|-ngl 0 flag.

Q: Am I missing something ? Did anyone else try to test this on M1/2/3 GPUs?

I don't think TQ packing support GPU inference yet

@compilade
Copy link
Collaborator Author

I don't think TQ packing support GPU inference yet

It does not (yet). But for at least TQ2_0, it should be possible to adapt the Metal kernels from ikawrakow/ik_llama.cpp#13.

I'll see what I can do, but I can't test Metal kernels directly, so I'll likely postpone full support to a follow-up pull-request.

@basavyr
Copy link

basavyr commented Aug 22, 2024

@flatsiedatsie The two quantized models are available here. Feel free to try them :)

However, as per @sorasoras and @ggerganov, we might have to wait until support for Metal inference is officially confirmed.

@ggerganov I will give that a try and see if it works.

Thanks a lot guys for the support 🎉🙏


Edit: In the meantime, I have managed to perform Metal quantization (IQ2_TN) + inference on the same TriLM variant. You can play around with the .gguf added in this HF commit. This was possible through the PR that Georgi mentioned.

@flatsiedatsie
Copy link

flatsiedatsie commented Aug 22, 2024

Thanks for sharing the .gguf files @basavyr!

I ran a succesful test using Wllama, so this is 100% browser-based BitNet (running on the CPU):

Screenshot 2024-08-22 at 22 35 10

@compilade
Copy link
Collaborator Author

compilade commented Aug 22, 2024

@basavyr

Can you test whether https://github.com/compilade/llama.cpp/tree/compilade/bitnet-ternary-metal allows you to run TQ2_0 models on Metal? (this is another branch)

  • Does it compile?
  • Does the output looks correct?
  • Is it faster than when not using Metal?

For the Metal kernels, I've mostly used the code from ikawrakow/ik_llama.cpp#13, but like I said, I can't test it directly because I don't have Apple hardware. If it does not work, then I'll leave that unimplemented here because debugging over comments would not really be convenient.

Also, I am not Georgi :)

@flatsiedatsie
Copy link

flatsiedatsie commented Aug 23, 2024

It runs and looks correct.

6) Update Your Website Design:
You need to update your website design on a regular basis in order to keep your website fresh and relevant. This includes adding new content
llama_print_timings:        load time =    5112.49 ms
llama_print_timings:      sample time =      13.70 ms /   400 runs   (    0.03 ms per token, 29197.08 tokens per second)
llama_print_timings: prompt eval time =     125.92 ms /    15 tokens (    8.39 ms per token,   119.13 tokens per second)
llama_print_timings:        eval time =    7764.58 ms /   399 runs   (   19.46 ms per token,    51.39 tokens per second)
llama_print_timings:       total time =    7939.94 ms /   414 tokens
ggml_metal_free: deallocating

FULL LOG

./llama-cli -m ./TriLM_3.9B_Unpacked_quant_TQ2_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e
Log start
main: build = 3647 (2f5e28f9)
main: built with Apple clang version 15.0.0 (clang-1500.3.9.4) for arm64-apple-darwin23.5.0
main: seed  = 1724390564
llama_model_loader: loaded meta data with 28 key-value pairs and 273 tensors from ./TriLM_3.9B_Unpacked_quant_TQ2_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                         general.size_label str              = 4.0B
llama_model_loader: - kv   3:                            general.license str              = apache-2.0
llama_model_loader: - kv   4:                          llama.block_count u32              = 30
llama_model_loader: - kv   5:                       llama.context_length u32              = 2048
llama_model_loader: - kv   6:                     llama.embedding_length u32              = 3072
llama_model_loader: - kv   7:                  llama.feed_forward_length u32              = 9216
llama_model_loader: - kv   8:                 llama.attention.head_count u32              = 24
llama_model_loader: - kv   9:              llama.attention.head_count_kv u32              = 24
llama_model_loader: - kv  10:                       llama.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  11:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  12:         llama.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  13:                          general.file_type u32              = 37
llama_model_loader: - kv  14:                           llama.vocab_size u32              = 50688
llama_model_loader: - kv  15:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  16:            tokenizer.ggml.add_space_prefix bool             = false
llama_model_loader: - kv  17:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  18:                         tokenizer.ggml.pre str              = olmo
llama_model_loader: - kv  19:                      tokenizer.ggml.tokens arr[str,50688]   = ["<|endoftext|>", "<|padding|>", "!",...
llama_model_loader: - kv  20:                  tokenizer.ggml.token_type arr[i32,50688]   = [3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  21:                      tokenizer.ggml.merges arr[str,50009]   = ["Ġ Ġ", "Ġ t", "Ġ a", "h e", "i n...
llama_model_loader: - kv  22:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  23:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  24:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  25:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  26:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  27:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   61 tensors
llama_model_loader: - type q4_K:    1 tensors
llama_model_loader: - type q6_K:    1 tensors
llama_model_loader: - type tq2_0:  210 tensors
llm_load_vocab: special tokens cache size = 25
llm_load_vocab: token to piece cache size = 0.2984 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 50688
llm_load_print_meta: n_merges         = 50009
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 2048
llm_load_print_meta: n_embd           = 3072
llm_load_print_meta: n_layer          = 30
llm_load_print_meta: n_head           = 24
llm_load_print_meta: n_head_kv        = 24
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 3072
llm_load_print_meta: n_embd_v_gqa     = 3072
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 9216
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 2048
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = TQ2_0 - 2.06 bpw ternary
llm_load_print_meta: model params     = 3.99 B
llm_load_print_meta: model size       = 1.08 GiB (2.33 BPW) 
llm_load_print_meta: general.name     = n/a
llm_load_print_meta: BOS token        = 0 '<|endoftext|>'
llm_load_print_meta: EOS token        = 0 '<|endoftext|>'
llm_load_print_meta: UNK token        = 0 '<|endoftext|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 0 '<|endoftext|>'
llm_load_print_meta: max token length = 1024
llm_load_tensors: ggml ctx size =    0.26 MiB
ggml_backend_metal_log_allocated_size: allocated buffer, size =  1027.47 MiB, ( 1027.53 / 10922.67)
llm_load_tensors: offloading 30 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 31/31 layers to GPU
llm_load_tensors:      Metal buffer size =  1027.46 MiB
llm_load_tensors:        CPU buffer size =    83.53 MiB
....................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M1 Pro
ggml_metal_init: picking default device: Apple M1 Pro
ggml_metal_init: using embedded metal library
ggml_metal_init: GPU name:   Apple M1 Pro
ggml_metal_init: GPU family: MTLGPUFamilyApple7  (1007)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction support   = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory              = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 11453.25 MB
llama_kv_cache_init:      Metal KV buffer size =   720.00 MiB
llama_new_context_with_model: KV self size  =  720.00 MiB, K (f16):  360.00 MiB, V (f16):  360.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.19 MiB
llama_new_context_with_model:      Metal compute buffer size =   124.00 MiB
llama_new_context_with_model:        CPU compute buffer size =    10.01 MiB
llama_new_context_with_model: graph nodes  = 966
llama_new_context_with_model: graph splits = 2

system_info: n_threads = 6 / 8 | AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 0 | NEON = 1 | SVE = 0 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
	repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
	top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 2048, n_batch = 2048, n_predict = 400, n_keep = 0


Building a website can be done in 10 simple steps:
Step 1: Create your website domain.
Step 2: Create your website content.
Step 3: Create your website pages.
Step 4: Create your website design.
Step 5: Build your website on your hosting server.
Step 6: Update your website design.
Step 7: Create your website login and password.
Step 8: Build your website with the help of an expert.
Step 9: Test your website and enjoy it.
Step 10: Build your website with the help of an expert.
If you want to build a website and make it successful, then follow these simple steps and you will be able to build a great website.
1) Create Your Website Domain:
The first step is to create your own website domain. You can use any domain name you like, but it should be something that's relevant to your business. You can also use a website builder like Weebly or Wix, which will help you create your website quickly and easily.
2) Create Your Website Content:
You need to create your website content in order to make it successful. This includes writing articles, blogging, and creating videos.
3) Create Your Website Pages:
Once you have created your website content, you need to create your website pages. This includes creating pages for your blog, product page, and other sections.
4) Create Your Website Design:
You need to create your website design in order to make it successful. This includes choosing a design that's relevant to your business and using a professional web design company to help you.
5) Build Your Website on Your Hosting Server:
Once you have created your website design, you need to build your website on your hosting server. This is the most important step because it is the foundation of your website.
6) Update Your Website Design:
You need to update your website design on a regular basis in order to keep your website fresh and relevant. This includes adding new content
llama_print_timings:        load time =    5112.49 ms
llama_print_timings:      sample time =      13.70 ms /   400 runs   (    0.03 ms per token, 29197.08 tokens per second)
llama_print_timings: prompt eval time =     125.92 ms /    15 tokens (    8.39 ms per token,   119.13 tokens per second)
llama_print_timings:        eval time =    7764.58 ms /   399 runs   (   19.46 ms per token,    51.39 tokens per second)
llama_print_timings:       total time =    7939.94 ms /   414 tokens
ggml_metal_free: deallocating
Log end

@ggerganov
Copy link
Owner

For the Metal kernels, I've mostly used the code from ikawrakow/ik_llama.cpp#13, but like I said, I can't test it directly because I don't have Apple hardware

@compilade Don't worry about the Metal implementation. I can add this in a separate PR

@basavyr
Copy link

basavyr commented Aug 30, 2024

@compilade Sorry for the late answer...

I have managed to compile your fork of llama.cpp and successfully run the TQ2_0 inference on Metal for SpectraSuite/TriLM_3.9B_Unpacked. It looks like inference with Metal acceleration for TriLM_3.9B_Unpacked model is now possible ✅

Moreover, I have also tried to quantize all three versions of bitnet models from 🤗 (i.e., 3B, xl, and large) with the same branch, however it is not working ❌ (as expected)

Answering your questions:

  • Does it compile? -> YES
  • Does the output looks correct? -> If by this you mean the model output during inference, It is not able to provide good prompt responses.
  • Is it faster than when not using Metal? See results below👇

GPU:

llama_print_timings:        load time =     209.00 ms
llama_print_timings:      sample time =       4.94 ms /   256 runs   (    0.02 ms per token, 51874.37 tokens per second)
llama_print_timings: prompt eval time =      64.34 ms /     7 tokens (    9.19 ms per token,   108.80 tokens per second)
llama_print_timings:        eval time =    3476.64 ms /   255 runs   (   13.63 ms per token,    73.35 tokens per second)
llama_print_timings:       total time =    3558.48 ms /   262 tokens
ggml_metal_free: deallocating

CPU (--n-gpu-layers 0):

llama_print_timings:        load time =     112.92 ms
llama_print_timings:      sample time =       5.99 ms /   256 runs   (    0.02 ms per token, 42766.46 tokens per second)
llama_print_timings: prompt eval time =      78.54 ms /     7 tokens (   11.22 ms per token,    89.13 tokens per second)
llama_print_timings:        eval time =    4509.33 ms /   255 runs   (   17.68 ms per token,    56.55 tokens per second)
llama_print_timings:       total time =    4608.16 ms /   262 tokens

@flatsiedatsie
Copy link

Is there any reason to not merge this? I'm already using it in a project built on Wllama, and now I have to take extra steps to compile it each time.

Don't let the perfect be the enemy of the good.

It would otherwise conflict with the more general
optimization coming with Mamba-2.

* ggml : handle TQ1_0 and TQ2_0 in dequantization-based operators
Not yet adding uncommented, because some backends like SYCL and Metal
do not properly handle unknown types in supports_op for GGML_OP_MUL_MAT.
(and Metal also doesn't handle it with GGML_OP_GET_ROWS)
Support for TQ1_0 and TQ2_0 for other backends than CPU
will be added in follow-up pull requests.
@compilade compilade force-pushed the compilade/bitnet-ternary branch from e4dc48a to 75b3a09 Compare September 4, 2024 19:02
@compilade
Copy link
Collaborator Author

Is there any reason to not merge this?

Not really. It's pretty much ready (apart from support in other backends than CPU-only, and quantization to Q4_K and Q6_K not being supported in gguf-py yet, but I guess this can be fixed later once reference quantization is made platform independent (ref #8939 (comment))). I did have an hesitation with the order of the values in TQ1_0, but after making some experiments with indexing, its current structure (which ended up unchanged) should be good enough for future GPU implementations (hopefully).

(the indexing experiment)

Indices to extract 4 values per tid. This relies a lot on memory access coalescing, since each 5 consecutive tid will read the same 4 bytes (the last 4 bytes are only read by 4 tid, though).

This tests the read order.

tq1_0_pattern: list[list[int]] = [[32*v + b for v in range(5)] for b in range(32)] + [[16*v + b + 160 for v in range(5)] for b in range(16)] + [[4*v + b + 240 for v in range(4)] for b in range(4)]

print(f"{tq1_0_pattern=}")

for tid in range(64):
    n = tid // 40;  # 0 or 1
    nh = tid // 60; # 0 or 1
    il = tid // 5;  # 0..13
    ir = tid % 5;   # 0..5
    l = 32 - 16*n - 12*nh; # 32, 16 or 4

    q: list[int] = [4*il + j for j in range(4)]
    y: list[int] = [128*n + 64*nh + 4*il + l*ir + j for j in range(4)];

    status = "good"
    for a, b in zip(q, y):
        if a >= len(tq1_0_pattern) or b != tq1_0_pattern[a][ir]:
            status = "bad"
            break

    print(f"{tid=}, {q=}, {y=}, {status=}")

Support for other backends than CPU will be added in separate pull requests. The only non-CPU backends I can possibly implement (and test) with the hardware I have are Vulkan (and CUDA, but only in evenings and weekends). TQ2_0 should be relatively easy to port everywhere, while TQ1_0 will require more effort, but should still be implementable.

Don't let the perfect be the enemy of the good.

Right. And the recent commits (8d61607 and 75b3a09) should not be controversial (reducing the changes to ggml_mul, handling missing TQ1_0 and TQ2_0 cases in switch statements for some ggml operators with a dequantization-based variant (ggml_add, ggml_add1, etc.), and adding some comments in tests/test-backend-ops.cpp about TQ1_0 and TQ2_0).

I will merge this soon, either today or tomorrow if I forget.

@compilade compilade added the merge ready indicates that this may be ready to merge soon and is just holding out in case of objections label Sep 4, 2024
@compilade compilade merged commit 9bc6db2 into master Sep 6, 2024
55 checks passed
@WenguoLi
Copy link

Thank you very much for this great job. Do you have any plans to further support risc-v devices?

dsx1986 pushed a commit to dsx1986/llama.cpp that referenced this pull request Oct 29, 2024
…8151)

* ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b

* ggml-quants : faster 1.625 bpw AVX2 vec_dot

Not using a lookup table anymore makes it match q4_0 speed.

* gguf-py : fix formatting

* llama : remove spaces on empty line

* ggml-quants : subtract 1 when back in epi8

This makes the 1.625 bpw type go faster than q4_0. Still not the fastest.

* ggml-quants : Q2_2 now faster than Q4_K on with AVX2

* ggml-quants : cleanup Q1_3 code formatting

* ggml-quants : ARM NEON vec_dot for q2_2 and q1_3

* ggml-quants : use ceiling division when quantizing q1_3

* convert-hf : simplify BitNet pre-quantization

This still results in the exact same tensor weights and scales,
but it reveals some weirdness in the current algorithm.

* convert-hf : allow converting the weird BitNet 1.3B

Its FFN size is 5460 which is not convenient.
The offending tensors are kept in F16,
which makes the final model 5.01 bpw.

* bitnet : replace 1.58b with b1.58, as in the paper

* ggml-quants : fix build failure on Windows

* ggml-quants : attempt to fix Arm 32-bit support

* ggml : add some informative comments in q1_3 vec_dot

* ggml : add TQ1_0 and TQ2_0 ternary quantization types

* ggml : even faster TQ2_0

* ggml : also faster TQ1_0

Same optimization as for TQ2_0 by offsetting the sum instead of the weights.
This makes TQ1_0 almost as fast as Q8_0 on AVX2.

* ggml : fix build issues in certain environments

* ggml : add NEON vec_dot implementation for TQ1_0 and TQ2_0

* ggml : avoid directly using vmlal_high_s8, for 32-bit ARM compat

The compiler seems smart enough to use the same instruction
even when using vget_high_s8 instead.

* ggml : remove q1_3 and q2_2

No more 1.625 bpw and 2.000 bpw,
now instead using 1.6875 bpw and 2.0625 bpw
with TQ1_0 and TQ2_0, respectively.

* llama : remove the separate scale tensors of BitNet b1.58

They won't be needed, since the remaining ternary quant types have
built-in scales.

* ggml-quants : rename fields of TQ1_0 and TQ2_0 structs for consistency

* ggml-quants : allow using vdotq_s32 in TQ2_0 vec_dot

Not yet tested on hardware which supports it,
might not work or might not even compile. But also it might.
It should make the performance better on recent ARM CPUs.

* ggml-quants : remove comment about possible format change of TQ2_0

Making it slightly more convenient for AVX512
but less convenient for everything else is not worth the trouble.

* gguf-py : Numpy (de)quantization for TQ1_0 and TQ2_0

* ggml-quants : use roundf instead of nearest_int for TQ1_0 and TQ2_0

This does not change anything for ternary models,
since their values should never end up being in halfway cases anyway.

* convert : allow direct conversion to TQ1_0 and TQ2_0

The token embeddings and output tensors are kept in F16
to allow quantizing them to Q4_K and Q6_K with llama-quantize.

* llama : handle fallback for TQ1_0 and TQ2_0 with Q4_0

Q4_0 is not completely symmetric (so not lossless for ternary models),
but it should be good enough.

* ggml-quants : allow using ARM dot product instructions for TQ1_0

* ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support

* ggml : remove unused ggml_mul special case

It would otherwise conflict with the more general
optimization coming with Mamba-2.

* ggml : handle TQ1_0 and TQ2_0 in dequantization-based operators

* test-backend-ops : add TQ1_0 and TQ2_0 comments for later

Not yet adding uncommented, because some backends like SYCL and Metal
do not properly handle unknown types in supports_op for GGML_OP_MUL_MAT.
(and Metal also doesn't handle it with GGML_OP_GET_ROWS)
Support for TQ1_0 and TQ2_0 for other backends than CPU
will be added in follow-up pull requests.
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 15, 2024
…8151)

* ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b

* ggml-quants : faster 1.625 bpw AVX2 vec_dot

Not using a lookup table anymore makes it match q4_0 speed.

* gguf-py : fix formatting

* llama : remove spaces on empty line

* ggml-quants : subtract 1 when back in epi8

This makes the 1.625 bpw type go faster than q4_0. Still not the fastest.

* ggml-quants : Q2_2 now faster than Q4_K on with AVX2

* ggml-quants : cleanup Q1_3 code formatting

* ggml-quants : ARM NEON vec_dot for q2_2 and q1_3

* ggml-quants : use ceiling division when quantizing q1_3

* convert-hf : simplify BitNet pre-quantization

This still results in the exact same tensor weights and scales,
but it reveals some weirdness in the current algorithm.

* convert-hf : allow converting the weird BitNet 1.3B

Its FFN size is 5460 which is not convenient.
The offending tensors are kept in F16,
which makes the final model 5.01 bpw.

* bitnet : replace 1.58b with b1.58, as in the paper

* ggml-quants : fix build failure on Windows

* ggml-quants : attempt to fix Arm 32-bit support

* ggml : add some informative comments in q1_3 vec_dot

* ggml : add TQ1_0 and TQ2_0 ternary quantization types

* ggml : even faster TQ2_0

* ggml : also faster TQ1_0

Same optimization as for TQ2_0 by offsetting the sum instead of the weights.
This makes TQ1_0 almost as fast as Q8_0 on AVX2.

* ggml : fix build issues in certain environments

* ggml : add NEON vec_dot implementation for TQ1_0 and TQ2_0

* ggml : avoid directly using vmlal_high_s8, for 32-bit ARM compat

The compiler seems smart enough to use the same instruction
even when using vget_high_s8 instead.

* ggml : remove q1_3 and q2_2

No more 1.625 bpw and 2.000 bpw,
now instead using 1.6875 bpw and 2.0625 bpw
with TQ1_0 and TQ2_0, respectively.

* llama : remove the separate scale tensors of BitNet b1.58

They won't be needed, since the remaining ternary quant types have
built-in scales.

* ggml-quants : rename fields of TQ1_0 and TQ2_0 structs for consistency

* ggml-quants : allow using vdotq_s32 in TQ2_0 vec_dot

Not yet tested on hardware which supports it,
might not work or might not even compile. But also it might.
It should make the performance better on recent ARM CPUs.

* ggml-quants : remove comment about possible format change of TQ2_0

Making it slightly more convenient for AVX512
but less convenient for everything else is not worth the trouble.

* gguf-py : Numpy (de)quantization for TQ1_0 and TQ2_0

* ggml-quants : use roundf instead of nearest_int for TQ1_0 and TQ2_0

This does not change anything for ternary models,
since their values should never end up being in halfway cases anyway.

* convert : allow direct conversion to TQ1_0 and TQ2_0

The token embeddings and output tensors are kept in F16
to allow quantizing them to Q4_K and Q6_K with llama-quantize.

* llama : handle fallback for TQ1_0 and TQ2_0 with Q4_0

Q4_0 is not completely symmetric (so not lossless for ternary models),
but it should be good enough.

* ggml-quants : allow using ARM dot product instructions for TQ1_0

* ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support

* ggml : remove unused ggml_mul special case

It would otherwise conflict with the more general
optimization coming with Mamba-2.

* ggml : handle TQ1_0 and TQ2_0 in dequantization-based operators

* test-backend-ops : add TQ1_0 and TQ2_0 comments for later

Not yet adding uncommented, because some backends like SYCL and Metal
do not properly handle unknown types in supports_op for GGML_OP_MUL_MAT.
(and Metal also doesn't handle it with GGML_OP_GET_ROWS)
Support for TQ1_0 and TQ2_0 for other backends than CPU
will be added in follow-up pull requests.
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
…8151)

* ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b

* ggml-quants : faster 1.625 bpw AVX2 vec_dot

Not using a lookup table anymore makes it match q4_0 speed.

* gguf-py : fix formatting

* llama : remove spaces on empty line

* ggml-quants : subtract 1 when back in epi8

This makes the 1.625 bpw type go faster than q4_0. Still not the fastest.

* ggml-quants : Q2_2 now faster than Q4_K on with AVX2

* ggml-quants : cleanup Q1_3 code formatting

* ggml-quants : ARM NEON vec_dot for q2_2 and q1_3

* ggml-quants : use ceiling division when quantizing q1_3

* convert-hf : simplify BitNet pre-quantization

This still results in the exact same tensor weights and scales,
but it reveals some weirdness in the current algorithm.

* convert-hf : allow converting the weird BitNet 1.3B

Its FFN size is 5460 which is not convenient.
The offending tensors are kept in F16,
which makes the final model 5.01 bpw.

* bitnet : replace 1.58b with b1.58, as in the paper

* ggml-quants : fix build failure on Windows

* ggml-quants : attempt to fix Arm 32-bit support

* ggml : add some informative comments in q1_3 vec_dot

* ggml : add TQ1_0 and TQ2_0 ternary quantization types

* ggml : even faster TQ2_0

* ggml : also faster TQ1_0

Same optimization as for TQ2_0 by offsetting the sum instead of the weights.
This makes TQ1_0 almost as fast as Q8_0 on AVX2.

* ggml : fix build issues in certain environments

* ggml : add NEON vec_dot implementation for TQ1_0 and TQ2_0

* ggml : avoid directly using vmlal_high_s8, for 32-bit ARM compat

The compiler seems smart enough to use the same instruction
even when using vget_high_s8 instead.

* ggml : remove q1_3 and q2_2

No more 1.625 bpw and 2.000 bpw,
now instead using 1.6875 bpw and 2.0625 bpw
with TQ1_0 and TQ2_0, respectively.

* llama : remove the separate scale tensors of BitNet b1.58

They won't be needed, since the remaining ternary quant types have
built-in scales.

* ggml-quants : rename fields of TQ1_0 and TQ2_0 structs for consistency

* ggml-quants : allow using vdotq_s32 in TQ2_0 vec_dot

Not yet tested on hardware which supports it,
might not work or might not even compile. But also it might.
It should make the performance better on recent ARM CPUs.

* ggml-quants : remove comment about possible format change of TQ2_0

Making it slightly more convenient for AVX512
but less convenient for everything else is not worth the trouble.

* gguf-py : Numpy (de)quantization for TQ1_0 and TQ2_0

* ggml-quants : use roundf instead of nearest_int for TQ1_0 and TQ2_0

This does not change anything for ternary models,
since their values should never end up being in halfway cases anyway.

* convert : allow direct conversion to TQ1_0 and TQ2_0

The token embeddings and output tensors are kept in F16
to allow quantizing them to Q4_K and Q6_K with llama-quantize.

* llama : handle fallback for TQ1_0 and TQ2_0 with Q4_0

Q4_0 is not completely symmetric (so not lossless for ternary models),
but it should be good enough.

* ggml-quants : allow using ARM dot product instructions for TQ1_0

* ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support

* ggml : remove unused ggml_mul special case

It would otherwise conflict with the more general
optimization coming with Mamba-2.

* ggml : handle TQ1_0 and TQ2_0 in dequantization-based operators

* test-backend-ops : add TQ1_0 and TQ2_0 comments for later

Not yet adding uncommented, because some backends like SYCL and Metal
do not properly handle unknown types in supports_op for GGML_OP_MUL_MAT.
(and Metal also doesn't handle it with GGML_OP_GET_ROWS)
Support for TQ1_0 and TQ2_0 for other backends than CPU
will be added in follow-up pull requests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request examples ggml changes relating to the ggml tensor library for machine learning merge ready indicates that this may be ready to merge soon and is just holding out in case of objections python python script changes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.