Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix int8 lora #431

Merged
merged 5 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions onediff_comfy_nodes/utils/diffusers_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ def _rewrite_attention(attn):
os.environ["ONEFLOW_FUSE_QUANT_TO_MATMUL"] = old_env


def replace_module_with_quantizable_module(diffusion_model, calibrate_info_path):
def replace_module_with_quantizable_module(
diffusion_model, calibrate_info_path, use_rewrite_attn=True
):
from diffusers_quant.utils import get_quantize_module

_use_graph()
Expand All @@ -242,24 +244,25 @@ def replace_module_with_quantizable_module(diffusion_model, calibrate_info_path)
convert_fn=maybe_allow_in_graph,
)
modify_sub_module(diffusion_model, sub_module_name, sub_mod)
if use_rewrite_attn:
print(f"{use_rewrite_attn=}, rewrite CrossAttention")
try:
# rewrite CrossAttention to use qkv
from comfy.ldm.modules.attention import CrossAttention

match_func = lambda m: isinstance(
m, CrossAttention
) and _can_use_flash_attn(m)
can_rewrite_modules = search_modules(diffusion_model, match_func)
print(f"rewrite {len(can_rewrite_modules)=} CrossAttention")
for k, v in can_rewrite_modules.items():
if f"{k}.to_q" in calibrate_info:
_rewrite_attention(v) # diffusion_model is modified in-place
else:
print(f"skip {k+'.to_q'} not in calibrate_info")

try:
# rewrite CrossAttentionPytorch to use qkv
from comfy.ldm.modules.attention import CrossAttentionPytorch

match_func = lambda m: isinstance(
m, CrossAttentionPytorch
) and _can_use_flash_attn(m)
can_rewrite_modules = search_modules(diffusion_model, match_func)
print(f"rewrite {len(can_rewrite_modules)=} CrossAttentionPytorch")
for k, v in can_rewrite_modules.items():
if f"{k}.to_q" in calibrate_info:
_rewrite_attention(v) # diffusion_model is modified in-place
else:
print(f"skip {k+'.to_q'} not in calibrate_info")

except Exception as e:
print(e)
except Exception as e:
raise RuntimeError(f"rewrite CrossAttention failed: {e}")


def find_quantizable_modules(
Expand Down
254 changes: 166 additions & 88 deletions onediff_comfy_nodes/utils/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,32 +103,38 @@ def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
or to_v_w_name not in patches
):
continue
to_q_w = patches[to_q_w_name]
to_k_w = patches[to_k_w_name]
to_v_w = patches[to_v_w_name]
to_q_w = patches[to_q_w_name][1]
to_k_w = patches[to_k_w_name][1]
to_v_w = patches[to_v_w_name][1]
assert to_q_w[2] == to_k_w[2] and to_q_w[2] == to_v_w[2]
to_qkv_w_name = f"diffusion_model.{name}.to_qkv.weight"

dim_head = module.to_qkv.out_features // module.heads // 3
patches[to_qkv_w_name] = tuple(
[
torch.stack((to_q_w[0], to_k_w[0], to_v_w[0]), dim=0).reshape(
3, module.heads, dim_head, -1
), # (3, H, K, (BM))
torch.stack((to_q_w[1], to_k_w[1], to_v_w[1]), dim=0),
]
+ list(to_q_w[2:])
)
tmp_list = [
torch.stack((to_q_w[0], to_k_w[0], to_v_w[0]), dim=0).reshape(
3, module.heads, dim_head, -1
), # (3, H, K, (BM))
torch.stack((to_q_w[1], to_k_w[1], to_v_w[1]), dim=0),
] + list(to_q_w[2:])

patch_type = "onediff_int8"
patch_value = tuple(tmp_list + [module])
patches[to_qkv_w_name] = (patch_type, patch_value)

if is_diffusers_quant_available:
if isinstance(
module, diffusers_quant.DynamicQuantLinearModule
) or isinstance(module, diffusers_quant.DynamicQuantConvModule):
w_name = f"diffusion_model.{name}.weight"
if w_name in patches:
patches[w_name] = tuple(list(patches[w_name]) + [module])
patch_type = "onediff_int8"
patch_value = tuple(list(patches[w_name][1]) + [module])
patches[w_name] = (patch_type, patch_value)
b_name = f"diffusion_model.{name}.bias"
if b_name in patches:
patches[b_name] = tuple(list(patches[b_name]) + [module])
patch_type = "onediff_int8"
patch_value = tuple(list(patches[b_name][1]) + [module])
patches[b_name] = (patch_type, patch_value)

p = set()
for k in patches:
Expand All @@ -148,41 +154,24 @@ def calculate_weight(self, patches, weight, key):
is_diffusers_quant_available = True
except:
pass

for p in patches:
alpha = p[0]
v = p[1]
strength_model = p[2]

is_rewrite_qkv = (
True if (len(v) == 4 or len(v) == 5) and "to_qkv" in key else False
)
is_quant = False
if (
is_diffusers_quant_available
and len(v) == 5
and (
isinstance(v[4], diffusers_quant.DynamicQuantLinearModule)
or isinstance(v[4], diffusers_quant.DynamicQuantConvModule)
)
):
is_quant = True
org_weight_scale = (
v[4]
.weight_scale.reshape(
[-1] + [1 for _ in range(len(weight.shape) - 1)]
)
.to(weight.device)
)
weight = weight.to(torch.float32) * org_weight_scale

if strength_model != 1.0:
weight *= strength_model

if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key),)

if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]

if patch_type == "diff":
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
Expand All @@ -195,22 +184,16 @@ def calculate_weight(self, patches, weight, key):
weight += alpha * comfy.model_management.cast_to_device(
w1, weight.device, weight.dtype
)
elif len(v) == 4 or is_rewrite_qkv or is_quant: # lora/locon
elif patch_type == "lora": # lora/locon
mat1 = comfy.model_management.cast_to_device(
v[0], weight.device, torch.float32
)
mat2 = comfy.model_management.cast_to_device(
v[1], weight.device, torch.float32
)
if v[2] is not None:
if is_rewrite_qkv:
alpha *= v[2] / mat2.shape[1]
else:
alpha *= v[2] / mat2.shape[0]
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
# TODO(): support rewrite qkv
assert not is_rewrite_qkv

# locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(
v[3], weight.device, torch.float32
Expand All @@ -230,52 +213,19 @@ def calculate_weight(self, patches, weight, key):
.transpose(0, 1)
)
try:
if is_rewrite_qkv:
heads = mat1.shape[1]
qkv_lora = alpha * torch.bmm(
mat1.reshape(3, -1, mat2.shape[1]),
mat2.flatten(start_dim=2),
)
qkv_lora = qkv_lora.reshape(
3, heads, -1, qkv_lora.shape[2]
) # reshape to (3, H, K, (BM))
qkv_lora = qkv_lora.permute(
1, 0, 2, 3
) # permute to (H, 3, K, (BM))
weight += qkv_lora.reshape(weight.shape).type(weight.dtype)
else:
weight += (
(
alpha
* torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
)
weight += (
(
alpha
* torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
)
.reshape(weight.shape)
.type(weight.dtype)
)
if is_quant:
weight_max = torch.max(
weight.reshape(weight.shape[0], -1), dim=1
)[0].reshape([-1] + [1 for _ in range(len(weight.shape) - 1)])
weight_scale = torch.abs(weight_max) / 127
weight = torch.clamp(
torch.round(weight / weight_scale), -127, 127
).to(weight.dtype)
weight_acc = (weight * weight_scale).sum(
dim=[d for d in range(1, len(weight.shape))]
)
weight_scale = weight_scale.reshape(v[4].weight_scale.shape).to(
v[4].weight_scale.dtype
)
weight_acc = weight_acc.reshape(v[4].weight_acc.shape).to(
v[4].weight_acc.dtype
)
v[4].weight_scale.copy_(weight_scale)
v[4].weight_acc.copy_(weight_acc)
.reshape(weight.shape)
.type(weight.dtype)
)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: # lokr
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
Expand Down Expand Up @@ -340,7 +290,7 @@ def calculate_weight(self, patches, weight, key):
)
except Exception as e:
print("ERROR", key, e)
else: # loha
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
Expand Down Expand Up @@ -397,7 +347,135 @@ def calculate_weight(self, patches, weight, key):
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]

a1 = comfy.model_management.cast_to_device(
v[0].flatten(start_dim=1), weight.device, torch.float32
)
a2 = comfy.model_management.cast_to_device(
v[1].flatten(start_dim=1), weight.device, torch.float32
)
b1 = comfy.model_management.cast_to_device(
v[2].flatten(start_dim=1), weight.device, torch.float32
)
b2 = comfy.model_management.cast_to_device(
v[3].flatten(start_dim=1), weight.device, torch.float32
)

weight += (
(
(
torch.mm(b2, b1)
+ torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)
)
* alpha
)
.reshape(weight.shape)
.type(weight.dtype)
)
elif patch_type == "onediff_int8":
# import pdb;pdb.set_trace()
is_rewrite_qkv = True if "to_qkv" in key else False
is_quant = False
if (
is_diffusers_quant_available
and len(v) == 5
and (
isinstance(v[4], diffusers_quant.DynamicQuantLinearModule)
or isinstance(v[4], diffusers_quant.DynamicQuantConvModule)
)
):
is_quant = True
org_weight_scale = (
v[4]
.weight_scale.reshape(
[-1] + [1 for _ in range(len(weight.shape) - 1)]
)
.to(weight.device)
)
weight = weight.to(torch.float32) * org_weight_scale

mat1 = comfy.model_management.cast_to_device(
v[0], weight.device, torch.float32
)
mat2 = comfy.model_management.cast_to_device(
v[1], weight.device, torch.float32
)
if v[2] is not None:
if is_rewrite_qkv:
alpha *= v[2] / mat2.shape[1]
else:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
# TODO(): support rewrite qkv
assert not is_rewrite_qkv

# locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(
v[3], weight.device, torch.float32
)
final_shape = [
mat2.shape[1],
mat2.shape[0],
mat3.shape[2],
mat3.shape[3],
]
mat2 = (
torch.mm(
mat2.transpose(0, 1).flatten(start_dim=1),
mat3.transpose(0, 1).flatten(start_dim=1),
)
.reshape(final_shape)
.transpose(0, 1)
)
try:
if is_rewrite_qkv:
heads = mat1.shape[1]
qkv_lora = alpha * torch.bmm(
mat1.reshape(3, -1, mat2.shape[1]),
mat2.flatten(start_dim=2),
)
qkv_lora = qkv_lora.reshape(
3, heads, -1, qkv_lora.shape[2]
) # reshape to (3, H, K, (BM))
qkv_lora = qkv_lora.permute(
1, 0, 2, 3
) # permute to (H, 3, K, (BM))
weight += qkv_lora.reshape(weight.shape).type(weight.dtype)
else:
weight += (
(
alpha
* torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
)
)
.reshape(weight.shape)
.type(weight.dtype)
)
if is_quant:
weight_max = torch.max(
weight.reshape(weight.shape[0], -1), dim=1
)[0].reshape([-1] + [1 for _ in range(len(weight.shape) - 1)])
weight_scale = torch.abs(weight_max) / 127
weight = torch.clamp(
torch.round(weight / weight_scale), -127, 127
).to(weight.dtype)
weight_acc = (weight * weight_scale).sum(
dim=[d for d in range(1, len(weight.shape))]
)
weight_scale = weight_scale.reshape(v[4].weight_scale.shape).to(
v[4].weight_scale.dtype
)
weight_acc = weight_acc.reshape(v[4].weight_acc.shape).to(
v[4].weight_acc.dtype
)
v[4].weight_scale.copy_(weight_scale)
v[4].weight_acc.copy_(weight_acc)
except Exception as e:
print("ERROR", key, e)
return weight


Expand Down
Loading