Skip to content

Commit

Permalink
convert.py : shorten and simplify permute
Browse files Browse the repository at this point in the history
* idea from @KerfuffleV2
  • Loading branch information
mj-shifu committed Jul 27, 2023
1 parent 01d16e1 commit 9442c34
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,11 @@ def __repr__(self) -> str:


def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
if n_kv_head is None or n_head == n_kv_head:
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
else:
return (weights.reshape(n_head // n_kv_head, 2, weights.shape[0] * n_kv_head // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
if n_kv_head is not None and n_head != n_kv_head:
n_head //= n_kv_head
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))


def dequantize_q4(qvalues_pack32: NDArray, scales: NDArray, addends: Optional[NDArray], g_idx: Optional[NDArray]) -> NDArray:
Expand Down

0 comments on commit 9442c34

Please sign in to comment.