-
Notifications
You must be signed in to change notification settings - Fork 50
/
convert_lora_safetensor_to_diffusers.py
98 lines (78 loc) · 3.94 KB
/
convert_lora_safetensor_to_diffusers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
This script shows a naive way, may be not so elegant, to load Lora (safetensors) weights in to diffusers model
For the mechanism of Lora, please refer to https://github.com/cloneofsimo/lora
Copyright 2023: Haofan Wang, Qixun Wang
"""
import torch
from safetensors.torch import load_file
from diffusers import StableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler
# load diffusers model
model_id = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float32)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
# load lora weight
model_path = "onePieceWanoSagaStyle_v2Offset.safetensors"
state_dict = load_file(model_path)
LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
alpha = 0.75
visited = []
# directly update weight in diffusers model
for key in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if '.alpha' in key or key in visited:
continue
if 'text' in key:
layer_infos = key.split('.')[0].split(LORA_PREFIX_TEXT_ENCODER+'_')[-1].split('_')
curr_layer = pipeline.text_encoder
else:
layer_infos = key.split('.')[0].split(LORA_PREFIX_UNET+'_')[-1].split('_')
curr_layer = pipeline.unet
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += '_'+layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
# org_forward(x) + lora_up(lora_down(x)) * multiplier
pair_keys = []
if 'lora_down' in key:
pair_keys.append(key.replace('lora_down', 'lora_up'))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace('lora_up', 'lora_down'))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
pipeline = pipeline.to("cuda")
pipeline.safety_checker = lambda images, clip_input: (images, False)
prompt = '1boy, wanostyle, monkey d luffy, smiling, straw hat, looking at viewer, solo, upper body, ((masterpiece)), (best quality), (extremely detailed), depth of field, sketch, dark intense shadows, sharp focus, soft lighting, hdr, colorful, good composition, fire all around, spectacular, <lora:wanostyle_2_offset:1>, closed shirt, anime screencap, scar under eye, ready to fight, black eyes'
negative_prompt = '(painting by bad-artist-anime:0.9), (painting by bad-artist:0.9), watermark, text, error, blurry, jpeg artifacts, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, artist name, (worst quality, low quality:1.4), bad anatomy, watermark, signature, text, logo'
with torch.no_grad():
image = pipeline(prompt=prompt,
height=1280,
width=720,
num_inference_steps=50,
guidance_scale=8).images[0]
image.save("./{}_{}.png".format(prompt[:50],alpha))