Skip to content

Commit

Permalink
fix baichuan2-7b, deepseek-vl and xcomposer2d5-4bit
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Aug 6, 2024
1 parent 69f90ae commit bf33b6d
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
10 changes: 7 additions & 3 deletions lmdeploy/turbomind/deploy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ def unpack_awq_gemm(x: torch.Tensor) -> torch.Tensor:
return torch.stack(ys, dim=-1).view(*x.shape[:-1], -1)


def process_awq_gemm(x: torch.Tensor, *args):
def process_awq_gemm(x: torch.Tensor, kind: str):
x = x.cuda()
if x.dtype == torch.int32:
x = unpack_awq_gemm(x)
return x.t()
if kind in ['qweight', 'qzeros', 'scales']:
x = x.t()
return x


def process_gptq(x: torch.Tensor, kind: str):
Expand All @@ -39,7 +41,9 @@ def process_gptq(x: torch.Tensor, kind: str):
x = torch.stack(xs, dim=1).view(-1, x.size(-1))
else: # 'qzeros' (k/g,n/8)
x = torch.stack(xs, dim=-1).view(x.size(0), -1) + 1
return x.t()
if kind in ['qweight', 'qzeros', 'scales']:
x = x.t()
return x


def get_input_policy(model_format):
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/turbomind/deploy/source_model/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _attn(self, i: int, kind: str):
q, k, v, o = (None, ) * 4
pack_key = f'model.layers.{i}.self_attn.W_pack.{kind}'
qkv = self.transform(self.params.get(pack_key), kind)
if qkv:
if qkv is not None:
q, k, v = torch.split(qkv, qkv.shape[0] // 3, dim=0)
o = self.params.get(f'model.layers.{i}.self_attn.o_proj.{kind}')
o = self.transform(o, kind)
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/source_model/deepseek_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class DeepSeekVLReader(LlamaReader):
"""DeepSeekVL model reader."""

attn_layer_prefix = 'language_model.model.layers'
attn_layer_patten = r'language_model.model.layers.([0-9]+).'
tok_embeddings_key = 'language_model.model.embed_tokens.weight'
norm_weight_key = 'language_model.model.norm.weight'
Expand Down
4 changes: 2 additions & 2 deletions src/turbomind/kernels/gemm/tune/stopping_criterion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class Optimistic: public StoppingCriterion {
public:
Optimistic(int min_iter, int max_iter, float max_ms)
{
min_iter_ = min_iter ? min_iter > 0 : 1;
max_iter_ = max_iter ? max_iter > 0 : std::numeric_limits<int>::max();
min_iter_ = std::max(min_iter, 1);
max_iter_ = max_iter > 0 ? max_iter : std::numeric_limits<int>::max();
max_ms_ = max_ms > 0 ? max_ms : std::numeric_limits<float>::infinity();
}
bool should_stop(const Stats& stats) override
Expand Down

0 comments on commit bf33b6d

Please sign in to comment.