diff --git a/onediff_comfy_nodes/utils/diffusers_quant_utils.py b/onediff_comfy_nodes/utils/diffusers_quant_utils.py index 25526de9f..e83d9fe32 100644 --- a/onediff_comfy_nodes/utils/diffusers_quant_utils.py +++ b/onediff_comfy_nodes/utils/diffusers_quant_utils.py @@ -216,7 +216,9 @@ def _rewrite_attention(attn): os.environ["ONEFLOW_FUSE_QUANT_TO_MATMUL"] = old_env -def replace_module_with_quantizable_module(diffusion_model, calibrate_info_path): +def replace_module_with_quantizable_module( + diffusion_model, calibrate_info_path, use_rewrite_attn=True +): from diffusers_quant.utils import get_quantize_module _use_graph() @@ -242,24 +244,25 @@ def replace_module_with_quantizable_module(diffusion_model, calibrate_info_path) convert_fn=maybe_allow_in_graph, ) modify_sub_module(diffusion_model, sub_module_name, sub_mod) + if use_rewrite_attn: + print(f"{use_rewrite_attn=}, rewrite CrossAttention") + try: + # rewrite CrossAttention to use qkv + from comfy.ldm.modules.attention import CrossAttention + + match_func = lambda m: isinstance( + m, CrossAttention + ) and _can_use_flash_attn(m) + can_rewrite_modules = search_modules(diffusion_model, match_func) + print(f"rewrite {len(can_rewrite_modules)=} CrossAttention") + for k, v in can_rewrite_modules.items(): + if f"{k}.to_q" in calibrate_info: + _rewrite_attention(v) # diffusion_model is modified in-place + else: + print(f"skip {k+'.to_q'} not in calibrate_info") - try: - # rewrite CrossAttentionPytorch to use qkv - from comfy.ldm.modules.attention import CrossAttentionPytorch - - match_func = lambda m: isinstance( - m, CrossAttentionPytorch - ) and _can_use_flash_attn(m) - can_rewrite_modules = search_modules(diffusion_model, match_func) - print(f"rewrite {len(can_rewrite_modules)=} CrossAttentionPytorch") - for k, v in can_rewrite_modules.items(): - if f"{k}.to_q" in calibrate_info: - _rewrite_attention(v) # diffusion_model is modified in-place - else: - print(f"skip {k+'.to_q'} not in calibrate_info") - - except Exception as e: - print(e) + except Exception as e: + raise RuntimeError(f"rewrite CrossAttention failed: {e}") def find_quantizable_modules( diff --git a/onediff_comfy_nodes/utils/model_patcher.py b/onediff_comfy_nodes/utils/model_patcher.py index 4226bfaa6..8575272e5 100644 --- a/onediff_comfy_nodes/utils/model_patcher.py +++ b/onediff_comfy_nodes/utils/model_patcher.py @@ -103,32 +103,38 @@ def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): or to_v_w_name not in patches ): continue - to_q_w = patches[to_q_w_name] - to_k_w = patches[to_k_w_name] - to_v_w = patches[to_v_w_name] + to_q_w = patches[to_q_w_name][1] + to_k_w = patches[to_k_w_name][1] + to_v_w = patches[to_v_w_name][1] assert to_q_w[2] == to_k_w[2] and to_q_w[2] == to_v_w[2] to_qkv_w_name = f"diffusion_model.{name}.to_qkv.weight" dim_head = module.to_qkv.out_features // module.heads // 3 - patches[to_qkv_w_name] = tuple( - [ - torch.stack((to_q_w[0], to_k_w[0], to_v_w[0]), dim=0).reshape( - 3, module.heads, dim_head, -1 - ), # (3, H, K, (BM)) - torch.stack((to_q_w[1], to_k_w[1], to_v_w[1]), dim=0), - ] - + list(to_q_w[2:]) - ) + tmp_list = [ + torch.stack((to_q_w[0], to_k_w[0], to_v_w[0]), dim=0).reshape( + 3, module.heads, dim_head, -1 + ), # (3, H, K, (BM)) + torch.stack((to_q_w[1], to_k_w[1], to_v_w[1]), dim=0), + ] + list(to_q_w[2:]) + + patch_type = "onediff_int8" + patch_value = tuple(tmp_list + [module]) + patches[to_qkv_w_name] = (patch_type, patch_value) + if is_diffusers_quant_available: if isinstance( module, diffusers_quant.DynamicQuantLinearModule ) or isinstance(module, diffusers_quant.DynamicQuantConvModule): w_name = f"diffusion_model.{name}.weight" if w_name in patches: - patches[w_name] = tuple(list(patches[w_name]) + [module]) + patch_type = "onediff_int8" + patch_value = tuple(list(patches[w_name][1]) + [module]) + patches[w_name] = (patch_type, patch_value) b_name = f"diffusion_model.{name}.bias" if b_name in patches: - patches[b_name] = tuple(list(patches[b_name]) + [module]) + patch_type = "onediff_int8" + patch_value = tuple(list(patches[b_name][1]) + [module]) + patches[b_name] = (patch_type, patch_value) p = set() for k in patches: @@ -148,34 +154,11 @@ def calculate_weight(self, patches, weight, key): is_diffusers_quant_available = True except: pass - for p in patches: alpha = p[0] v = p[1] strength_model = p[2] - is_rewrite_qkv = ( - True if (len(v) == 4 or len(v) == 5) and "to_qkv" in key else False - ) - is_quant = False - if ( - is_diffusers_quant_available - and len(v) == 5 - and ( - isinstance(v[4], diffusers_quant.DynamicQuantLinearModule) - or isinstance(v[4], diffusers_quant.DynamicQuantConvModule) - ) - ): - is_quant = True - org_weight_scale = ( - v[4] - .weight_scale.reshape( - [-1] + [1 for _ in range(len(weight.shape) - 1)] - ) - .to(weight.device) - ) - weight = weight.to(torch.float32) * org_weight_scale - if strength_model != 1.0: weight *= strength_model @@ -183,6 +166,12 @@ def calculate_weight(self, patches, weight, key): v = (self.calculate_weight(v[1:], v[0].clone(), key),) if len(v) == 1: + patch_type = "diff" + elif len(v) == 2: + patch_type = v[0] + v = v[1] + + if patch_type == "diff": w1 = v[0] if alpha != 0.0: if w1.shape != weight.shape: @@ -195,7 +184,7 @@ def calculate_weight(self, patches, weight, key): weight += alpha * comfy.model_management.cast_to_device( w1, weight.device, weight.dtype ) - elif len(v) == 4 or is_rewrite_qkv or is_quant: # lora/locon + elif patch_type == "lora": # lora/locon mat1 = comfy.model_management.cast_to_device( v[0], weight.device, torch.float32 ) @@ -203,14 +192,8 @@ def calculate_weight(self, patches, weight, key): v[1], weight.device, torch.float32 ) if v[2] is not None: - if is_rewrite_qkv: - alpha *= v[2] / mat2.shape[1] - else: - alpha *= v[2] / mat2.shape[0] + alpha *= v[2] / mat2.shape[0] if v[3] is not None: - # TODO(): support rewrite qkv - assert not is_rewrite_qkv - # locon mid weights, hopefully the math is fine because I didn't properly test it mat3 = comfy.model_management.cast_to_device( v[3], weight.device, torch.float32 @@ -230,52 +213,19 @@ def calculate_weight(self, patches, weight, key): .transpose(0, 1) ) try: - if is_rewrite_qkv: - heads = mat1.shape[1] - qkv_lora = alpha * torch.bmm( - mat1.reshape(3, -1, mat2.shape[1]), - mat2.flatten(start_dim=2), - ) - qkv_lora = qkv_lora.reshape( - 3, heads, -1, qkv_lora.shape[2] - ) # reshape to (3, H, K, (BM)) - qkv_lora = qkv_lora.permute( - 1, 0, 2, 3 - ) # permute to (H, 3, K, (BM)) - weight += qkv_lora.reshape(weight.shape).type(weight.dtype) - else: - weight += ( - ( - alpha - * torch.mm( - mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) - ) + weight += ( + ( + alpha + * torch.mm( + mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) ) - .reshape(weight.shape) - .type(weight.dtype) ) - if is_quant: - weight_max = torch.max( - weight.reshape(weight.shape[0], -1), dim=1 - )[0].reshape([-1] + [1 for _ in range(len(weight.shape) - 1)]) - weight_scale = torch.abs(weight_max) / 127 - weight = torch.clamp( - torch.round(weight / weight_scale), -127, 127 - ).to(weight.dtype) - weight_acc = (weight * weight_scale).sum( - dim=[d for d in range(1, len(weight.shape))] - ) - weight_scale = weight_scale.reshape(v[4].weight_scale.shape).to( - v[4].weight_scale.dtype - ) - weight_acc = weight_acc.reshape(v[4].weight_acc.shape).to( - v[4].weight_acc.dtype - ) - v[4].weight_scale.copy_(weight_scale) - v[4].weight_acc.copy_(weight_acc) + .reshape(weight.shape) + .type(weight.dtype) + ) except Exception as e: print("ERROR", key, e) - elif len(v) == 8: # lokr + elif patch_type == "lokr": w1 = v[0] w2 = v[1] w1_a = v[3] @@ -340,7 +290,7 @@ def calculate_weight(self, patches, weight, key): ) except Exception as e: print("ERROR", key, e) - else: # loha + elif patch_type == "loha": w1a = v[0] w1b = v[1] if v[2] is not None: @@ -397,7 +347,135 @@ def calculate_weight(self, patches, weight, key): weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) + elif patch_type == "glora": + if v[4] is not None: + alpha *= v[4] / v[0].shape[0] + + a1 = comfy.model_management.cast_to_device( + v[0].flatten(start_dim=1), weight.device, torch.float32 + ) + a2 = comfy.model_management.cast_to_device( + v[1].flatten(start_dim=1), weight.device, torch.float32 + ) + b1 = comfy.model_management.cast_to_device( + v[2].flatten(start_dim=1), weight.device, torch.float32 + ) + b2 = comfy.model_management.cast_to_device( + v[3].flatten(start_dim=1), weight.device, torch.float32 + ) + weight += ( + ( + ( + torch.mm(b2, b1) + + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1) + ) + * alpha + ) + .reshape(weight.shape) + .type(weight.dtype) + ) + elif patch_type == "onediff_int8": + # import pdb;pdb.set_trace() + is_rewrite_qkv = True if "to_qkv" in key else False + is_quant = False + if ( + is_diffusers_quant_available + and len(v) == 5 + and ( + isinstance(v[4], diffusers_quant.DynamicQuantLinearModule) + or isinstance(v[4], diffusers_quant.DynamicQuantConvModule) + ) + ): + is_quant = True + org_weight_scale = ( + v[4] + .weight_scale.reshape( + [-1] + [1 for _ in range(len(weight.shape) - 1)] + ) + .to(weight.device) + ) + weight = weight.to(torch.float32) * org_weight_scale + + mat1 = comfy.model_management.cast_to_device( + v[0], weight.device, torch.float32 + ) + mat2 = comfy.model_management.cast_to_device( + v[1], weight.device, torch.float32 + ) + if v[2] is not None: + if is_rewrite_qkv: + alpha *= v[2] / mat2.shape[1] + else: + alpha *= v[2] / mat2.shape[0] + if v[3] is not None: + # TODO(): support rewrite qkv + assert not is_rewrite_qkv + + # locon mid weights, hopefully the math is fine because I didn't properly test it + mat3 = comfy.model_management.cast_to_device( + v[3], weight.device, torch.float32 + ) + final_shape = [ + mat2.shape[1], + mat2.shape[0], + mat3.shape[2], + mat3.shape[3], + ] + mat2 = ( + torch.mm( + mat2.transpose(0, 1).flatten(start_dim=1), + mat3.transpose(0, 1).flatten(start_dim=1), + ) + .reshape(final_shape) + .transpose(0, 1) + ) + try: + if is_rewrite_qkv: + heads = mat1.shape[1] + qkv_lora = alpha * torch.bmm( + mat1.reshape(3, -1, mat2.shape[1]), + mat2.flatten(start_dim=2), + ) + qkv_lora = qkv_lora.reshape( + 3, heads, -1, qkv_lora.shape[2] + ) # reshape to (3, H, K, (BM)) + qkv_lora = qkv_lora.permute( + 1, 0, 2, 3 + ) # permute to (H, 3, K, (BM)) + weight += qkv_lora.reshape(weight.shape).type(weight.dtype) + else: + weight += ( + ( + alpha + * torch.mm( + mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) + ) + ) + .reshape(weight.shape) + .type(weight.dtype) + ) + if is_quant: + weight_max = torch.max( + weight.reshape(weight.shape[0], -1), dim=1 + )[0].reshape([-1] + [1 for _ in range(len(weight.shape) - 1)]) + weight_scale = torch.abs(weight_max) / 127 + weight = torch.clamp( + torch.round(weight / weight_scale), -127, 127 + ).to(weight.dtype) + weight_acc = (weight * weight_scale).sum( + dim=[d for d in range(1, len(weight.shape))] + ) + weight_scale = weight_scale.reshape(v[4].weight_scale.shape).to( + v[4].weight_scale.dtype + ) + weight_acc = weight_acc.reshape(v[4].weight_acc.shape).to( + v[4].weight_acc.dtype + ) + v[4].weight_scale.copy_(weight_scale) + v[4].weight_acc.copy_(weight_acc) + except Exception as e: + print("ERROR", key, e) return weight