Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient Inference Kernel for SpQR #34976

Open
wants to merge 30 commits into
base: main
Choose a base branch
from

Conversation

elvircrn
Copy link

@elvircrn elvircrn commented Nov 27, 2024

What does this PR do?

Adds support for efficient single-batch inference for SpQR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@SunMarc @MekkCyber

@elvircrn elvircrn marked this pull request as draft November 27, 2024 15:40
@elvircrn elvircrn changed the title Spqr quantizer Efficient Inference Kernel for SpQR Nov 27, 2024
@SunMarc SunMarc requested a review from MekkCyber November 28, 2024 14:45
@MekkCyber
Copy link
Contributor

Hey @elvircrn, Thanks for adding this quantization method ! A very smooth integration 🔥! Just left a few comments

src/transformers/integrations/spqr.py Show resolved Hide resolved
raise ValueError(
f"SpQR requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add the checks here, validate_environment in the quantizer will take care of that

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by removing the checks.


requires_calibration = True
required_packages = ["spqr_quant"]
optimum_quantizer = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used, to be removed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by removing

raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant`")

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
return torch.float16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's raise a warning here if you need to override the torch_dtype like in other quantizers

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since SpQR really only supports float16, I raise an error if it's None or anything other than float16. Is this the desired behavior?

src/transformers/utils/quantization_config.py Show resolved Hide resolved
@require_spqr
@require_accelerate
class SpQRTest(unittest.TestCase):
model_name = "BlackSamorez/Llama-2-7b-SPQR-2Bit-1x16-hf"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I can't find this model on the hub 🤔

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, didn't expect a review so soon. Thank you!

I've uploaded the model and updated this part of the test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries 😊 take your time and ping me when you finish the integration

@unittest.skipUnless(
is_spqr_available(),
"test requires `spqr_quant`",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add the requirement here, it's already verified with the decorator require_spqr

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by removing.

archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also add a small section about how to load a model quantized using SpQR using transformers

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by updating the README.

raise ImportError("Using `spqr` quantization requires Accelerate: `pip install accelerate`")

if not is_spqr_available():
raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant`")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the quantization method requires a gpu we need to check if it's availabe here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by adding an is_cuda check.



@slow
@require_torch_gpu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add tests for multi_gpu setups ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolving by bringing this back form AQLM. Will the CI be able to pick this up? I don't have the means of testing this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it does 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on that, you can add the test and we will check to make sure it to pass the right expected output

@elvircrn
Copy link
Author

elvircrn commented Dec 2, 2024

@MekkCyber I've removed the Draft tag from the PR.

It seems that I am unable to pass the test_raise_if_non_quantized. Could you give me some pointers on this?

@elvircrn elvircrn force-pushed the spqr-quantizer branch 4 times, most recently from 56c016d to e1a91fa Compare December 5, 2024 17:16
@MekkCyber
Copy link
Contributor

Could you rebase @elvircrn to update the branch :)

@elvircrn
Copy link
Author

elvircrn commented Dec 6, 2024

@MekkCyber Rebase done.

@MekkCyber
Copy link
Contributor

Hey @elvircrn, thanks for iterating ! It looks great I just left some minor comments

Comment on lines 49 to 56
if modules_to_not_convert is None:
modules_to_not_convert = []

from accelerate import init_empty_weights
from spqr_quant import QuantizedLinear

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add the checks here if the accelerate and spqr_quant packages are available before doing the import, what we don't need to do is raising errors, because the packages are assumed to be installed when executing the function

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I understood this, please double-check my newest change.

Comment on lines 53 to 57
if torch_dtype is None:
torch_dtype = torch.float16
logger.info(
"Assuming CUDA is available. Assuming SpQR inference on GPU and loading the model in `torch.float16`."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not assuming here because we already validated the environment 😉

Comment on lines 58 to 61
elif torch_dtype != torch.float16:
raise ValueError(
"You cannot any type other than torch.float16 for SpQR. Please either leave it None or set it to"
"torch.float16 explicitly."
)
return torch_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif torch_dtype != torch.float16:
raise ValueError(
"You cannot any type other than torch.float16 for SpQR. Please either leave it None or set it to"
"torch.float16 explicitly."
)
return torch_dtype
elif torch_dtype != torch.float16:
raise ValueError(
"You cannot use any type other than torch.float16 for SpQR. Please either leave it None or set it to"
"torch.float16 explicitly."
)
return torch_dtype

Comment on lines +73 to +105
model._modules[name] = QuantizedLinear.create_placehodler(
rows=out_features,
cols=in_features,
bits=quantization_config.bits,
beta1=quantization_config.beta1,
beta2=quantization_config.beta2,
dense_weights_shape=dense_weights_shape,
row_offsets_shape=row_offsets_shape,
col_vals_shape=col_vals_shape,
in_perm_shape=in_perm_shape,
)
has_been_replaced = True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FMI, are the beta and shapes parameters used solely for loading, or do they also play a role in quantization? In other words, if a model is quantized using specific beta or shapes values, can it only be loaded with those same parameters?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SpQR quantized weights are a product of tile-based (beta corresponds to the dimension of this tile where beta1, beta2 represent tile width and height respectively) bi-level quantization. You always specify beta before the quantization starts.

Each SpQR weight also comes with a sparse tensor representing the outlier weights. As this tensor is unstructured, one had to keep track of the size of this matrix during loading.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following is a visualization of the compression format from the original publication (https://arxiv.org/pdf/2306.03078):

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed explanation

Comment on lines 64 to 90
tensor_name = ".".join(current_key_name)
dense_weights_shape = quantization_config.shapes[f"{tensor_name}.dense_weights.shape"]
row_offsets_shape = quantization_config.shapes[f"{tensor_name}.row_offsets.shape"]
col_vals_shape = quantization_config.shapes[f"{tensor_name}.col_vals.shape"]
in_perm_shape = quantization_config.shapes[f"{tensor_name}.in_perm.shape"]

in_features = module.in_features
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the shapes attribute do we need to specify the shape for all the tensors ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The colval size needs to be specified ahead of time for all weights. The rest can be computed from m, n and bits. Can we keep it the way it is now?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do it both ways easily, not sure what is preferred here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just worried about the quantization config, I'm not sure if we can verify that shapes contains all the necessary elements in the post_init. Maybe before trying to access the key in replace_with_spqr_linear we verify if it's there or we raise an error to inform the user there is something wrong with the config.json they are using, wdyt @SunMarc

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my estimate, these configurations are directly tied to the quantized weights. This would be a fatal error in quantization serialization if the shapes don't cover the full set of quantized weights. Perhaps this is something that should be done on the side of SpQR (https://pypi.org/project/spqr-quant/)?

If we were to do it on huggingface/transformers-side:

a) during replace_with_linear, we raise a fatal error (please let me know which one would be appropriate here) if the quantized weights don't have the corresponding key in the shapes config.
b) we do a full check before conducting the replacement in order to conserve memory/device cycles.

Let me know what best suits transformers here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see the tag to @SunMarc - I thought the wdyt referred to me.



@slow
@require_torch_gpu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it does 😄

tests/quantization/spqr_integration/test_spqr.py Outdated Show resolved Hide resolved
Comment on lines +113 to +116
model_id = "meta-llama/Llama-2-7b-hf"
config = AutoConfig.from_pretrained(model_id)
quantization_config = AutoConfig.from_pretrained(self.model_name, return_dict=False).quantization_config
quantization_config = SpQRConfig.from_dict(quantization_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible it's better to use small models for the CI like Qwen 0.5, Bloom 560M, SmolLM 135M...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will have to get back to you on this (might require a significant time investment).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a blocker for the merge?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really a blocker for now

@elvircrn
Copy link
Author

elvircrn commented Dec 9, 2024

@MekkCyber The latest batch of comments are addressed.

@MekkCyber
Copy link
Contributor

MekkCyber commented Dec 9, 2024

Thanks for the quick iteration @elvircrn 🔥 ! can you re-rebase please 😅

Comment on lines 64 to 90
tensor_name = ".".join(current_key_name)
dense_weights_shape = quantization_config.shapes[f"{tensor_name}.dense_weights.shape"]
row_offsets_shape = quantization_config.shapes[f"{tensor_name}.row_offsets.shape"]
col_vals_shape = quantization_config.shapes[f"{tensor_name}.col_vals.shape"]
in_perm_shape = quantization_config.shapes[f"{tensor_name}.in_perm.shape"]

in_features = module.in_features
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just worried about the quantization config, I'm not sure if we can verify that shapes contains all the necessary elements in the post_init. Maybe before trying to access the key in replace_with_spqr_linear we verify if it's there or we raise an error to inform the user there is something wrong with the config.json they are using, wdyt @SunMarc

Comment on lines +73 to +105
model._modules[name] = QuantizedLinear.create_placehodler(
rows=out_features,
cols=in_features,
bits=quantization_config.bits,
beta1=quantization_config.beta1,
beta2=quantization_config.beta2,
dense_weights_shape=dense_weights_shape,
row_offsets_shape=row_offsets_shape,
col_vals_shape=col_vals_shape,
in_perm_shape=in_perm_shape,
)
has_been_replaced = True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed explanation

Comment on lines +52 to +55
if is_accelerate_available():
from accelerate import init_empty_weights
if is_spqr_available():
from spqr_quant import QuantizedLinear
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's exactly what I meant, thank you for fixing it !

Comment on lines +113 to +116
model_id = "meta-llama/Llama-2-7b-hf"
config = AutoConfig.from_pretrained(model_id)
quantization_config = AutoConfig.from_pretrained(self.model_name, return_dict=False).quantization_config
quantization_config = SpQRConfig.from_dict(quantization_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really a blocker for now

@elvircrn
Copy link
Author

elvircrn commented Dec 9, 2024

@MekkCyber rebased.

@MekkCyber MekkCyber requested a review from SunMarc December 9, 2024 13:04
@elvircrn
Copy link
Author

elvircrn commented Dec 9, 2024

@MekkCyber Resolved the latest batch of comments.

@elvircrn
Copy link
Author

elvircrn commented Dec 9, 2024

@SunMarc @MekkCyber
I've pushed a check for keys() in shapes in replace_with_spqr_linear. Let me know if something in this ballpark works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants