Skip to content

Commit

Permalink
Support p-tuning v2 for ChatGLM family & fix rope theta for 32k/128k …
Browse files Browse the repository at this point in the history
…seqlen (#289)
  • Loading branch information
li-plus authored Apr 23, 2024
1 parent 04910ce commit 829e9a7
Show file tree
Hide file tree
Showing 9 changed files with 573 additions and 220 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ C++ implementation of [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGL
Highlights:
* Pure C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp).
* Accelerated memory-efficient CPU inference with int4/int8 quantization, optimized KV cache and parallel computing.
* P-Tuning v2 and LoRA finetuned models support.
* Streaming generation with typewriter effect.
* Python binding, web demo, api servers and more possibilities.

Expand Down Expand Up @@ -68,7 +69,9 @@ You are free to try any of the below quantization types by specifying `-t <type>
* `f16`: half precision floating point weights without quantization.
* `f32`: single precision floating point weights without quantization.

For LoRA model, add `-l <lora_model_name_or_path>` flag to merge your LoRA weights into the base model.
For LoRA models, add `-l <lora_model_name_or_path>` flag to merge your LoRA weights into the base model. For example, run `python3 chatglm_cpp/convert.py -i THUDM/chatglm3-6b -t q4_0 -o chatglm3-ggml-lora.bin -l shibing624/chatglm3-6b-csc-chinese-lora` to merge public LoRA weights from Hugging Face.

For P-Tuning v2 models using the [official finetuning script](https://github.com/THUDM/ChatGLM3/tree/main/finetune_demo), additional weights are automatically detected by `convert.py`. If `past_key_values` is on the output weight list, the P-Tuning checkpoint is successfully converted.

**Build & Run**

Expand Down
201 changes: 125 additions & 76 deletions chatglm.cpp

Large diffs are not rendered by default.

136 changes: 101 additions & 35 deletions chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,27 @@ struct ConfigRecordV1 {
};

// For compatibility
struct ConfigRecordV2 : public ConfigRecordV1 {
struct ConfigRecordV1GQA : public ConfigRecordV1 {
int num_kv_heads;
};

// TODO: use json to serialize config
struct ConfigRecordV2 {
ggml_type dtype;
int vocab_size;
int hidden_size;
int num_attention_heads;
int num_key_value_heads;
int num_hidden_layers;
int intermediate_size;
float norm_eps;
int num_virtual_tokens;
float rope_theta;
int max_length;
int eos_token_id;
int pad_token_id;
};

enum class ActivationType {
GELU,
SILU,
Expand All @@ -89,6 +106,7 @@ enum class RopeType {
GPTJ = 0,
NEOX = 2,
CHATGLM = 4,
CHATGLM2 = 8,
DISABLED = 10000,
};

Expand All @@ -105,33 +123,44 @@ class ModelConfig {
ModelConfig(ModelType model_type, ggml_type dtype, int vocab_size, int hidden_size, int num_attention_heads,
int num_kv_heads, int num_hidden_layers, int intermediate_size, float norm_eps,
ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi,
RopeType rope_type, int rope_dim_scale, AttentionMaskType attn_mask_type, int max_length,
int bos_token_id, int eos_token_id, int pad_token_id, int sep_token_id,
std::vector<int> extra_eos_token_ids)
RopeType rope_type, float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type,
int num_virtual_tokens, int max_length, int bos_token_id, int eos_token_id, int pad_token_id,
int sep_token_id, std::vector<int> extra_eos_token_ids)
: model_type(model_type), dtype(dtype), vocab_size(vocab_size), hidden_size(hidden_size),
num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), num_hidden_layers(num_hidden_layers),
intermediate_size(intermediate_size), norm_eps(norm_eps), hidden_act(hidden_act), use_qkv_bias(use_qkv_bias),
use_dense_bias(use_dense_bias), interleaved_qkv(interleaved_qkv), use_alibi(use_alibi), rope_type(rope_type),
rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type), max_length(max_length),
bos_token_id(bos_token_id), eos_token_id(eos_token_id), pad_token_id(pad_token_id),
sep_token_id(sep_token_id), extra_eos_token_ids(std::move(extra_eos_token_ids)) {}
rope_theta(rope_theta), rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type),
num_virtual_tokens(num_virtual_tokens), max_length(max_length), bos_token_id(bos_token_id),
eos_token_id(eos_token_id), pad_token_id(pad_token_id), sep_token_id(sep_token_id),
extra_eos_token_ids(std::move(extra_eos_token_ids)) {}

ModelConfig(ModelType model_type, const ConfigRecordV1 &rec, float norm_eps, ActivationType hidden_act,
bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type,
int rope_dim_scale, AttentionMaskType attn_mask_type)
float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens)
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads,
rec.num_attention_heads, rec.num_hidden_layers, rec.intermediate_size, norm_eps, hidden_act,
use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_dim_scale,
attn_mask_type, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id,
rec.sep_token_id, {}) {}
use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale,
attn_mask_type, num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id,
rec.pad_token_id, rec.sep_token_id, {}) {}

ModelConfig(ModelType model_type, const ConfigRecordV2 &rec, float norm_eps, ActivationType hidden_act,
ModelConfig(ModelType model_type, const ConfigRecordV1GQA &rec, float norm_eps, ActivationType hidden_act,
bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type,
int rope_dim_scale, AttentionMaskType attn_mask_type)
float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens)
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads,
rec.num_hidden_layers, rec.intermediate_size, norm_eps, hidden_act, use_qkv_bias, use_dense_bias,
interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type, rec.max_length,
rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {}
interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale, attn_mask_type,
num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id,
rec.sep_token_id, {}) {}

ModelConfig(ModelType model_type, const ConfigRecordV2 &rec, ActivationType hidden_act, bool use_qkv_bias,
bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale,
AttentionMaskType attn_mask_type)
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads,
rec.num_key_value_heads, rec.num_hidden_layers, rec.intermediate_size, rec.norm_eps, hidden_act,
use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rec.rope_theta,
rope_dim_scale, attn_mask_type, rec.num_virtual_tokens, rec.max_length, -1, rec.eos_token_id,
rec.pad_token_id, -1, {}) {}

std::string model_type_name() const { return to_string(model_type); }

Expand All @@ -151,8 +180,10 @@ class ModelConfig {
bool interleaved_qkv;
bool use_alibi;
RopeType rope_type;
float rope_theta;
int rope_dim_scale;
AttentionMaskType attn_mask_type;
int num_virtual_tokens;
int max_length;
int bos_token_id;
int eos_token_id;
Expand Down Expand Up @@ -388,16 +419,17 @@ class BasicAttention {
BasicAttention() = default;
BasicAttention(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length,
bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, bool use_alibi, RopeType rope_type,
int rope_dim_scale, AttentionMaskType attn_mask_type)
float rope_theta, int rope_dim_scale, AttentionMaskType attn_mask_type, int num_virtual_tokens)
: num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), interleaved_qkv(interleaved_qkv),
use_alibi(use_alibi), rope_type(rope_type), rope_dim_scale(rope_dim_scale), attn_mask_type(attn_mask_type),
use_alibi(use_alibi), rope_type(rope_type), rope_theta(rope_theta), rope_dim_scale(rope_dim_scale),
attn_mask_type(attn_mask_type), num_virtual_tokens(num_virtual_tokens),
query_key_value(ctx, hidden_size, hidden_size + 2 * (hidden_size / num_attention_heads) * num_kv_heads,
use_qkv_bias),
dense(ctx, hidden_size, hidden_size, use_dense_bias),
k_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads, max_length,
num_kv_heads)),
v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length, hidden_size / num_attention_heads,
num_kv_heads)) {}
k_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads,
max_length + num_virtual_tokens, num_kv_heads)),
v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length + num_virtual_tokens,
hidden_size / num_attention_heads, num_kv_heads)) {}

ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past,
int n_ctx) const;
Expand All @@ -408,12 +440,14 @@ class BasicAttention {
bool interleaved_qkv;
bool use_alibi;
RopeType rope_type;
float rope_theta;
int rope_dim_scale;
AttentionMaskType attn_mask_type;
int num_virtual_tokens;
Linear query_key_value;
Linear dense;
ggml_tensor *k_cache; // [kv_heads, max_len, head_size]
ggml_tensor *v_cache; // [kv_heads, head_size, max_len]
ggml_tensor *k_cache; // [#kvh, s, d]
ggml_tensor *v_cache; // [#kvh, d, s]
};

template <typename Norm, typename Attention, typename MLP>
Expand All @@ -422,11 +456,12 @@ class BasicBlock {
BasicBlock() = default;
BasicBlock(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size,
int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias,
bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale,
AttentionMaskType attn_mask_type)
bool interleaved_qkv, bool use_alibi, RopeType rope_type, float rope_theta, int rope_dim_scale,
AttentionMaskType attn_mask_type, int num_virtual_tokens)
: input_layernorm(ctx, hidden_size, false, norm_eps),
attention(ctx, hidden_size, num_attention_heads, num_kv_heads, max_length, use_qkv_bias, use_dense_bias,
interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type),
interleaved_qkv, use_alibi, rope_type, rope_theta, rope_dim_scale, attn_mask_type,
num_virtual_tokens),
post_attention_layernorm(ctx, hidden_size, false, norm_eps),
mlp(ctx, hidden_size, intermediate_size, hidden_act) {}

Expand Down Expand Up @@ -517,16 +552,44 @@ class BasicModel {
return hidden_states;
}

void load_prefix_cache(const ModelConfig &config, ggml_tensor *past_key_values) {
ggml_cgraph gf{};
auto ctx = make_unique_ggml_context(config.num_hidden_layers * 7 * ggml_tensor_overhead(), nullptr, false);
const int head_size = config.hidden_size / config.num_attention_heads;
for (size_t i = 0; i < layers.size(); i++) {
auto &attn = layers[i].attention;
ggml_tensor *virtual_key = ggml_view_3d(ctx.get(), past_key_values, head_size, config.num_virtual_tokens,
config.num_kv_heads, past_key_values->nb[1], past_key_values->nb[2],
i * 2 * past_key_values->nb[3]); // [#h, v, d]
ggml_tensor *k_cache_view =
ggml_view_3d(ctx.get(), attn.k_cache, head_size, config.num_virtual_tokens, config.num_kv_heads,
attn.k_cache->nb[1], attn.k_cache->nb[2], 0); // [#h, v, d]
ggml_build_forward_expand(&gf, ggml_cpy(ctx.get(), virtual_key, k_cache_view));

ggml_tensor *virtual_value = ggml_view_3d(
ctx.get(), past_key_values, head_size, config.num_virtual_tokens, config.num_kv_heads,
past_key_values->nb[1], past_key_values->nb[2], (i * 2 + 1) * past_key_values->nb[3]); // [#h, v, d]
virtual_value = ggml_permute(ctx.get(), virtual_value, 1, 0, 2, 3); // [#h, d, v]
ggml_tensor *v_cache_view =
ggml_view_3d(ctx.get(), attn.v_cache, config.num_virtual_tokens, head_size, config.num_kv_heads,
attn.v_cache->nb[1], attn.v_cache->nb[2], 0); // [#h, d, v]
ggml_build_forward_expand(&gf, ggml_cpy(ctx.get(), virtual_value, v_cache_view));
}
CHATGLM_CHECK(ggml_used_mem(ctx.get()) == ggml_get_mem_size(ctx.get())) << "corrupted prefix cache context";
std::vector<uninitialized_char> compute_buffer;
ggml_graph_compute_helper(compute_buffer, &gf, 0);
}

private:
std::vector<Block> build_layers(ModelContext *ctx, const ModelConfig &config) {
std::vector<Block> layers;
layers.reserve(config.num_hidden_layers);
for (int layer_id = 0; layer_id < config.num_hidden_layers; layer_id++) {
// TODO: reduce max length? 32k might be too large for cpu inference
layers.emplace_back(ctx, config.hidden_size, config.num_attention_heads, config.num_kv_heads,
config.intermediate_size, config.max_length, config.norm_eps, config.hidden_act,
config.use_qkv_bias, config.use_dense_bias, config.interleaved_qkv, config.use_alibi,
config.rope_type, config.rope_dim_scale, config.attn_mask_type);
config.rope_type, config.rope_theta, config.rope_dim_scale, config.attn_mask_type,
config.num_virtual_tokens);
}
return layers;
}
Expand Down Expand Up @@ -745,6 +808,8 @@ class BasicModelForCausalLM : public BaseModelForCausalLM {
return lm_logits;
}

void load_prefix_cache(ggml_tensor *past_key_values) { transformer.load_prefix_cache(config, past_key_values); }

protected:
void to_cpu() {
for (auto &item : state_dict_) {
Expand Down Expand Up @@ -818,13 +883,14 @@ class GLMBlock : public BasicBlock<LayerNorm, BasicAttention, BasicMLP> {
GLMBlock() = default;
GLMBlock(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size,
int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias,
bool interleaved_qkv, bool use_alibi, RopeType rope_type, int rope_dim_scale,
AttentionMaskType attn_mask_type)
: BasicBlock(
LayerNorm(ctx, hidden_size, false, norm_eps),
BasicAttention(ctx, hidden_size, num_attention_heads, num_attention_heads, max_length, use_qkv_bias,
use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_dim_scale, attn_mask_type),
LayerNorm(ctx, hidden_size, false, norm_eps), BasicMLP(ctx, hidden_size, intermediate_size, hidden_act)),
bool interleaved_qkv, bool use_alibi, RopeType rope_type, float rope_theta, int rope_dim_scale,
AttentionMaskType attn_mask_type, int num_virtual_tokens)
: BasicBlock(LayerNorm(ctx, hidden_size, false, norm_eps),
BasicAttention(ctx, hidden_size, num_attention_heads, num_attention_heads, max_length,
use_qkv_bias, use_dense_bias, interleaved_qkv, use_alibi, rope_type, rope_theta,
rope_dim_scale, attn_mask_type, num_virtual_tokens),
LayerNorm(ctx, hidden_size, false, norm_eps),
BasicMLP(ctx, hidden_size, intermediate_size, hidden_act)),
alpha_value(std::sqrt(2.f * 28)) {}

ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past,
Expand Down
2 changes: 1 addition & 1 deletion chatglm_cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import chatglm_cpp._C as _C
from chatglm_cpp._C import ChatMessage

__version__ = "0.3.1"
__version__ = "0.3.2"


@dataclass
Expand Down
Loading

0 comments on commit 829e9a7

Please sign in to comment.