-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Efficient Inference Kernel for SpQR #34976
base: main
Are you sure you want to change the base?
Conversation
Hey @elvircrn, Thanks for adding this quantization method ! A very smooth integration 🔥! Just left a few comments |
raise ValueError( | ||
f"SpQR requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to add the checks here, validate_environment
in the quantizer will take care of that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved by removing the checks.
|
||
requires_calibration = True | ||
required_packages = ["spqr_quant"] | ||
optimum_quantizer = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used, to be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved by removing
raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant`") | ||
|
||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | ||
return torch.float16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's raise a warning here if you need to override the torch_dtype like in other quantizers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since SpQR really only supports float16, I raise an error if it's None or anything other than float16. Is this the desired behavior?
@require_spqr | ||
@require_accelerate | ||
class SpQRTest(unittest.TestCase): | ||
model_name = "BlackSamorez/Llama-2-7b-SPQR-2Bit-1x16-hf" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I can't find this model on the hub 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, didn't expect a review so soon. Thank you!
I've uploaded the model and updated this part of the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries 😊 take your time and ping me when you finish the integration
@unittest.skipUnless( | ||
is_spqr_available(), | ||
"test requires `spqr_quant`", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to add the requirement here, it's already verified with the decorator require_spqr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved by removing.
docs/source/en/quantization/spqr.md
Outdated
archivePrefix={arXiv}, | ||
primaryClass={cs.CL} | ||
} | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also add a small section about how to load a model quantized using SpQR using transformers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved by updating the README.
raise ImportError("Using `spqr` quantization requires Accelerate: `pip install accelerate`") | ||
|
||
if not is_spqr_available(): | ||
raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant`") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the quantization method requires a gpu we need to check if it's availabe here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved by adding an is_cuda check.
|
||
|
||
@slow | ||
@require_torch_gpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to add tests for multi_gpu setups ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolving by bringing this back form AQLM. Will the CI be able to pick this up? I don't have the means of testing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it does 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on that, you can add the test and we will check to make sure it to pass the right expected output
fda6b66
to
ee10336
Compare
@MekkCyber I've removed the Draft tag from the PR. It seems that I am unable to pass the |
56c016d
to
e1a91fa
Compare
Could you rebase @elvircrn to update the branch :) |
e1a91fa
to
4ea16a9
Compare
@MekkCyber Rebase done. |
4ea16a9
to
51c9f8d
Compare
Hey @elvircrn, thanks for iterating ! It looks great I just left some minor comments |
if modules_to_not_convert is None: | ||
modules_to_not_convert = [] | ||
|
||
from accelerate import init_empty_weights | ||
from spqr_quant import QuantizedLinear | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add the checks here if the accelerate
and spqr_quant
packages are available before doing the import, what we don't need to do is raising errors, because the packages are assumed to be installed when executing the function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if I understood this, please double-check my newest change.
if torch_dtype is None: | ||
torch_dtype = torch.float16 | ||
logger.info( | ||
"Assuming CUDA is available. Assuming SpQR inference on GPU and loading the model in `torch.float16`." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not assuming here because we already validated the environment 😉
elif torch_dtype != torch.float16: | ||
raise ValueError( | ||
"You cannot any type other than torch.float16 for SpQR. Please either leave it None or set it to" | ||
"torch.float16 explicitly." | ||
) | ||
return torch_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif torch_dtype != torch.float16: | |
raise ValueError( | |
"You cannot any type other than torch.float16 for SpQR. Please either leave it None or set it to" | |
"torch.float16 explicitly." | |
) | |
return torch_dtype | |
elif torch_dtype != torch.float16: | |
raise ValueError( | |
"You cannot use any type other than torch.float16 for SpQR. Please either leave it None or set it to" | |
"torch.float16 explicitly." | |
) | |
return torch_dtype |
model._modules[name] = QuantizedLinear.create_placehodler( | ||
rows=out_features, | ||
cols=in_features, | ||
bits=quantization_config.bits, | ||
beta1=quantization_config.beta1, | ||
beta2=quantization_config.beta2, | ||
dense_weights_shape=dense_weights_shape, | ||
row_offsets_shape=row_offsets_shape, | ||
col_vals_shape=col_vals_shape, | ||
in_perm_shape=in_perm_shape, | ||
) | ||
has_been_replaced = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just FMI, are the beta
and shapes
parameters used solely for loading, or do they also play a role in quantization? In other words, if a model is quantized using specific beta
or shapes
values, can it only be loaded with those same parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SpQR quantized weights are a product of tile-based (beta corresponds to the dimension of this tile where beta1, beta2 represent tile width and height respectively) bi-level quantization. You always specify beta before the quantization starts.
Each SpQR weight also comes with a sparse tensor representing the outlier weights. As this tensor is unstructured, one had to keep track of the size of this matrix during loading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following is a visualization of the compression format from the original publication (https://arxiv.org/pdf/2306.03078):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed explanation
tensor_name = ".".join(current_key_name) | ||
dense_weights_shape = quantization_config.shapes[f"{tensor_name}.dense_weights.shape"] | ||
row_offsets_shape = quantization_config.shapes[f"{tensor_name}.row_offsets.shape"] | ||
col_vals_shape = quantization_config.shapes[f"{tensor_name}.col_vals.shape"] | ||
in_perm_shape = quantization_config.shapes[f"{tensor_name}.in_perm.shape"] | ||
|
||
in_features = module.in_features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the shapes
attribute do we need to specify the shape for all the tensors ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The colval size needs to be specified ahead of time for all weights. The rest can be computed from m, n and bits. Can we keep it the way it is now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can do it both ways easily, not sure what is preferred here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just worried about the quantization config, I'm not sure if we can verify that shapes
contains all the necessary elements in the post_init
. Maybe before trying to access the key in replace_with_spqr_linear
we verify if it's there or we raise an error to inform the user there is something wrong with the config.json they are using, wdyt @SunMarc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my estimate, these configurations are directly tied to the quantized weights. This would be a fatal error in quantization serialization if the shapes don't cover the full set of quantized weights. Perhaps this is something that should be done on the side of SpQR (https://pypi.org/project/spqr-quant/)?
If we were to do it on huggingface/transformers-side:
a) during replace_with_linear, we raise a fatal error (please let me know which one would be appropriate here) if the quantized weights don't have the corresponding key in the shapes config.
b) we do a full check before conducting the replacement in order to conserve memory/device cycles.
Let me know what best suits transformers here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't see the tag to @SunMarc - I thought the wdyt referred to me.
|
||
|
||
@slow | ||
@require_torch_gpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it does 😄
model_id = "meta-llama/Llama-2-7b-hf" | ||
config = AutoConfig.from_pretrained(model_id) | ||
quantization_config = AutoConfig.from_pretrained(self.model_name, return_dict=False).quantization_config | ||
quantization_config = SpQRConfig.from_dict(quantization_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible it's better to use small models for the CI like Qwen 0.5, Bloom 560M, SmolLM 135M...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will have to get back to you on this (might require a significant time investment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a blocker for the merge?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really a blocker for now
436697e
to
16c0278
Compare
@MekkCyber The latest batch of comments are addressed. |
Thanks for the quick iteration @elvircrn 🔥 ! can you re-rebase please 😅 |
tensor_name = ".".join(current_key_name) | ||
dense_weights_shape = quantization_config.shapes[f"{tensor_name}.dense_weights.shape"] | ||
row_offsets_shape = quantization_config.shapes[f"{tensor_name}.row_offsets.shape"] | ||
col_vals_shape = quantization_config.shapes[f"{tensor_name}.col_vals.shape"] | ||
in_perm_shape = quantization_config.shapes[f"{tensor_name}.in_perm.shape"] | ||
|
||
in_features = module.in_features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just worried about the quantization config, I'm not sure if we can verify that shapes
contains all the necessary elements in the post_init
. Maybe before trying to access the key in replace_with_spqr_linear
we verify if it's there or we raise an error to inform the user there is something wrong with the config.json they are using, wdyt @SunMarc
model._modules[name] = QuantizedLinear.create_placehodler( | ||
rows=out_features, | ||
cols=in_features, | ||
bits=quantization_config.bits, | ||
beta1=quantization_config.beta1, | ||
beta2=quantization_config.beta2, | ||
dense_weights_shape=dense_weights_shape, | ||
row_offsets_shape=row_offsets_shape, | ||
col_vals_shape=col_vals_shape, | ||
in_perm_shape=in_perm_shape, | ||
) | ||
has_been_replaced = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed explanation
if is_accelerate_available(): | ||
from accelerate import init_empty_weights | ||
if is_spqr_available(): | ||
from spqr_quant import QuantizedLinear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's exactly what I meant, thank you for fixing it !
model_id = "meta-llama/Llama-2-7b-hf" | ||
config = AutoConfig.from_pretrained(model_id) | ||
quantization_config = AutoConfig.from_pretrained(self.model_name, return_dict=False).quantization_config | ||
quantization_config = SpQRConfig.from_dict(quantization_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really a blocker for now
16c0278
to
efd9f20
Compare
@MekkCyber rebased. |
@MekkCyber Resolved the latest batch of comments. |
@SunMarc @MekkCyber |
cfca9df
to
91ca265
Compare
What does this PR do?
Adds support for efficient single-batch inference for SpQR.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@SunMarc @MekkCyber