-
Notifications
You must be signed in to change notification settings - Fork 12
/
ops.py
364 lines (309 loc) · 13.7 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
#
# This source code is licensed under the Apache 2.0 license found in the
# LICENSE file in the root directory of this source tree.
###############################################################################
from typing import Optional, Tuple
import habana_frameworks.torch as htorch
import torch
import math
import os
import torch.nn.functional as F
from vllm.logger import init_logger
logger = init_logger(__name__)
HPUFusedRMSNorm = None
try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
HPUFusedRMSNorm = FusedRMSNorm
except ImportError:
logger.warning("Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")
HPUFusedSDPA = None
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
HPUFusedSDPA = FusedSDPA
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")
VLLM_PA_SOFTMAX_IMPL_OPTIONS = ['wsum', 'amax']
VLLM_PA_SOFTMAX_IMPL = os.environ.get('VLLM_PA_SOFTMAX_IMPL', VLLM_PA_SOFTMAX_IMPL_OPTIONS[0])
assert VLLM_PA_SOFTMAX_IMPL in VLLM_PA_SOFTMAX_IMPL_OPTIONS, f'Unsupported pa softmax impl - {VLLM_PA_SOFTMAX_IMPL} . Supported values: {VLLM_PA_SOFTMAX_IMPL_OPTIONS}'
def batch2block(tensor, block_mapping):
shape = tuple(tensor.shape)
return (block_mapping @ tensor.view(shape[0], -1)).view(-1, *shape[1:])
def block2batch(tensor, block_mapping):
shape = tuple(tensor.shape)
return (block_mapping.t() @ tensor.view(shape[0], -1)).view(-1, *shape[1:])
def normalize_amax(batch_size, attn, **rest):
tail_dims = tuple(range(1, attn.dim()))
attn_max = attn.amax(tail_dims).amax()
return attn.sub_(attn_max)
def normalize_wsum(batch_size, attn, block_mapping, block_scales, **rest):
block_sum_attn = attn.amax(-1)
missing_dims = block_sum_attn.dim() - block_scales.dim()
block_sum_attn.mul_(block_scales.reshape(-1, *[1 for _ in range(missing_dims)]))
block_sum_attn = block2batch(block_sum_attn, block_mapping)
block_sum_attn = batch2block(block_sum_attn, block_mapping)
return attn.sub_(block_sum_attn.unsqueeze(-1))
def normalize(**kwargs):
match VLLM_PA_SOFTMAX_IMPL:
case 'amax':
return normalize_amax(**kwargs)
case 'wsum':
return normalize_wsum(**kwargs)
def block_softmax(batch_size, attn, block_mapping, block_scales):
attn = normalize(batch_size=batch_size, attn=attn, block_mapping=block_mapping, block_scales=block_scales)
attn = attn.exp_()
sums = attn.sum(dim=-1).unsqueeze(-1)
sums = block2batch(sums, block_mapping)
sums = batch2block(sums, block_mapping)
sums.add_(torch.finfo(sums.dtype).tiny)
attn.div_(sums)
return attn
def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
block_bias, block_scales, scale, matmul_qk_op, matmul_av_op, keys_fetch_func,
values_fetch_func):
batch_size = query.size(0)
q_heads = query.size(1)
kv_heads = key_cache.size(2)
query = batch2block(scale * query, block_mapping).unsqueeze(-2)
key = keys_fetch_func(key_cache, block_list).transpose(1, 2)
value = values_fetch_func(value_cache, block_list).transpose(1, 2)
block_bias = block_bias.view(key.size(0), 1, 1, -1)
if kv_heads != q_heads:
block_bias = block_bias.unsqueeze(1)
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
key = key.transpose(3, 4)
else:
key = key.transpose(2, 3)
attn = matmul_qk_op(query, key) + block_bias
attn = block_softmax(batch_size, attn, block_mapping, block_scales)
attn = matmul_av_op(attn, value)
attn = block2batch(attn, block_mapping)
attn = attn.squeeze(-2)
if kv_heads != q_heads:
attn = attn.flatten(1, 2)
return attn
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
#TODO: remove after fusedsdpa fix for query_head != kv_head
def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The kv go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = kv.shape
if n_rep == 1:
return kv
kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen,
head_dim)
return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def prompt_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
valid_seq_lengths: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_heads = query.size(1)
kv_heads = key.size(1)
if attn_bias is not None or HPUFusedSDPA is None:
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = matmul_av_op(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
else:
#TODO: remove after fusedsdpa fix for query_heads != kv_heads
if query_heads != kv_heads:
key = repeat_kv(key, int(query_heads // kv_heads))
value = repeat_kv(value, int(query_heads // kv_heads))
softmax_mode = 'fast'
recompute_mode = True
attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True,
scale, softmax_mode, recompute_mode,
valid_seq_lengths, 'right')
attn_weights = attn_weights.transpose(1, 2)
return attn_weights
class LoraMask:
lora_mask = None
@staticmethod
def setLoraMask(mask):
LoraMask.lora_mask = mask
@staticmethod
def getLoraMask():
return LoraMask.lora_mask
def dispatch_bgmv_linear(
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
layer_idx: int,
scale: float,
):
"""
`wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices
stacked at dimension 0 into single tensors, assuming same rank. `wa` is the
reshaped and transposed version of `wa_t_all` of shape
(h_in, max_loras * lora_rank) and `wb` is the transposed and reshaped
version of `wb_t_all` of shape (max_loras * lora_rank, h_out).
Matmul input `x` with `wa`. Multiply `x` with a mask to zero-out inputs of
inactive LoRA indices. Matmul masked output with `wb` and scale it to get
the final output.
"""
assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
mask = LoraMask.getLoraMask()
wa = wa_t_all[:, 0, :, :]
wb = wb_t_all[:, 0, :, :].transpose(1, 2)
wa = wa.reshape(wa.shape[0] * wa.shape[1], wa.shape[2]).transpose(0, 1)
wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2])
out = x @ wa
assert (out.shape == mask.shape)
out = out * mask
out = out @ wb
y += out * scale
def dispatch_bgmv_embedding(
y: torch.Tensor,
x: torch.Tensor,
wb_t_all: torch.Tensor,
layer_idx: int,
scale: float,
):
"""
`wb_t_all` contains all LoRA-B weight matrices stacked at dimension 0 into
a single tensor, assuming same rank. `wb` is the transposed and reshaped
version of `wb_t_all` of shape (num_loras * lora_rank, embedding_dim).
Output of LoRA-A embedding (tensor x) is repeated max_loras times to match
the shape of `wb`. Multiply `x` with a mask to zero-out inputs of inactive
LoRA indices. Matmul masked output with `wb` and scale it to get the final
output.
"""
assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}'
max_loras = wb_t_all.size(0)
x = x.repeat(1, max_loras)
x = x * LoraMask.getLoraMask()
wb = wb_t_all[:, 0, :, :].transpose(1, 2)
wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2])
out = x @ wb
y += out * scale
class MoeMatmul(torch.nn.Module):
def __init__(self):
super().__init__()
def set_weight(self, w):
self.weight = w
def calc(self, state, expert_id, w):
self.weight = w[expert_id].transpose(0, 1)
return self.forward(state)
def forward(self, state):
return torch.matmul(state, self.weight)
class StaticFusedMOE(torch.nn.Module):
def __init__(self, num_total_experts):
super().__init__()
self.w13_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)])
self.w2_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)])
self.num_total_experts = num_total_experts
def forward(self, hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights = torch.zeros((B, self.num_total_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, self.num_total_experts)
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)
htorch.core.mark_step()
for expert_idx in range(self.num_total_experts):
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, D)
w_output = self.w13_list[expert_idx].calc(current_state_static,
expert_idx, w1)
w_output = silu_and_mul(w_output)
w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2)
current_hidden_states_static = w_output * padded_weight
final_hidden_states += current_hidden_states_static
return final_hidden_states.view(-1, D)
# fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensor for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
if batch_dim_padding:
shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
output = torch.empty(shape,
device=input.device,
dtype=torch.float8_e4m3fn)
else:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None:
raise "dynamic scaled_fp8_quant not implemented for HPU"
#TODO: calculate scale to match gaudi2 240 range instead of 448
if use_per_token_if_dynamic:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale, scale_ub)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
output = torch.ops.hpu.cast_to_fp8_v2(input,
1 / scale,
False,
False,
dtype=torch.float8_e4m3fn)[0]
return output, scale