Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml-quants : ternary packing for TriLMs and BitNet b1.58 #8151

Merged
merged 33 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
bd80749
ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b
compilade Jun 19, 2024
7ef4254
ggml-quants : faster 1.625 bpw AVX2 vec_dot
compilade Jun 19, 2024
48b73b8
ggml-quants : substract 1 when back in epi8
compilade Jun 19, 2024
ef1e345
ggml-quants : Q2_2 now faster than Q4_K on with AVX2
compilade Jun 20, 2024
638ad52
ggml-quants : cleanup Q1_3 code formatting
compilade Jun 23, 2024
9465ec6
ggml-quants : ARM NEON vec_dot for q2_2 and q1_3
compilade Jun 25, 2024
89dc3b2
ggml-quants : use ceiling division when quantizing q1_3
compilade Jun 26, 2024
961e293
convert-hf : simplify BitNet pre-quantization
compilade Jun 26, 2024
0996149
convert-hf : allow converting the weird BitNet 1.3B
compilade Jun 27, 2024
bfd2f21
bitnet : replace 1.58b with b1.58, as in the paper
compilade Jun 29, 2024
ec50944
ggml-quants : fix build failure on Windows
compilade Jun 29, 2024
8fbd593
ggml-quants : attempt to fix Arm 32-bit support
compilade Jun 29, 2024
dd3e62a
ggml : add some informative comments in q1_3 vec_dot
compilade Jul 29, 2024
79a278e
Merge branch 'master' into compilade/bitnet-ternary
compilade Jul 29, 2024
77b8f84
ggml : add TQ1_0 and TQ2_0 ternary quantization types
compilade Jul 30, 2024
560873f
ggml : even faster TQ2_0
compilade Jul 31, 2024
e971957
ggml : also faster TQ1_0
compilade Jul 31, 2024
a6dd699
ggml : fix build issues in certain environments
compilade Aug 1, 2024
5417089
ggml : add NEON vec_dot implementation for TQ1_0 and TQ2_0
compilade Aug 1, 2024
45719a2
ggml : avoid directly using vmlal_high_s8, for 32-bit ARM compat
compilade Aug 1, 2024
04eec58
ggml : remove q1_3 and q2_2
compilade Aug 2, 2024
f034aa1
ggml-quants : rename fields of TQ1_0 and TQ2_0 structs for consistency
compilade Aug 3, 2024
96b3d41
ggml-quants : allow using vdotq_s32 in TQ2_0 vec_dot
compilade Aug 7, 2024
d911cd1
Merge branch 'master' into compilade/bitnet-ternary
compilade Aug 11, 2024
3a0bf17
gguf-py : Numpy (de)quantization for TQ1_0 and TQ2_0
compilade Aug 12, 2024
895004f
convert : allow direct conversion to TQ1_0 and TQ2_0
compilade Aug 13, 2024
69f7726
ggml-quants : allow using ARM dot product instructions for TQ1_0
compilade Aug 13, 2024
82b2404
Merge branch 'master' into compilade/bitnet-ternary
compilade Aug 13, 2024
35cc556
ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support
compilade Aug 13, 2024
cb6d996
Merge branch 'master' into compilade/bitnet-ternary
compilade Aug 22, 2024
7f3a619
Merge branch 'master' into compilade/bitnet-ternary
compilade Sep 4, 2024
8d61607
ggml ; remove unused ggml_mul special case
compilade Sep 4, 2024
75b3a09
test-backend-ops : add TQ1_0 and TQ2_0 comments for later
compilade Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,20 @@ def prepare_tensors(self):
):
data_qtype = gguf.GGMLQuantizationType.F32

if data_qtype is False and any(
self.match_model_tensor_name(new_name, key, bid)
for key in (
gguf.MODEL_TENSOR.TOKEN_EMBD,
gguf.MODEL_TENSOR.OUTPUT,
)
):
if self.ftype in (
gguf.LlamaFileType.MOSTLY_TQ1_0,
gguf.LlamaFileType.MOSTLY_TQ2_0,
):
# TODO: use Q4_K and Q6_K
data_qtype = gguf.GGMLQuantizationType.F16

# No override (data_qtype is False), or wants to be quantized (data_qtype is True)
if isinstance(data_qtype, bool):
if self.ftype == gguf.LlamaFileType.ALL_F32:
Expand All @@ -318,6 +332,10 @@ def prepare_tensors(self):
data_qtype = gguf.GGMLQuantizationType.BF16
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
data_qtype = gguf.GGMLQuantizationType.Q8_0
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ1_0:
data_qtype = gguf.GGMLQuantizationType.TQ1_0
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0:
data_qtype = gguf.GGMLQuantizationType.TQ2_0
else:
raise ValueError(f"Unknown file type: {self.ftype.name}")

Expand Down Expand Up @@ -1623,15 +1641,16 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(1.0)

def weight_quant(self, weight):
def weight_quant(self, weight: Tensor) -> Tensor:
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
weight = (weight * s).round().clamp(-1, 1) / s
scale = weight.abs().max().unsqueeze(0)
weight = torch.where(weight.abs().less(1e-6), 0, weight).type(dtype)
weight = torch.sign(weight).type(dtype)
return weight.type(dtype), scale.type(torch.float32)
scale = weight.abs().mean().clamp(min=1e-5)
iscale = 1 / scale
# TODO: multiply by the scale directly instead of inverting it twice
# (this is also unnecessarily doubly inverted upstream)
# ref: https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/af89e318d78a70802061246bf037199d2fb97020/utils_quant.py#L10
result = (weight * iscale).round().clamp(-1, 1) / iscale
return result.type(dtype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
new_name = self.map_tensor_name(name)
Expand All @@ -1646,11 +1665,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
gguf.MODEL_TENSOR.FFN_GATE,
]):
# transform weight into 1/0/-1 (in fp32)
weight_torch, scale_torch = self.weight_quant(data_torch)
yield (new_name, weight_torch)
yield (new_name.removesuffix(".weight") + ".scale", scale_torch)
else:
yield (new_name, data_torch)
data_torch = self.weight_quant(data_torch)

yield (new_name, data_torch)


@Model.register("GrokForCausalLM")
Expand Down Expand Up @@ -4011,8 +4028,8 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
"--bigendian", action="store_true",
Expand Down Expand Up @@ -4099,6 +4116,8 @@ def main() -> None:
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
"tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0,
"tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0,
"auto": gguf.LlamaFileType.GUESSED,
}

Expand Down
2 changes: 2 additions & 0 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "TQ1_0", LLAMA_FTYPE_MOSTLY_TQ1_0, " 1.69 bpw ternarization", },
{ "TQ2_0", LLAMA_FTYPE_MOSTLY_TQ2_0, " 2.06 bpw ternarization", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", },
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },
Expand Down
2 changes: 2 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ extern "C" {
GGML_TYPE_Q4_0_4_4 = 31,
GGML_TYPE_Q4_0_4_8 = 32,
GGML_TYPE_Q4_0_8_8 = 33,
GGML_TYPE_TQ1_0 = 34,
GGML_TYPE_TQ2_0 = 35,
GGML_TYPE_COUNT,
};

Expand Down
20 changes: 20 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,25 @@ typedef struct {
} block_q8_0x8;
static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding");

//
// Ternary quantization
//

// 1.6875 bpw
typedef struct {
uint8_t qs[(QK_K - 4 * QK_K / 64) / 5]; // 5 elements per byte (3^5 = 243 < 256)
uint8_t qh[QK_K/64]; // 4 elements per byte
ggml_half d;
} block_tq1_0;
static_assert(sizeof(block_tq1_0) == sizeof(ggml_half) + QK_K / 64 + (QK_K - 4 * QK_K / 64) / 5, "wrong tq1_0 block size/padding");

// 2.0625 bpw
typedef struct {
uint8_t qs[QK_K/4]; // 2 bits per element
ggml_half d;
} block_tq2_0;
static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding");

//
// Super-block quantization structures
//
Expand Down Expand Up @@ -361,6 +380,7 @@ typedef struct {
} block_iq3_s;
static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");

// 1.5625 bpw
typedef struct {
ggml_half d;
uint8_t qs[QK_K/8];
Expand Down
11 changes: 4 additions & 7 deletions ggml/src/ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ typedef __fp16 ggml_fp16_internal_t;

// 32-bit ARM compatibility

// vaddvq_s16
// vaddlvq_s16
// vpaddq_s16
// vpaddq_s32
// vaddvq_s32
Expand All @@ -185,12 +185,9 @@ typedef __fp16 ggml_fp16_internal_t;
// vzip1_u8
// vzip2_u8

inline static int32_t vaddvq_s16(int16x8_t v) {
return
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
inline static int32_t vaddlvq_s16(int16x8_t v) {
int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
}

inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
Expand Down
Loading
Loading