Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for LyCORIS IA3 format #4234

Merged
merged 2 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions invokeai/backend/model_management/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def apply_lora(
# with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale

if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
Expand Down Expand Up @@ -361,7 +361,8 @@ def apply_lora(

layer.to(dtype=torch.float32)
layer_key = layer_key.replace(prefix, "")
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
# TODO: rewrite to pass original tensor weight(required by ia3)
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
if layer_key is blended_loras:
blended_loras[layer_key] += layer_weight
else:
Expand Down
89 changes: 49 additions & 40 deletions invokeai/backend/model_management/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,41 +122,7 @@ def __init__(
self.rank = None # set in layer implementation
self.layer_key = layer_key

def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)

else:
op = torch.nn.functional.linear
extra_args = {}

weight = self.get_weight()

bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)

def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
raise NotImplementedError()

def calc_size(self) -> int:
Expand Down Expand Up @@ -197,7 +163,7 @@ def __init__(

self.rank = self.down.shape[0]

def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
Expand Down Expand Up @@ -260,7 +226,7 @@ def __init__(

self.rank = self.w1_b.shape[0]

def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)

Expand Down Expand Up @@ -342,7 +308,7 @@ def __init__(
else:
self.rank = None # unscaled

def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
Expand Down Expand Up @@ -410,7 +376,7 @@ def __init__(

self.rank = None # unscaled

def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
return self.weight

def calc_size(self) -> int:
Expand All @@ -428,6 +394,45 @@ def to(
self.weight = self.weight.to(device=device, dtype=dtype)


class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor

def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)

self.weight = values["weight"]
self.on_input = values["on_input"]

self.rank = None # unscaled

def get_weight(self, orig_weight: torch.Tensor):
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
return orig_weight * weight

def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size

def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)

self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)


# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
Expand Down Expand Up @@ -547,11 +552,15 @@ def from_checkpoint(
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)

# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)

# ia3
elif "weight" in values and "on_input" in values:
layer = IA3Layer(layer_key, values)

else:
# TODO: ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")

Expand Down
34 changes: 20 additions & 14 deletions invokeai/backend/model_management/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,43 @@ def lora_token_vector_length(checkpoint: dict) -> int:
def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None

if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)

# check lora/locon
if ".lora_down.weight" in key:
if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1]

# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif ".hada_w1_b" in key or ".hada_w2_b" in key:
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1]

# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif ".lokr_" in key:
_lokr_key = key.split(".")[0]

if _lokr_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"]
elif _lokr_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"]
elif "lokr_" in lora_key:
if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format

if _lokr_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"]
elif _lokr_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"]
if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format

lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]

elif ".diff" in key:
elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1]

# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]

return lora_token_vector_length

lora_token_vector_length = None
Expand Down
Loading