-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial GGUF support for flux models #6890
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonderful! Functionally tested and working in my hands with several civitai gguf models.
-
You will need to add the
gguf
module to the pyproject dependencies. -
There are a number of typing mismatches flagged by pyright. I have suggested fixes for several of them, but there were others that I couldn't easily fix.
-
I notice that Civitai is distributing some of the GGUF models as a packed zip file. For network model install, we'll have to add an additional step of unpacking the contents and adding them onto the install queue. I can do that work if it is in scope.
new = super().to(*args, **kwargs) | ||
new.tensor_type = getattr(self, "tensor_type", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new = super().to(*args, **kwargs) | |
new.tensor_type = getattr(self, "tensor_type", None) | |
new = super().to(*args, **kwargs) | |
assert isinstance(new,self.__class__) | |
new.tensor_type = getattr(self, "tensor_type", None) |
def __new__(cls, *args, tensor_type, tensor_shape, patches=None, **kwargs): | ||
return super().__new__(cls, *args, **kwargs) | ||
|
||
def to(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def to(self, *args, **kwargs): | |
def to(self, *args, **kwargs) -> Tensor: |
Could return GGUFTensor
here instead, but there would have to be an intermediate class defined.
|
||
def __deepcopy__(self, *args, **kwargs): | ||
# Intel Arc fix, ref#50 | ||
new = super().__deepcopy__(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new = super().__deepcopy__(*args, **kwargs) | |
new = super().__deepcopy__(*args, **kwargs) # type: ignore |
patch_dtype = None | ||
torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} | ||
|
||
def is_ggml_quantized(self, *, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def is_ggml_quantized(self, *, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None): | |
def is_ggml_quantized(self, *, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None) -> bool: |
return False | ||
return is_quantized(weight) or is_quantized(bias) | ||
|
||
def _load_from_state_dict(self, state_dict: dict[str, Tensor], prefix: str, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _load_from_state_dict(self, state_dict: dict[str, Tensor], prefix: str, *args, **kwargs): | |
def _load_from_state_dict(self, state_dict: dict[str, Tensor], prefix: str, *args, **kwargs) -> Tensor: |
if arch_str not in {"flux"}: | ||
raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}") | ||
else: | ||
arch_str = detect_arch({val[0] for val in tensors}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first argument in the call signature for detect_arch()
was originally dict[str, Tensor]
, but here you're passing a set of str. Since the function is only called here, I proposed a change in the detect_arch
function below.
class Conv2d(GGUFLayer, nn.Conv2d): | ||
def forward(self, input: Tensor) -> Tensor: | ||
weight, bias = self.cast_bias_weight(input) | ||
return self._conv_forward(input, weight, bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return self._conv_forward(input, weight, bias) | |
result = self._conv_forward(input, weight, bias) | |
assert isinstance(result, Tensor) | |
return result |
original_class = getattr(nn, attr_name) | ||
|
||
# Define a helper function to bind the current patcher_attr for each iteration | ||
def create_patch_function(patcher_attr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is giving a pyright error about missing type annotation.
def create_patch_function(patcher_attr): | ||
# Return a new patch_class function specific to this patcher_attr | ||
@wrapt.decorator | ||
def patch_class(wrapped, instance, args, kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here as well.
original_classes[attr_name] = original_class | ||
|
||
# Apply the patch | ||
setattr(nn, attr_name, create_patch_function(patcher_class)(original_class)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here pyright is complaining about calling the untyped create_patch_function
in a typed context. create_patch_function()
should be returning a Callable, but I wasn't sure of the exact signature needed.
I've taken the liberty of adding |
After further testing, I did find a GGUF quantized model on Civitai that does not load properly: The URL is https://civitai.com/models/705823/ggufk-flux-unchained-km-quants (warning: NSFW). This is a Q4_KM model. I guess KM quantization is not yet supported? It loads and installs as expected, but when generating gives this error:
I also tried installing a quantized GGUF-format T5 encoder, and it failed as expected. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initial comments from a partial review. I'll continue tomorrow.
if hasattr(nn, attr_name): | ||
# Get the original torch.nn class | ||
original_class = getattr(nn, attr_name) | ||
|
||
# Define a helper function to bind the current patcher_attr for each iteration | ||
def create_patch_function(patcher_attr): | ||
# Return a new patch_class function specific to this patcher_attr | ||
@wrapt.decorator | ||
def patch_class(wrapped, instance, args, kwargs): | ||
# Call the _patcher version of the class | ||
return patcher_attr(*args, **kwargs) | ||
|
||
return patch_class | ||
|
||
# Save the original class for restoration later | ||
original_classes[attr_name] = original_class | ||
|
||
# Apply the patch | ||
setattr(nn, attr_name, create_patch_function(patcher_class)(original_class)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if hasattr(nn, attr_name): | |
# Get the original torch.nn class | |
original_class = getattr(nn, attr_name) | |
# Define a helper function to bind the current patcher_attr for each iteration | |
def create_patch_function(patcher_attr): | |
# Return a new patch_class function specific to this patcher_attr | |
@wrapt.decorator | |
def patch_class(wrapped, instance, args, kwargs): | |
# Call the _patcher version of the class | |
return patcher_attr(*args, **kwargs) | |
return patch_class | |
# Save the original class for restoration later | |
original_classes[attr_name] = original_class | |
# Apply the patch | |
setattr(nn, attr_name, create_patch_function(patcher_class)(original_class)) | |
# Check if torch.nn has a class with the same name | |
if hasattr(nn, attr_name): | |
original_class = getattr(nn, attr_name) | |
original_classes[attr_name] = original_class | |
setattr(nn, attr_name, patcher_class) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully follow what we were trying to accomplish with wrapt
, but I think it can be ripped out entirely.
quantized_sd = { | ||
"linear.weight": torch.load("tests/assets/gguf_qweight.pt"), | ||
"linear.bias": torch.load("tests/assets/gguf_qbias.pt"), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just construct the state_dict in a pytest fixture rather than loading from a file. By storing the weights in a file, we have inflated the repo size by 30MB. Plus, it makes the expected state dict structure unclear to a reader unless they manually inspect the contents of the .pt files.
Also, let's remove all references to the large binary files from the git history.
|
||
# Ensure nn.Linear is restored | ||
assert nn.Linear is not TestGGUFPatcher.Linear | ||
assert isinstance(nn.Linear(4, 8), nn.Linear) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is not checking anything. For this to be a meaningful check, we'd have to store an original_linear = nn.Linear
before patching and then compare against that:
assert isinstance(nn.Linear(4, 8), nn.Linear) | |
assert isinstance(nn.Linear(4, 8), original_linear) |
for match_list in match_lists: | ||
if all(key in state_dict for key in match_list): | ||
return arch | ||
breakpoint() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this breakpoint.
def gguf_sd_loader( | ||
path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16 | ||
) -> dict[str, GGUFTensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gguf_sd_loader()
is pretty messy (unnecessary complexity, unused vars, mixes general sd loading with model-specific key modifications, etc.). I know it's mostly copied from elsewhere, but feels like we should clean it up. I'd propose having a simple utility function for loading GGUF files, and then separate functions for key prefix handling and model-specific modifications (if we even need those for our use case).
The utility loader would be as simple as (not tested):
def load_gguf_sd(path: Path) -> dict[str, GGUFTensor]:
reader = gguf.GGUFReader(path)
sd: dict[str, GGUFTensor] = {}
for tensor in reader.tensors:
torch_tensor = torch.from_numpy(tensor.data)
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
torch_tensor = torch_tensor.view(*shape)
sd[tensor.name] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
return sd
Also note that I replaced {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
with TORCH_COMPATIBLE_QTYPES.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another chunk of review. Next up I'll tackle GGUFTensor and GGUFLayer - looks like there's still quite a bit of room for cleanup there.
if not isinstance(config, CheckpointConfigBase): | ||
raise ValueError("Only CheckpointConfigBase models are currently supported here.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This check is redundant given this stricter check in _load_from_singlefile(...): assert isinstance(config, MainGGUFCheckpointConfig)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly wrote it this way in case we choose to use other "submodels" on the same loader
if str(path).endswith(".gguf"): | ||
checkpoint = gguf_sd_loader(Path(path)) | ||
else: | ||
checkpoint = torch.load(path, map_location=torch.device("meta")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we load GGUF files onto the 'meta' device here to match the behaviour of the torch.load(...)
call? It seems risky to load such large models into memory without registering them with the model cache.
The following might work. Or modifications may be necessary insided gguf_sd_loader(...)
:
with torch.device("meta"):
checkpoint = gguf_sd_loader(Path(path))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense to me
from typing import Optional, Union | ||
|
||
import gguf | ||
from torch import Tensor, device, dtype, float32, nn, zeros_like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The convention throughout most of the codebase is to import torch
and then reference the submodule like this: torch.Tensor
, torch.device
, torch.nn.Linear
, etc.
PATCH_TYPES = Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] | ||
|
||
|
||
def move_patch_to_device(item: PATCH_TYPES, device: torch.device) -> PATCH_TYPES: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move_patch_to_device(...)
appears to be unused. Can we remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I didn't pull in the patch logic from that node, so this isn't necessary
elif model_path.suffix.endswith(".gguf"): | ||
return gguf_sd_loader(model_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There have been a number of security vulnerabilities reported in the GGUF file format in the past (https://www.databricks.com/blog/ggml-gguf-file-format-vulnerabilities).
We should make sure that we have set the minimum gguf version in pyproject.toml accordingly.
Can you think of anything else we should be doing to protect users? I'm calling it out here since the method name _scan_and_load_checkpoint()
implies that any necessary security checks are happening here.
def _save_to_state_dict(self, *args, **kwargs): | ||
if self.is_ggml_quantized(): | ||
return self.ggml_save_to_state_dict(*args, **kwargs) | ||
return super()._save_to_state_dict(*args, **kwargs) | ||
|
||
def ggml_save_to_state_dict(self, destination: dict[str, Tensor], prefix: str): | ||
# This is a fake state dict for vram estimation | ||
weight = zeros_like(self.weight, device=device("meta")) | ||
destination[prefix + "weight"] = weight | ||
if self.bias is not None: | ||
bias = zeros_like(self.bias, device=device("meta")) | ||
destination[prefix + "bias"] = bias | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This behavior seems weird to me. If we are using a torch-compatible quantization format we save a state_dict, otherwise we save dummy meta weights. Should we just rip this out since we aren't using it? I can't really think of a valid use case for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll rip it out and do testing. Unsure if there's some reason the weight/bias keys wouldn't be set properly otherwise
if weight is None or bias is None: | ||
return False | ||
return is_quantized(weight) or is_quantized(bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic seems wrong. We should still check for quantization even if one of weight
or bias
is None.
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
Summary
Support for GGUF quantized models within the FLUX ecosystem
QA Instructions
Attempt to install and run with multiple GGUF quantized flux models
Merge Plan
After thorough reviews, can be merged when approved
Checklist