Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial GGUF support for flux models #6890

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

brandonrising
Copy link
Collaborator

Summary

Support for GGUF quantized models within the FLUX ecosystem

QA Instructions

Attempt to install and run with multiple GGUF quantized flux models

Merge Plan

After thorough reviews, can be merged when approved

Checklist

  • The PR has a short but descriptive title, suitable for a changelog
  • Tests added / updated (if applicable)
  • Documentation added / updated (if applicable)

@github-actions github-actions bot added python PRs that change python files backend PRs that change backend files labels Sep 19, 2024
@github-actions github-actions bot added the python-tests PRs that change python tests label Sep 20, 2024
Copy link
Collaborator

@lstein lstein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful! Functionally tested and working in my hands with several civitai gguf models.

  • You will need to add the gguf module to the pyproject dependencies.

  • There are a number of typing mismatches flagged by pyright. I have suggested fixes for several of them, but there were others that I couldn't easily fix.

  • I notice that Civitai is distributing some of the GGUF models as a packed zip file. For network model install, we'll have to add an additional step of unpacking the contents and adding them onto the install queue. I can do that work if it is in scope.

Comment on lines +28 to +29
new = super().to(*args, **kwargs)
new.tensor_type = getattr(self, "tensor_type", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
new = super().to(*args, **kwargs)
new.tensor_type = getattr(self, "tensor_type", None)
new = super().to(*args, **kwargs)
assert isinstance(new,self.__class__)
new.tensor_type = getattr(self, "tensor_type", None)

def __new__(cls, *args, tensor_type, tensor_shape, patches=None, **kwargs):
return super().__new__(cls, *args, **kwargs)

def to(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def to(self, *args, **kwargs):
def to(self, *args, **kwargs) -> Tensor:

Could return GGUFTensor here instead, but there would have to be an intermediate class defined.


def __deepcopy__(self, *args, **kwargs):
# Intel Arc fix, ref#50
new = super().__deepcopy__(*args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
new = super().__deepcopy__(*args, **kwargs)
new = super().__deepcopy__(*args, **kwargs) # type: ignore

patch_dtype = None
torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}

def is_ggml_quantized(self, *, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def is_ggml_quantized(self, *, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None):
def is_ggml_quantized(self, *, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None) -> bool:

return False
return is_quantized(weight) or is_quantized(bias)

def _load_from_state_dict(self, state_dict: dict[str, Tensor], prefix: str, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _load_from_state_dict(self, state_dict: dict[str, Tensor], prefix: str, *args, **kwargs):
def _load_from_state_dict(self, state_dict: dict[str, Tensor], prefix: str, *args, **kwargs) -> Tensor:

if arch_str not in {"flux"}:
raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
else:
arch_str = detect_arch({val[0] for val in tensors})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first argument in the call signature for detect_arch() was originally dict[str, Tensor], but here you're passing a set of str. Since the function is only called here, I proposed a change in the detect_arch function below.

class Conv2d(GGUFLayer, nn.Conv2d):
def forward(self, input: Tensor) -> Tensor:
weight, bias = self.cast_bias_weight(input)
return self._conv_forward(input, weight, bias)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self._conv_forward(input, weight, bias)
result = self._conv_forward(input, weight, bias)
assert isinstance(result, Tensor)
return result

original_class = getattr(nn, attr_name)

# Define a helper function to bind the current patcher_attr for each iteration
def create_patch_function(patcher_attr):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is giving a pyright error about missing type annotation.

def create_patch_function(patcher_attr):
# Return a new patch_class function specific to this patcher_attr
@wrapt.decorator
def patch_class(wrapped, instance, args, kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well.

original_classes[attr_name] = original_class

# Apply the patch
setattr(nn, attr_name, create_patch_function(patcher_class)(original_class))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here pyright is complaining about calling the untyped create_patch_function in a typed context. create_patch_function() should be returning a Callable, but I wasn't sure of the exact signature needed.

@lstein
Copy link
Collaborator

lstein commented Sep 22, 2024

I've taken the liberty of adding .gguf to the list of model suffixes that get searched for when scanning a folder for model import.

@lstein
Copy link
Collaborator

lstein commented Sep 22, 2024

After further testing, I did find a GGUF quantized model on Civitai that does not load properly:

The URL is https://civitai.com/models/705823/ggufk-flux-unchained-km-quants (warning: NSFW). This is a Q4_KM model. I guess KM quantization is not yet supported?

It loads and installs as expected, but when generating gives this error:

[2024-09-21 21:04:25,679]::[InvokeAI]::ERROR --> Error while invoking session 82cd8bc8-9036-41f8-b524-4a2796f279c7, invocation 7e498e87-44e4-4d63-91a7-f9e4e65e6ed2 (flux_denoise): Error(s) in loading state_dict for Flux:
        size mismatch for img_in.weight: copying a param with shape torch.Size([768, 256]) from checkpoint, the shape in current model is torch.Size([3072, 64]).
[2024-09-21 21:04:25,679]::[InvokeAI]::ERROR --> Traceback (most recent call last):
  File "/home/lstein/Projects/InvokeAI/invokeai/app/services/session_processor/session_processor_default.py", line 129, in run_node
    output = invocation.invoke_internal(context=context, services=self._services)
  File "/home/lstein/Projects/InvokeAI/invokeai/app/invocations/baseinvocation.py", line 288, in invoke_internal
    output = self.invoke(context)
  File "/home/lstein/invokeai-main/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lstein/Projects/InvokeAI/invokeai/app/invocations/flux_denoise.py", line 88, in invoke
    latents = self._run_diffusion(context)
  File "/home/lstein/Projects/InvokeAI/invokeai/app/invocations/flux_denoise.py", line 124, in _run_diffusion
    transformer_info = context.models.load(self.transformer.transformer)
  File "/home/lstein/Projects/InvokeAI/invokeai/app/services/shared/invocation_context.py", line 369, in load
    return self._services.model_manager.load.load_model(model, _submodel_type)
  File "/home/lstein/Projects/InvokeAI/invokeai/app/services/model_load/model_load_default.py", line 70, in load_model
    ).load_model(model_config, submodel_type)
  File "/home/lstein/Projects/InvokeAI/invokeai/backend/model_manager/load/load_default.py", line 56, in load_model
    locker = self._load_and_cache(model_config, submodel_type)
  File "/home/lstein/Projects/InvokeAI/invokeai/backend/model_manager/load/load_default.py", line 77, in _load_and_cache
    loaded_model = self._load_model(config, submodel_type)
  File "/home/lstein/Projects/InvokeAI/invokeai/backend/model_manager/load/model_loaders/flux.py", line 224, in _load_model
    return self._load_from_singlefile(config)
  File "/home/lstein/Projects/InvokeAI/invokeai/backend/model_manager/load/model_loaders/flux.py", line 248, in _load_from_singlefile
    model.load_state_dict(sd, assign=True)
  File "/home/lstein/invokeai-main/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Flux:
        size mismatch for img_in.weight: copying a param with shape torch.Size([768, 256]) from checkpoint, the shape in current model is torch.Size([3072, 64]).

I also tried installing a quantized GGUF-format T5 encoder, and it failed as expected.

Copy link
Collaborator

@RyanJDick RyanJDick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial comments from a partial review. I'll continue tomorrow.

Comment on lines +27 to +45
if hasattr(nn, attr_name):
# Get the original torch.nn class
original_class = getattr(nn, attr_name)

# Define a helper function to bind the current patcher_attr for each iteration
def create_patch_function(patcher_attr):
# Return a new patch_class function specific to this patcher_attr
@wrapt.decorator
def patch_class(wrapped, instance, args, kwargs):
# Call the _patcher version of the class
return patcher_attr(*args, **kwargs)

return patch_class

# Save the original class for restoration later
original_classes[attr_name] = original_class

# Apply the patch
setattr(nn, attr_name, create_patch_function(patcher_class)(original_class))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if hasattr(nn, attr_name):
# Get the original torch.nn class
original_class = getattr(nn, attr_name)
# Define a helper function to bind the current patcher_attr for each iteration
def create_patch_function(patcher_attr):
# Return a new patch_class function specific to this patcher_attr
@wrapt.decorator
def patch_class(wrapped, instance, args, kwargs):
# Call the _patcher version of the class
return patcher_attr(*args, **kwargs)
return patch_class
# Save the original class for restoration later
original_classes[attr_name] = original_class
# Apply the patch
setattr(nn, attr_name, create_patch_function(patcher_class)(original_class))
# Check if torch.nn has a class with the same name
if hasattr(nn, attr_name):
original_class = getattr(nn, attr_name)
original_classes[attr_name] = original_class
setattr(nn, attr_name, patcher_class)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully follow what we were trying to accomplish with wrapt, but I think it can be ripped out entirely.

Comment on lines +8 to +11
quantized_sd = {
"linear.weight": torch.load("tests/assets/gguf_qweight.pt"),
"linear.bias": torch.load("tests/assets/gguf_qbias.pt"),
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just construct the state_dict in a pytest fixture rather than loading from a file. By storing the weights in a file, we have inflated the repo size by 30MB. Plus, it makes the expected state dict structure unclear to a reader unless they manually inspect the contents of the .pt files.

Also, let's remove all references to the large binary files from the git history.


# Ensure nn.Linear is restored
assert nn.Linear is not TestGGUFPatcher.Linear
assert isinstance(nn.Linear(4, 8), nn.Linear)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is not checking anything. For this to be a meaningful check, we'd have to store an original_linear = nn.Linear before patching and then compare against that:

Suggested change
assert isinstance(nn.Linear(4, 8), nn.Linear)
assert isinstance(nn.Linear(4, 8), original_linear)

for match_list in match_lists:
if all(key in state_dict for key in match_list):
return arch
breakpoint()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this breakpoint.

Comment on lines +12 to +14
def gguf_sd_loader(
path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
) -> dict[str, GGUFTensor]:
Copy link
Collaborator

@RyanJDick RyanJDick Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gguf_sd_loader() is pretty messy (unnecessary complexity, unused vars, mixes general sd loading with model-specific key modifications, etc.). I know it's mostly copied from elsewhere, but feels like we should clean it up. I'd propose having a simple utility function for loading GGUF files, and then separate functions for key prefix handling and model-specific modifications (if we even need those for our use case).

The utility loader would be as simple as (not tested):

def load_gguf_sd(path: Path) -> dict[str, GGUFTensor]:
    reader = gguf.GGUFReader(path)
    sd: dict[str, GGUFTensor] = {}
    for tensor in reader.tensors:
        torch_tensor = torch.from_numpy(tensor.data)
        shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
        if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
            torch_tensor = torch_tensor.view(*shape)
        sd[tensor.name] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)

    return sd

Also note that I replaced {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} with TORCH_COMPATIBLE_QTYPES.

Copy link
Collaborator

@RyanJDick RyanJDick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another chunk of review. Next up I'll tackle GGUFTensor and GGUFLayer - looks like there's still quite a bit of room for cleanup there.

Comment on lines +219 to +220
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This check is redundant given this stricter check in _load_from_singlefile(...): assert isinstance(config, MainGGUFCheckpointConfig)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly wrote it this way in case we choose to use other "submodels" on the same loader

invokeai/backend/model_manager/load/model_loaders/flux.py Outdated Show resolved Hide resolved
invokeai/backend/model_manager/load/model_loaders/flux.py Outdated Show resolved Hide resolved
Comment on lines +59 to +62
if str(path).endswith(".gguf"):
checkpoint = gguf_sd_loader(Path(path))
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we load GGUF files onto the 'meta' device here to match the behaviour of the torch.load(...) call? It seems risky to load such large models into memory without registering them with the model cache.

The following might work. Or modifications may be necessary insided gguf_sd_loader(...):

with torch.device("meta"):
    checkpoint = gguf_sd_loader(Path(path))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to me

from typing import Optional, Union

import gguf
from torch import Tensor, device, dtype, float32, nn, zeros_like
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The convention throughout most of the codebase is to import torch and then reference the submodule like this: torch.Tensor, torch.device, torch.nn.Linear, etc.

PATCH_TYPES = Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]


def move_patch_to_device(item: PATCH_TYPES, device: torch.device) -> PATCH_TYPES:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move_patch_to_device(...) appears to be unused. Can we remove it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I didn't pull in the patch logic from that node, so this isn't necessary

Comment on lines +408 to +409
elif model_path.suffix.endswith(".gguf"):
return gguf_sd_loader(model_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There have been a number of security vulnerabilities reported in the GGUF file format in the past (https://www.databricks.com/blog/ggml-gguf-file-format-vulnerabilities).

We should make sure that we have set the minimum gguf version in pyproject.toml accordingly.

Can you think of anything else we should be doing to protect users? I'm calling it out here since the method name _scan_and_load_checkpoint() implies that any necessary security checks are happening here.

Comment on lines +100 to +112
def _save_to_state_dict(self, *args, **kwargs):
if self.is_ggml_quantized():
return self.ggml_save_to_state_dict(*args, **kwargs)
return super()._save_to_state_dict(*args, **kwargs)

def ggml_save_to_state_dict(self, destination: dict[str, Tensor], prefix: str):
# This is a fake state dict for vram estimation
weight = zeros_like(self.weight, device=device("meta"))
destination[prefix + "weight"] = weight
if self.bias is not None:
bias = zeros_like(self.bias, device=device("meta"))
destination[prefix + "bias"] = bias
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behavior seems weird to me. If we are using a torch-compatible quantization format we save a state_dict, otherwise we save dummy meta weights. Should we just rip this out since we aren't using it? I can't really think of a valid use case for it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll rip it out and do testing. Unsure if there's some reason the weight/bias keys wouldn't be set properly otherwise

Comment on lines +72 to +74
if weight is None or bias is None:
return False
return is_quantized(weight) or is_quantized(bias)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic seems wrong. We should still check for quantization even if one of weight or bias is None.

@github-actions github-actions bot added Root python-deps PRs that change python dependencies labels Sep 26, 2024
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend PRs that change backend files python PRs that change python files python-deps PRs that change python dependencies python-tests PRs that change python tests Root
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants