Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the support of Qwen #495

Merged
merged 4 commits into from
Jan 26, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,14 @@ def quantize(

examples = self._prepare_examples_for_quantization(examples, batch_size)

def nested_move_to_device(v, device):
if isinstance(v, torch.Tensor):
return move_to_device(v, device)
elif isinstance(v, (list, tuple)):
return type(v)([nested_move_to_device(e, device) for e in v])
else:
return v

class LayerHijacker(nn.Module):
"""hijack layer's forward pass to cache data"""

Expand Down Expand Up @@ -259,10 +267,11 @@ def forward(self, inp=None, **kwargs):
one_kwargs = dict()
for k, v in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states", "attention_mask", "position_ids"]:
if isinstance(v, torch.Tensor):
one_kwargs[k] = move_to_device(v, self.data_device)
else:
one_kwargs[k] = v
# if isinstance(v, torch.Tensor):
# one_kwargs[k] = move_to_device(v, self.data_device)
# else:
# one_kwargs[k] = v
one_kwargs[k] = nested_move_to_device(v, self.data_device)
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
layer_input_kwargs.append(one_kwargs)
raise ValueError

Expand Down Expand Up @@ -355,10 +364,11 @@ def tmp(_, inp, out):
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor):
additional_layer_inputs[k] = move_to_device(v, cur_layer_device)
else:
additional_layer_inputs[k] = v
# if isinstance(v, torch.Tensor):
# additional_layer_inputs[k] = move_to_device(v, cur_layer_device)
# else:
# additional_layer_inputs[k] = v
additional_layer_inputs[k] = nested_move_to_device(v, cur_layer_device)
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
layer(layer_input, **additional_layer_inputs)
for h in handles:
h.remove()
Expand Down Expand Up @@ -389,10 +399,11 @@ def tmp(_, inp, out):
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor):
additional_layer_inputs[k] = move_to_device(v, cur_layer_device)
else:
additional_layer_inputs[k] = v
# if isinstance(v, torch.Tensor):
# additional_layer_inputs[k] = move_to_device(v, cur_layer_device)
# else:
# additional_layer_inputs[k] = v
additional_layer_inputs[k] = nested_move_to_device(v, cur_layer_device)
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
layer_output = move_to_device(
layer(layer_input, **additional_layer_inputs)[0],
cur_layer_device if cache_examples_on_gpu else CPU
Expand Down