Skip to content

Commit

Permalink
feature: updates refiners vendored library (#458)
Browse files Browse the repository at this point in the history
* feature: updates refiners vendored library

has a small bugfix that will soon be replaced by a better fix from upstream refiners

Co-authored-by: Bryce <github20210803@accounts.brycedrennan.com>
  • Loading branch information
jaydrennan and brycedrennan authored Jan 19, 2024
1 parent fbb16f6 commit 1bf53e4
Show file tree
Hide file tree
Showing 17 changed files with 663 additions and 423 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ vendorize_normal_map:


vendorize_refiners:
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=20c229903f53d05dc1c44659ec97603660ef964c && \
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=ce3035923ba71bcb5044708d2f1c37fd1d6722e9 && \
make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \
mkdir -p ./imaginairy/vendored/$$PKG && \
rm -rf ./imaginairy/vendored/$$PKG/* && \
Expand Down
319 changes: 227 additions & 92 deletions imaginairy/vendored/refiners/fluxion/adapters/lora.py
Original file line number Diff line number Diff line change
@@ -1,130 +1,265 @@
from typing import Any, Generic, Iterable, TypeVar
from abc import ABC, abstractmethod
from typing import Any

from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
from torch.nn.init import normal_, zeros_

import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain

T = TypeVar("T", bound=fl.Chain)
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)


class Lora(fl.Chain):
class Lora(fl.Chain, ABC):
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.scale: float = 1.0
self._scale = scale

super().__init__(
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype),
fl.Lambda(func=self.scale_outputs),
)
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))

normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
zeros_(tensor=self.Linear_2.weight)
normal_(tensor=self.down.weight, std=1 / self.rank)
zeros_(tensor=self.up.weight)

def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
@abstractmethod
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
...

def set_scale(self, scale: float) -> None:
self.scale = scale

def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
@property
def down(self) -> fl.WeightedModule:
down_layer = self[0]
assert isinstance(down_layer, fl.WeightedModule)
return down_layer

@property
def up_weight(self) -> Tensor:
return self.Linear_2.weight.data
def up(self) -> fl.WeightedModule:
up_layer = self[1]
assert isinstance(up_layer, fl.WeightedModule)
return up_layer

@property
def down_weight(self) -> Tensor:
return self.Linear_1.weight.data
def scale(self) -> float:
return self._scale

@scale.setter
def scale(self, value: float) -> None:
self._scale = value
self.ensure_find(fl.Multiply).scale = value

@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "Lora":
match (up.ndim, down.ndim):
case (2, 2):
return LinearLora.from_weights(up=up, down=down)
case (4, 4):
return Conv2dLora.from_weights(up=up, down=down)
case _:
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")

@classmethod
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
"""
Create a dictionary of LoRA layers from a state dict.
Expects the state dict to be a succession of down and up weights.
"""
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
loras: dict[str, Lora] = {}
for down_key, down_tensor, up_tensor in zip(
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
):
key = ".".join(down_key.split(".")[:-2])
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor)
return loras

class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
@abstractmethod
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any:
...

def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
assert down_weight.shape == self.down.weight.shape
assert up_weight.shape == self.up.weight.shape
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))


class LinearLora(Lora):
def __init__(
self,
target: fl.Linear,
in_features: int,
out_features: int,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = target.in_features
self.out_features = target.out_features
self.rank = rank
self.scale = scale
with self.setup_adapter(target):
super().__init__(
target,
Lora(
in_features=target.in_features,
out_features=target.out_features,
rank=rank,
device=target.device,
dtype=target.dtype,
),
)
self.Lora.set_scale(scale=scale)


class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
self.in_features = in_features
self.out_features = out_features

super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)

@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "LinearLora":
assert up.ndim == 2 and down.ndim == 2
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
lora = cls(
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype
)
lora.load_weights(down_weight=down, up_weight=up)
return lora

def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Linear):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue

if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue

if layer.in_features == self.in_features and layer.out_features == self.out_features:
return LoraAdapter(target=layer, lora=self), parent

def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Linear, fl.Linear]:
return (
fl.Linear(
in_features=self.in_features,
out_features=self.rank,
bias=False,
device=device,
dtype=dtype,
),
fl.Linear(
in_features=self.rank,
out_features=self.out_features,
bias=False,
device=device,
dtype=dtype,
),
)


class Conv2dLora(Lora):
def __init__(
self,
target: T,
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
rank: int | None = None,
in_channels: int,
out_channels: int,
rank: int = 16,
scale: float = 1.0,
weights: list[Tensor] | None = None,
kernel_size: tuple[int, int] = (1, 3),
stride: tuple[int, int] = (1, 1),
padding: tuple[int, int] = (0, 1),
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding

super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)

@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "Conv2dLora":
assert up.ndim == 4 and down.ndim == 4
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
down_padding = 1 if down_kernel_size == 3 else 0
up_padding = 1 if up_kernel_size == 3 else 0
lora = cls(
in_channels=down.shape[1],
out_channels=up.shape[0],
rank=down.shape[0],
kernel_size=(down_kernel_size, up_kernel_size),
padding=(down_padding, up_padding),
device=up.device,
dtype=up.dtype,
)
lora.load_weights(down_weight=down, up_weight=up)
return lora

def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Conv2d):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue

if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue

if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
if layer.stride != (self.stride[0], self.stride[0]):
self.down.stride = layer.stride

return LoraAdapter(
target=layer,
lora=self,
), parent

def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Conv2d, fl.Conv2d]:
return (
fl.Conv2d(
in_channels=self.in_channels,
out_channels=self.rank,
kernel_size=self.kernel_size[0],
stride=self.stride[0],
padding=self.padding[0],
use_bias=False,
device=device,
dtype=dtype,
),
fl.Conv2d(
in_channels=self.rank,
out_channels=self.out_channels,
kernel_size=self.kernel_size[1],
stride=self.stride[1],
padding=self.padding[1],
use_bias=False,
device=device,
dtype=dtype,
),
)


class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None:
with self.setup_adapter(target):
super().__init__(target)

if weights is not None:
assert len(weights) % 2 == 0
weights_rank = weights[0].shape[1]
if rank is None:
rank = weights_rank
else:
assert rank == weights_rank

assert rank is not None, "either pass a rank or weights"

self.sub_targets = sub_targets
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []

for linear, parent in self.sub_targets:
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))

if weights is not None:
assert len(self.sub_adapters) == (len(weights) // 2)
for i, (adapter, _) in enumerate(self.sub_adapters):
lora = adapter.Lora
assert (
lora.rank == weights[i * 2].shape[1]
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])

def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
for adapter, adapter_parent in self.sub_adapters:
adapter.inject(adapter_parent)
return super().inject(parent)

def eject(self) -> None:
for adapter, _ in self.sub_adapters:
adapter.eject()
super().eject()
super().__init__(target, lora)

@property
def weights(self) -> list[Tensor]:
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]
def lora(self) -> Lora:
return self.ensure_find(Lora)

@property
def scale(self) -> float:
return self.lora.scale

@scale.setter
def scale(self, value: float) -> None:
self.lora.scale = value
3 changes: 3 additions & 0 deletions imaginairy/vendored/refiners/fluxion/layers/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def device(self) -> Device:
def dtype(self) -> DType:
return self.weight.dtype

def __str__(self) -> str:
return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})"


class TreeNode(TypedDict):
value: str
Expand Down
Loading

0 comments on commit 1bf53e4

Please sign in to comment.