Skip to content

Commit

Permalink
fixing GPTQ
Browse files Browse the repository at this point in the history
Summary:

trying to fix the issue with kv_cache update by changing tracing into a
tensor subclass. However it seems we have less success than the fx
tracer. The fx tracer breaks due

k_out[:,:, input_pos] = k_val

getting traced as

new_var = torch.ops.aten.index_put_(k_out, [None, None,
input_pos], k_val)

with new var never being accessed afterward. new_var becomes hte correct
multiInput value, but then is lost.

The subclass ont he other hand, tries to use the func "<slot wrapper '__setitem__' of 'torch._C.TensorBase' objects>"
which seems to not want to mutate k_out and so the attempt to make it a
multiTensor fails.

Test Plan: sh run.sh

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 9ed1621201317e5f655132ba11538a67c8aa5a69
Pull Request resolved: #148
  • Loading branch information
HDCharles committed Mar 28, 2024
1 parent f697317 commit dfaa329
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 9 deletions.
213 changes: 209 additions & 4 deletions GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ def __init__(

def add_input(self, args):
if self.inputs is None:
self.inputs = [MultiInput([arg]) for arg in args]
# self.inputs = [MultiInput([arg]) for arg in args]
self.inputs = [GPTQMultiTensor([arg]) for arg in args]
else:
self.inputs = [
multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
# multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
multi.add_tensors(arg) for (multi, arg) in zip(self.inputs, args)
]

def get_recorded_inputs(self):
Expand Down Expand Up @@ -129,6 +131,199 @@ def cuda(self):
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]


class GPTQMultiTensor(torch.Tensor):
"""
"""
# todo need default shape/dtype
@staticmethod
def __new__(cls, input, **kwargs):
if isinstance(input, (list, tuple)):
input = input[0]
kwargs["dtype"]=kwargs.get("dtype", input.dtype)
shape = kwargs.pop("shape", input.shape)
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

def __init__(self, input, **kwargs):
self.values = []
self.add_tensors(input)
self.debug = True

def __repr__(self):
return (
f"{self.__class__.__name__}(data={self.values})"
)

def add_tensors(self, input):
if isinstance(input, (tuple, list)):
for inp in input:
self.add_tensors(inp)
else:
assert isinstance(input, torch.Tensor), f"MultiTensor can only use add_input for Tensors or lists of tensors but got {type(input)}"
self.values.append(input)
return self

def count(self):
return len(self.values)

def cuda(self):
self.values = [val.cuda() for val in self.values]
return self

def cpu(self):
self.values = [val.cpu() for val in self.values]
return self

def configure_quantization_mode(
self,
get_qparams_func,
quantize_func,
dequantize_func,
combine_qparams_list_func,
make_names_and_values_dict_func,
skip_layer_func,
):
self.get_qparams_func = get_qparams_func
self.quantize_func = quantize_func
self.dequantize_func = dequantize_func
self.combine_qparams_list_func = combine_qparams_list_func
self.skip_layer_func = skip_layer_func
self.make_names_and_values_dict_func = make_names_and_values_dict_func
return self

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None, skip_gptq=False):
# with torch._C.DisableTorchFunctionSubclass():
# is_set_item = str(func)=="<slot wrapper '__setitem__' of 'torch._C.TensorBase' objects>"
# if is_set_item:
# breakpoint()
# try:
# new_arg1=[None if x == slice(None) else x for x in args[1]]
# return torch.ops.aten.index_put(args[0], new_arg1, args[2])
# except Exception as e:
# print(e)
# print("?A?")
# breakpoint()
# print("?")
# if func == torch.ops.aten.index_put_:
# breakpoint()

def tensors_to_cuda(args):
new_args = []
for x in args:
new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x)
return new_args

def flat_to_grouped(flat):
# size of biggest MultiTensor
multi_tensor_size = max(
[x.count() if isinstance(x, GPTQMultiTensor) else 1 for x in flat]
)
# convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
grouped = list(
zip(
*[x.values if isinstance(x, GPTQMultiTensor) else [x] * multi_tensor_size for x in flat]
)
)
return grouped

# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
# where A is nontensor, b's,c's are tensors
def grouped_to_flat(grouped):
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [(A,A,A), (b1,b2,b3), (c1,c2,c3)]
flat_tups = list(zip(*grouped))
# convert [(A,A,A), (b1,b2,b3), (c1,c2,c3)] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
flattened = [
cls(tup).cpu() if isinstance(tup[0], torch.Tensor) else tup[0] for tup in flat_tups
]
# need to check that getting rid of all but one from each nonTensor tuple is OK
non_tensors_equal=min([True]+[
min([True]+[ # handle situation where tuples have size 0
tup[0]==x for x in tup # check all elements match
]) for tup in flat_tups if not isinstance(tup[0], torch.Tensor) # look at tuples of nonTensors
])
return flattened, non_tensors_equal

kwargs = {} if kwargs is None else kwargs
# combine args and kwargs and remove lists and tuples
flat_args, spec = tree_flatten((args, kwargs))
# move single tensors to cuda

# flat_args = tensors_to_cuda(flat_args)

# convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
grouped_args = flat_to_grouped(flat_args)

do_gptq_linear = (
func is nn.functional.linear
# and id(args[1]) in self.id_to_name
and not skip_gptq
# and not (self.skip_layer_func)
)

# run function for each of the multitensors and return a multitensor
if not do_gptq_linear:
outputs = []
with torch._C.DisableTorchFunctionSubclass():
for inp in grouped_args:
# inp = tensors_to_cuda(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)
try:
out = func(*cur_args, **cur_kwargs)
outputs.append(out.cpu() if isinstance(out, torch.Tensor) else out)
except Exception as e:
print(e)
print("?B?")
breakpoint()
print("?")
try:
# each output
grouped_outputs = [tree_flatten(x)[0] for x in outputs]
out_spec = tree_flatten(outputs[0])[1]
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
flat_outputs, non_tensors_equal = grouped_to_flat(grouped_outputs)
assert non_tensors_equal, (
f"ERR: found a function in model: {func} which "
+"caused an error in GPTQMultiInput, the function dispatch only works for functions"
+" with Tensor outputs or that have the same non-Tensor output value for all across all inputs"
)
return tree_unflatten(flat_outputs, out_spec)
except Exception as e:
print(e)
print("?C?")
breakpoint()
print("?")

# do GPTQ if quantize_linear is true
total_batches = 0
H=0
for inp in grouped_args:
# inp = tensors_to_cuda(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)
x = cur_args[0].float()
shape = x.shape
n = 1 if len(shape) == 2 else shape[0]
H*= total_batches / (total_batches + n)
total_batches += n
x = (
(2 / total_batches) ** (1 / 2) *
x.reshape(-1, shape[-1]).t().float()

)
H += x.matmul(x.t())
W = args[1].to(H.device)
DQ = W+.01
# Q, DQ, qparams = args[0].faster_quant(H, W.detach())

new_out = cls.__torch_function__(func, types, (args[0], DQ, *args[2:]), kwargs, skip_gptq = True)
# if args[0].debug:
return new_out

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
breakpoint()
pass


class GenericGPTQRunner(fx.Interpreter):
"""
This is a generic GPTQ runner that takes an existing model and applies GPTQ.
Expand All @@ -150,7 +345,7 @@ def __init__(
}

# trace model for one input
one_input = [multi.values[0].cpu() for multi in inputs]
one_input = tuple([multi.values[0].cpu() for multi in inputs])
exported_model = torch._dynamo.export(
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
)(*one_input)
Expand All @@ -161,7 +356,7 @@ def __init__(
self.groupsize = groupsize
self.inputs = inputs
self.gptq_done = False
self.debug = False
self.debug = True

def configure_quantization_mode(
self,
Expand Down Expand Up @@ -312,6 +507,16 @@ def SQNR(x, y):
print(
"SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after)
) # matches
qparams_after = self.get_qparams_func(DQ)
Q_after = self.quantize_func(DQ, qparams_after)
print(
"abs difference of Q-quant(DQ)", (Q-Q_after).abs().sum()
)
DQ_after_after = self.dequantize_func(Q_after, qparams_after).to(DQ.dtype)
print(
"SQNR for DQ(Q(DQ)) vs DQ", SQNR(DQ, DQ_after_after)
)
breakpoint()

print(
"SQNR for weight (can be low)", SQNR(W, DQ.cuda())
Expand Down
7 changes: 6 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,16 @@ def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]


k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
breakpoint()
v_out[:, :, input_pos] = v_val
breakpoint()
# k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
# v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)


return k_out, v_out

Expand Down Expand Up @@ -174,7 +180,6 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Expand Down
14 changes: 10 additions & 4 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch.nn.functional as F
from sentencepiece import SentencePieceProcessor

from GPTQ import GenericGPTQRunner, InputRecorder

try:
from GPTQ import GenericGPTQRunner, InputRecorder
from eval import get_task_dict, evaluate, lm_eval
Expand Down Expand Up @@ -286,6 +288,10 @@ def create_quantized_state_dict(
pad_calibration_inputs,
) -> "StateDict":
inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs)
self.mod=self.mod.to("cpu")
inputs=[x.cpu() if hasattr(x, "cpu") else x for x in inputs]
self.mod(*inputs)

print("Tracing model for GPTQ")
GPTQ_runner = GenericGPTQRunner(
self.mod,
Expand Down Expand Up @@ -438,12 +444,12 @@ def convert_for_runtime(self, use_cuda):
return self.mod

class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
from model import find_multiple
self.mod = mod
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding = padding
self.padding_allowed = padding_allowed
self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
self.quantize_func = lambda w, qparams: \
group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
Expand All @@ -453,7 +459,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
[torch.cat(x, dim=1) for x in zip(*qparams_list)]
# skip unless padding=True or its correctly sized
self.skip_layer_func = lambda linear_weight: not (
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding_allowed
)
# we need to do the padding here, both for q and the qparams if necessary
def make_names_and_values_dict_func(q, qparams):
Expand All @@ -472,7 +478,7 @@ def make_names_and_values_dict_func(q, qparams):


def convert_for_runtime(self, use_cuda):
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda)
return self.mod

class WeightOnlyInt4Linear(torch.nn.Module):
Expand Down
19 changes: 19 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working
# echo "base"
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 1
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
# echo "quant good"

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5

# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5

0 comments on commit dfaa329

Please sign in to comment.