-
Notifications
You must be signed in to change notification settings - Fork 951
/
torch2coreml.py
1738 lines (1454 loc) · 72.3 KB
/
torch2coreml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
from python_coreml_stable_diffusion import (
unet, controlnet, chunk_mlprogram
)
import argparse
from collections import OrderedDict, defaultdict
from copy import deepcopy
import coremltools as ct
from diffusers import (
StableDiffusionPipeline,
DiffusionPipeline,
ControlNetModel
)
from diffusionkit.tests.torch2coreml import (
convert_mmdit_to_mlpackage,
convert_vae_to_mlpackage
)
import gc
from huggingface_hub import snapshot_download
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
import numpy as np
import os
import requests
import shutil
import time
import re
import pathlib
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_grad_enabled(False)
from types import MethodType
def _get_coreml_inputs(sample_inputs, args):
return [
ct.TensorType(
name=k,
shape=v.shape,
dtype=v.numpy().dtype if isinstance(v, torch.Tensor) else v.dtype,
) for k, v in sample_inputs.items()
]
def compute_psnr(a, b):
""" Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects
"""
max_b = np.abs(b).max()
sumdeltasq = 0.0
sumdeltasq = ((a - b) * (a - b)).sum()
sumdeltasq /= b.size
sumdeltasq = np.sqrt(sumdeltasq)
eps = 1e-5
eps2 = 1e-10
psnr = 20 * np.log10((max_b + eps) / (sumdeltasq + eps2))
return psnr
ABSOLUTE_MIN_PSNR = 35
def report_correctness(original_outputs, final_outputs, log_prefix):
""" Report PSNR values across two compatible tensors
"""
original_psnr = compute_psnr(original_outputs, original_outputs)
final_psnr = compute_psnr(original_outputs, final_outputs)
dB_change = final_psnr - original_psnr
logger.info(
f"{log_prefix}: PSNR changed by {dB_change:.1f} dB ({original_psnr:.1f} -> {final_psnr:.1f})"
)
if final_psnr < ABSOLUTE_MIN_PSNR:
raise ValueError(f"{final_psnr:.1f} dB is too low!")
else:
logger.info(
f"{final_psnr:.1f} dB > {ABSOLUTE_MIN_PSNR} dB (minimum allowed) parity check passed"
)
return final_psnr
def _get_out_path(args, submodule_name):
fname = f"Stable_Diffusion_version_{args.model_version}_{submodule_name}.mlpackage"
fname = fname.replace("/", "_")
return os.path.join(args.o, fname)
def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
output_names, args, out_path=None, precision=None, compute_unit=None):
if out_path is None:
out_path = _get_out_path(args, submodule_name)
compute_unit = compute_unit or ct.ComputeUnit[args.compute_unit]
if os.path.exists(out_path):
logger.info(f"Skipping export because {out_path} already exists")
logger.info(f"Loading model from {out_path}")
start = time.time()
# Note: Note that each model load will trigger a model compilation which takes up to a few minutes.
# The Swifty CLI we provide uses precompiled Core ML models (.mlmodelc) which incurs compilation only
# upon first load and mitigates the load time in subsequent runs.
coreml_model = ct.models.MLModel(
out_path, compute_units=compute_unit)
logger.info(
f"Loading {out_path} took {time.time() - start:.1f} seconds")
coreml_model.compute_unit = compute_unit
else:
logger.info(f"Converting {submodule_name} to CoreML..")
coreml_model = ct.convert(
torchscript_module,
convert_to="mlprogram",
minimum_deployment_target=ct.target.macOS13,
inputs=_get_coreml_inputs(sample_inputs, args),
outputs=[ct.TensorType(name=name, dtype=np.float32) for name in output_names],
compute_units=compute_unit,
compute_precision=precision,
skip_model_load=not args.check_output_correctness,
)
del torchscript_module
gc.collect()
return coreml_model, out_path
def quantize_weights(args):
""" Quantize weights to args.quantize_nbits using a palette (look-up table)
"""
for model_name in ["text_encoder", "text_encoder_2", "unet", "refiner", "control-unet"]:
logger.info(f"Quantizing {model_name} to {args.quantize_nbits}-bit precision")
out_path = _get_out_path(args, model_name)
_quantize_weights(
out_path,
model_name,
args.quantize_nbits
)
if args.convert_controlnet:
for controlnet_model_version in args.convert_controlnet:
controlnet_model_name = controlnet_model_version.replace("/", "_")
logger.info(f"Quantizing {controlnet_model_name} to {args.quantize_nbits}-bit precision")
fname = f"ControlNet_{controlnet_model_name}.mlpackage"
out_path = os.path.join(args.o, fname)
_quantize_weights(
out_path,
controlnet_model_name,
args.quantize_nbits
)
def _quantize_weights(out_path, model_name, nbits):
if os.path.exists(out_path):
logger.info(f"Quantizing {model_name}")
mlmodel = ct.models.MLModel(out_path,
compute_units=ct.ComputeUnit.CPU_ONLY)
op_config = ct.optimize.coreml.OpPalettizerConfig(
mode="kmeans",
nbits=nbits,
)
config = ct.optimize.coreml.OptimizationConfig(
global_config=op_config,
op_type_configs={
"gather": None # avoid quantizing the embedding table
}
)
model = ct.optimize.coreml.palettize_weights(mlmodel, config=config).save(out_path)
logger.info("Done")
else:
logger.info(
f"Skipped quantizing {model_name} (Not found at {out_path})")
def _compile_coreml_model(source_model_path, output_dir, final_name):
""" Compiles Core ML models using the coremlcompiler utility from Xcode toolchain
"""
target_path = os.path.join(output_dir, f"{final_name}.mlmodelc")
if os.path.exists(target_path):
logger.warning(
f"Found existing compiled model at {target_path}! Skipping..")
return target_path
logger.info(f"Compiling {source_model_path}")
source_model_name = os.path.basename(
os.path.splitext(source_model_path)[0])
os.system(f"xcrun coremlcompiler compile {source_model_path} {output_dir}")
compiled_output = os.path.join(output_dir, f"{source_model_name}.mlmodelc")
shutil.move(compiled_output, target_path)
return target_path
def _download_t5_model(args, t5_save_path):
t5_url = args.text_encoder_t5_url
match = re.match(r'https://huggingface.co/(.+)/resolve/main/(.+)', t5_url)
if not match:
raise ValueError(f"Invalid Hugging Face URL: {t5_url}")
repo_id, model_subpath = match.groups()
download_path = snapshot_download(
repo_id=repo_id,
revision="main",
allow_patterns=[f"{model_subpath}/*"]
)
logger.info(f"Downloaded T5 model to {download_path}")
# Move the downloaded model to the top level of the Resources directory
logger.info(f"Copying T5 model from {download_path} to {t5_save_path}")
cache_path = os.path.join(download_path, model_subpath)
shutil.copytree(cache_path, t5_save_path)
def bundle_resources_for_swift_cli(args):
"""
- Compiles Core ML models from mlpackage into mlmodelc format
- Download tokenizer resources for the text encoder
"""
resources_dir = os.path.join(args.o, "Resources")
if not os.path.exists(resources_dir):
os.makedirs(resources_dir, exist_ok=True)
logger.info(f"Created {resources_dir} for Swift CLI assets")
# Compile model using coremlcompiler (Significantly reduces the load time for unet)
for source_name, target_name in [("text_encoder", "TextEncoder"),
("text_encoder_2", "TextEncoder2"),
("vae_decoder", "VAEDecoder"),
("vae_encoder", "VAEEncoder"),
("unet", "Unet"),
("unet_chunk1", "UnetChunk1"),
("unet_chunk2", "UnetChunk2"),
("refiner", "UnetRefiner"),
("refiner_chunk1", "UnetRefinerChunk1"),
("refiner_chunk2", "UnetRefinerChunk2"),
("mmdit", "MultiModalDiffusionTransformer"),
("control-unet", "ControlledUnet"),
("control-unet_chunk1", "ControlledUnetChunk1"),
("control-unet_chunk2", "ControlledUnetChunk2"),
("safety_checker", "SafetyChecker")]:
source_path = _get_out_path(args, source_name)
if os.path.exists(source_path):
target_path = _compile_coreml_model(source_path, resources_dir,
target_name)
logger.info(f"Compiled {source_path} to {target_path}")
else:
logger.warning(
f"{source_path} not found, skipping compilation to {target_name}.mlmodelc"
)
if args.convert_controlnet:
for controlnet_model_version in args.convert_controlnet:
controlnet_model_name = controlnet_model_version.replace("/", "_")
fname = f"ControlNet_{controlnet_model_name}.mlpackage"
source_path = os.path.join(args.o, fname)
controlnet_dir = os.path.join(resources_dir, "controlnet")
target_name = "".join([word.title() for word in re.split('_|-', controlnet_model_name)])
if os.path.exists(source_path):
target_path = _compile_coreml_model(source_path, controlnet_dir,
target_name)
logger.info(f"Compiled {source_path} to {target_path}")
else:
logger.warning(
f"{source_path} not found, skipping compilation to {target_name}.mlmodelc"
)
# Fetch and save vocabulary JSON file for text tokenizer
logger.info("Downloading and saving tokenizer vocab.json")
with open(os.path.join(resources_dir, "vocab.json"), "wb") as f:
f.write(requests.get(args.text_encoder_vocabulary_url).content)
logger.info("Done")
# Fetch and save merged pairs JSON file for text tokenizer
logger.info("Downloading and saving tokenizer merges.txt")
with open(os.path.join(resources_dir, "merges.txt"), "wb") as f:
f.write(requests.get(args.text_encoder_merges_url).content)
logger.info("Done")
# Fetch and save pre-converted T5 text encoder model
t5_model_name = "TextEncoderT5.mlmodelc"
t5_save_path = os.path.join(resources_dir, t5_model_name)
if args.include_t5:
if not os.path.exists(t5_save_path):
logger.info("Downloading pre-converted T5 encoder model TextEncoderT5.mlmodelc")
_download_t5_model(args, t5_save_path)
logger.info("Done")
else:
logger.info(f"Skipping T5 download as {t5_save_path} already exists")
# Fetch and save T5 text tokenizer JSON files
logger.info("Downloading and saving T5 tokenizer files tokenizer_config.json and tokenizer.json")
with open(os.path.join(resources_dir, "tokenizer_config.json"), "wb") as f:
f.write(requests.get(args.text_encoder_t5_config_url).content)
with open(os.path.join(resources_dir, "tokenizer.json"), "wb") as f:
f.write(requests.get(args.text_encoder_t5_data_url).content)
logger.info("Done")
return resources_dir
from transformers.models.clip import modeling_clip
# Copied from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/clip/modeling_clip.py#L677C1-L692C1
def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_length: int = 0):
""" Patch to replace torch.finfo(dtype).min with -1e4
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(-1e4, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
modeling_clip._make_causal_mask = patched_make_causal_mask
def convert_text_encoder(text_encoder, tokenizer, submodule_name, args):
""" Converts the text encoder component of Stable Diffusion
"""
text_encoder = text_encoder.to(dtype=torch.float32)
out_path = _get_out_path(args, submodule_name)
if os.path.exists(out_path):
logger.info(
f"`{submodule_name}` already exists at {out_path}, skipping conversion."
)
return
# Create sample inputs for tracing, conversion and correctness verification
text_encoder_sequence_length = tokenizer.model_max_length
sample_text_encoder_inputs = {
"input_ids":
torch.randint(
text_encoder.config.vocab_size,
(1, text_encoder_sequence_length),
# https://github.com/apple/coremltools/issues/1423
dtype=torch.float32,
)
}
sample_text_encoder_inputs_spec = {
k: (v.shape, v.dtype)
for k, v in sample_text_encoder_inputs.items()
}
logger.info(f"Sample inputs spec: {sample_text_encoder_inputs_spec}")
class TextEncoder(nn.Module):
def __init__(self, with_hidden_states_for_layer=None):
super().__init__()
self.text_encoder = text_encoder
self.with_hidden_states_for_layer = with_hidden_states_for_layer
def forward(self, input_ids):
if self.with_hidden_states_for_layer is not None:
output = self.text_encoder(input_ids, output_hidden_states=True)
hidden_embeds = output.hidden_states[self.with_hidden_states_for_layer]
if "text_embeds" in output:
return (hidden_embeds, output.text_embeds)
else:
return (hidden_embeds, output.pooler_output)
else:
return self.text_encoder(input_ids, return_dict=False)
# SD XL uses the hidden states after the encoder layers from both encoders,
# and the pooled `text_embeds` output of the second encoder.
hidden_layer = -2 if args.xl_version else None
reference_text_encoder = TextEncoder(with_hidden_states_for_layer=hidden_layer).eval()
logger.info(f"JIT tracing {submodule_name}..")
reference_text_encoder = torch.jit.trace(
reference_text_encoder,
(sample_text_encoder_inputs["input_ids"].to(torch.int32), ),
)
logger.info("Done.")
if args.xl_version:
output_names = ["hidden_embeds", "pooled_outputs"]
else:
output_names = ["last_hidden_state", "pooled_outputs"]
coreml_text_encoder, out_path = _convert_to_coreml(
submodule_name, reference_text_encoder, sample_text_encoder_inputs,
output_names, args)
# Set model metadata
coreml_text_encoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
if args.xl_version:
coreml_text_encoder.license = "OpenRAIL++-M (https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md)"
else:
coreml_text_encoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
coreml_text_encoder.version = args.model_version
coreml_text_encoder.short_description = \
"Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/abs/2112.10752 for details."
# Set the input descriptions
coreml_text_encoder.input_description[
"input_ids"] = "The token ids that represent the input text"
# Set the output descriptions
if args.xl_version:
coreml_text_encoder.output_description[
"hidden_embeds"] = "Hidden states after the encoder layers"
else:
coreml_text_encoder.output_description[
"last_hidden_state"] = "The token embeddings as encoded by the Transformer model"
coreml_text_encoder.output_description[
"pooled_outputs"] = "The version of the `last_hidden_state` output after pooling"
coreml_text_encoder.save(out_path)
logger.info(f"Saved text_encoder into {out_path}")
# Parity check PyTorch vs CoreML
if args.check_output_correctness:
baseline_out = text_encoder(
sample_text_encoder_inputs["input_ids"].to(torch.int32),
output_hidden_states=args.xl_version,
return_dict=True,
)
if args.xl_version:
# TODO: maybe check pooled_outputs too
baseline_out = baseline_out.hidden_states[hidden_layer].numpy()
else:
baseline_out = baseline_out.last_hidden_state.numpy()
coreml_out = coreml_text_encoder.predict(
{k: v.numpy() for k, v in sample_text_encoder_inputs.items()}
)
coreml_out = coreml_out["hidden_embeds" if args.xl_version else "last_hidden_state"]
report_correctness(
baseline_out, coreml_out,
"text_encoder baseline PyTorch to reference CoreML")
del reference_text_encoder, coreml_text_encoder
gc.collect()
def modify_coremltools_torch_frontend_badbmm():
"""
Modifies coremltools torch frontend for baddbmm to be robust to the `beta` argument being of non-float dtype:
e.g. https://github.com/huggingface/diffusers/blob/v0.8.1/src/diffusers/models/attention.py#L315
"""
from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil.frontend.torch.torch_op_registry import _TORCH_OPS_REGISTRY
if "baddbmm" in _TORCH_OPS_REGISTRY:
del _TORCH_OPS_REGISTRY["baddbmm"]
@register_torch_op
def baddbmm(context, node):
"""
baddbmm(Tensor input, Tensor batch1, Tensor batch2, Scalar beta=1, Scalar alpha=1)
output = beta * input + alpha * batch1 * batch2
Notice that batch1 and batch2 must be 3-D tensors each containing the same number of matrices.
If batch1 is a (b×n×m) tensor, batch2 is a (b×m×p) tensor, then input must be broadcastable with a (b×n×p) tensor
and out will be a (b×n×p) tensor.
"""
assert len(node.outputs) == 1
inputs = _get_inputs(context, node, expected=5)
bias, batch1, batch2, beta, alpha = inputs
if beta.val != 1.0:
# Apply scaling factor beta to the bias.
if beta.val.dtype == np.int32:
beta = mb.cast(x=beta, dtype="fp32")
logger.warning(
f"Casted the `beta`(value={beta.val}) argument of `baddbmm` op "
"from int32 to float32 dtype for conversion!")
bias = mb.mul(x=beta, y=bias, name=bias.name + "_scaled")
context.add(bias)
if alpha.val != 1.0:
# Apply scaling factor alpha to the input.
batch1 = mb.mul(x=alpha, y=batch1, name=batch1.name + "_scaled")
context.add(batch1)
bmm_node = mb.matmul(x=batch1, y=batch2, name=node.name + "_bmm")
context.add(bmm_node)
baddbmm_node = mb.add(x=bias, y=bmm_node, name=node.name)
context.add(baddbmm_node)
def convert_vae_decoder(pipe, args):
""" Converts the VAE Decoder component of Stable Diffusion
"""
out_path = _get_out_path(args, "vae_decoder")
if os.path.exists(out_path):
logger.info(
f"`vae_decoder` already exists at {out_path}, skipping conversion."
)
return
if not hasattr(pipe, "unet"):
raise RuntimeError(
"convert_unet() deletes pipe.unet to save RAM. "
"Please use convert_vae_decoder() before convert_unet()")
z_shape = (
1, # B
pipe.vae.config.latent_channels, # C
args.latent_h or pipe.unet.config.sample_size, # H
args.latent_w or pipe.unet.config.sample_size, # W
)
if args.custom_vae_version is None and args.xl_version:
inputs_dtype = torch.float32
compute_precision = ct.precision.FLOAT32
# FIXME: Hardcoding to CPU_AND_GPU since ANE doesn't support FLOAT32
compute_unit = ct.ComputeUnit.CPU_AND_GPU
else:
inputs_dtype = torch.float16
compute_precision = None
compute_unit = None
sample_vae_decoder_inputs = {
"z": torch.rand(*z_shape, dtype=inputs_dtype)
}
class VAEDecoder(nn.Module):
""" Wrapper nn.Module wrapper for pipe.decode() method
"""
def __init__(self):
super().__init__()
self.post_quant_conv = pipe.vae.post_quant_conv.to(dtype=torch.float32)
self.decoder = pipe.vae.decoder.to(dtype=torch.float32)
def forward(self, z):
return self.decoder(self.post_quant_conv(z))
baseline_decoder = VAEDecoder().eval()
# No optimization needed for the VAE Decoder as it is a pure ConvNet
traced_vae_decoder = torch.jit.trace(
baseline_decoder, (sample_vae_decoder_inputs["z"].to(torch.float32), ))
modify_coremltools_torch_frontend_badbmm()
coreml_vae_decoder, out_path = _convert_to_coreml(
"vae_decoder", traced_vae_decoder, sample_vae_decoder_inputs,
["image"], args, precision=compute_precision, compute_unit=compute_unit)
# Set model metadata
coreml_vae_decoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
if args.xl_version:
coreml_vae_decoder.license = "OpenRAIL++-M (https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md)"
else:
coreml_vae_decoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
coreml_vae_decoder.version = args.model_version
coreml_vae_decoder.short_description = \
"Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/abs/2112.10752 for details."
# Set the input descriptions
coreml_vae_decoder.input_description["z"] = \
"The denoised latent embeddings from the unet model after the last step of reverse diffusion"
# Set the output descriptions
coreml_vae_decoder.output_description[
"image"] = "Generated image normalized to range [-1, 1]"
coreml_vae_decoder.save(out_path)
logger.info(f"Saved vae_decoder into {out_path}")
# Parity check PyTorch vs CoreML
if args.check_output_correctness:
baseline_out = baseline_decoder(
z=sample_vae_decoder_inputs["z"].to(torch.float32)).numpy()
coreml_out = list(
coreml_vae_decoder.predict(
{k: v.numpy()
for k, v in sample_vae_decoder_inputs.items()}).values())[0]
report_correctness(baseline_out, coreml_out,
"vae_decoder baseline PyTorch to baseline CoreML")
del traced_vae_decoder, pipe.vae.decoder, coreml_vae_decoder
gc.collect()
def convert_vae_decoder_sd3(args):
""" Converts the VAE component of Stable Diffusion 3
"""
out_path = _get_out_path(args, "vae_decoder")
if os.path.exists(out_path):
logger.info(
f"`vae_decoder` already exists at {out_path}, skipping conversion."
)
return
# Convert the VAE Decoder model via DiffusionKit
converted_vae_path = convert_vae_to_mlpackage(
model_version=args.model_version,
latent_h=args.latent_h,
latent_w=args.latent_w,
output_dir=args.o,
)
# Load converted model
coreml_vae_decoder = ct.models.MLModel(converted_vae_path)
# Set model metadata
coreml_vae_decoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
coreml_vae_decoder.license = "Stability AI Community License (https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md)"
coreml_vae_decoder.version = args.model_version
coreml_vae_decodershort_description = \
"Stable Diffusion 3 generates images conditioned on text or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/pdf/2403.03206 for details."
# Set the input descriptions
coreml_vae_decoder.input_description["z"] = \
"The denoised latent embeddings from the unet model after the last step of reverse diffusion"
# Set the output descriptions
coreml_vae_decoder.output_description[
"image"] = "Generated image normalized to range [-1, 1]"
# Set package version metadata
from python_coreml_stable_diffusion._version import __version__
coreml_vae_decoder.user_defined_metadata["com.github.apple.ml-stable-diffusion.version"] = __version__
from diffusionkit.version import __version__
coreml_vae_decoder.user_defined_metadata["com.github.argmax.diffusionkit.version"] = __version__
# Save the updated model
coreml_vae_decoder.save(out_path)
logger.info(f"Saved vae_decoder into {out_path}")
# Delete the original file
if os.path.exists(converted_vae_path):
shutil.rmtree(converted_vae_path)
del coreml_vae_decoder
gc.collect()
def convert_vae_encoder(pipe, args):
""" Converts the VAE Encoder component of Stable Diffusion
"""
out_path = _get_out_path(args, "vae_encoder")
if os.path.exists(out_path):
logger.info(
f"`vae_encoder` already exists at {out_path}, skipping conversion."
)
return
if not hasattr(pipe, "unet"):
raise RuntimeError(
"convert_unet() deletes pipe.unet to save RAM. "
"Please use convert_vae_encoder() before convert_unet()")
height = (args.latent_h or pipe.unet.config.sample_size) * 8
width = (args.latent_w or pipe.unet.config.sample_size) * 8
x_shape = (
1, # B
3, # C (RGB range from -1 to 1)
height, # H
width, # w
)
if args.xl_version:
inputs_dtype = torch.float32
compute_precision = ct.precision.FLOAT32
# FIXME: Hardcoding to CPU_AND_GPU since ANE doesn't support FLOAT32
compute_unit = ct.ComputeUnit.CPU_AND_GPU
else:
inputs_dtype = torch.float16
compute_precision = None
compute_unit = None
sample_vae_encoder_inputs = {
"x": torch.rand(*x_shape, dtype=inputs_dtype)
}
class VAEEncoder(nn.Module):
""" Wrapper nn.Module wrapper for pipe.encode() method
"""
def __init__(self):
super().__init__()
self.quant_conv = pipe.vae.quant_conv.to(dtype=torch.float32)
self.encoder = pipe.vae.encoder.to(dtype=torch.float32)
def forward(self, x):
return self.quant_conv(self.encoder(x))
baseline_encoder = VAEEncoder().eval()
# No optimization needed for the VAE Encoder as it is a pure ConvNet
traced_vae_encoder = torch.jit.trace(
baseline_encoder, (sample_vae_encoder_inputs["x"].to(torch.float32), ))
modify_coremltools_torch_frontend_badbmm()
coreml_vae_encoder, out_path = _convert_to_coreml(
"vae_encoder", traced_vae_encoder, sample_vae_encoder_inputs,
["latent"], args, precision=compute_precision, compute_unit=compute_unit)
# Set model metadata
coreml_vae_encoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
if args.xl_version:
coreml_vae_encoder.license = "OpenRAIL++-M (https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md)"
else:
coreml_vae_encoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
coreml_vae_encoder.version = args.model_version
coreml_vae_encoder.short_description = \
"Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/abs/2112.10752 for details."
# Set the input descriptions
coreml_vae_encoder.input_description["x"] = \
"The input image to base the initial latents on normalized to range [-1, 1]"
# Set the output descriptions
coreml_vae_encoder.output_description["latent"] = "The latent embeddings from the unet model from the input image."
coreml_vae_encoder.save(out_path)
logger.info(f"Saved vae_encoder into {out_path}")
# Parity check PyTorch vs CoreML
if args.check_output_correctness:
baseline_out = baseline_encoder(
x=sample_vae_encoder_inputs["x"].to(torch.float32)).numpy()
coreml_out = list(
coreml_vae_encoder.predict(
{k: v.numpy()
for k, v in sample_vae_encoder_inputs.items()}).values())[0]
report_correctness(baseline_out, coreml_out,
"vae_encoder baseline PyTorch to baseline CoreML")
del traced_vae_encoder, pipe.vae.encoder, coreml_vae_encoder
gc.collect()
def convert_unet(pipe, args, model_name=None):
""" Converts the UNet component of Stable Diffusion
"""
if args.unet_support_controlnet:
unet_name = "control-unet"
else:
unet_name = model_name or "unet"
out_path = _get_out_path(args, unet_name)
# Check if Unet was previously exported and then chunked
unet_chunks_exist = all(
os.path.exists(
out_path.replace(".mlpackage", f"_chunk{idx+1}.mlpackage"))
for idx in range(2))
if args.chunk_unet and unet_chunks_exist:
logger.info("`unet` chunks already exist, skipping conversion.")
del pipe.unet
gc.collect()
return
# If original Unet does not exist, export it from PyTorch+diffusers
elif not os.path.exists(out_path):
# Prepare sample input shapes and values
batch_size = 2 # for classifier-free guidance
if args.unet_batch_one:
batch_size = 1 # for not using classifier-free guidance
sample_shape = (
batch_size, # B
pipe.unet.config.in_channels, # C
args.latent_h or pipe.unet.config.sample_size, # H
args.latent_w or pipe.unet.config.sample_size, # W
)
if not hasattr(pipe, "text_encoder"):
raise RuntimeError(
"convert_text_encoder() deletes pipe.text_encoder to save RAM. "
"Please use convert_unet() before convert_text_encoder()")
if hasattr(pipe, "text_encoder") and pipe.text_encoder is not None:
text_token_sequence_length = pipe.text_encoder.config.max_position_embeddings
hidden_size = pipe.text_encoder.config.hidden_size,
elif hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None:
text_token_sequence_length = pipe.text_encoder_2.config.max_position_embeddings
hidden_size = pipe.text_encoder_2.config.hidden_size,
encoder_hidden_states_shape = (
batch_size,
args.text_encoder_hidden_size or pipe.unet.config.cross_attention_dim or hidden_size,
1,
args.text_token_sequence_length or text_token_sequence_length,
)
# Create the scheduled timesteps for downstream use
DEFAULT_NUM_INFERENCE_STEPS = 50
pipe.scheduler.set_timesteps(DEFAULT_NUM_INFERENCE_STEPS)
sample_unet_inputs = OrderedDict([
("sample", torch.rand(*sample_shape)),
("timestep",
torch.tensor([pipe.scheduler.timesteps[0].item()] *
(batch_size)).to(torch.float32)),
("encoder_hidden_states", torch.rand(*encoder_hidden_states_shape))
])
# Prepare inputs
baseline_sample_unet_inputs = deepcopy(sample_unet_inputs)
baseline_sample_unet_inputs[
"encoder_hidden_states"] = baseline_sample_unet_inputs[
"encoder_hidden_states"].squeeze(2).transpose(1, 2)
# Initialize reference unet
if args.xl_version:
unet_cls = unet.UNet2DConditionModelXL
# Sample time_ids
height = (args.latent_h or pipe.unet.config.sample_size) * 8
width = (args.latent_w or pipe.unet.config.sample_size) * 8
original_size = (height, width) # output_resolution
crops_coords_top_left = (0, 0) # topleft_crop_cond
target_size = (height, width) # resolution_cond
if hasattr(pipe.config, "requires_aesthetics_score") and pipe.config.requires_aesthetics_score:
# Part of SDXL's micro-conditioning as explained in section 2.2 of
# [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
# simulate an aesthetic score of the generated image by influencing the positive and negative text conditions.
aesthetic_score = 6.0 # default aesthetic_score
negative_aesthetic_score = 2.5 # default negative_aesthetic_score
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
else:
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
time_ids = [
add_neg_time_ids,
add_time_ids
]
# Pooled text embedding from text_encoder_2
text_embeds_shape = (
batch_size,
pipe.text_encoder_2.config.hidden_size
)
additional_xl_inputs = OrderedDict([
("time_ids", torch.tensor(time_ids).to(torch.float32)),
("text_embeds", torch.rand(*text_embeds_shape)),
])
sample_unet_inputs.update(additional_xl_inputs)
baseline_sample_unet_inputs['added_cond_kwargs'] = additional_xl_inputs
else:
unet_cls = unet.UNet2DConditionModel
reference_unet = unet_cls(support_controlnet=args.unet_support_controlnet, **pipe.unet.config).eval()
load_state_dict_summary = reference_unet.load_state_dict(
pipe.unet.state_dict())
if args.unet_support_controlnet:
from .unet import calculate_conv2d_output_shape
additional_residuals_shapes = []
# conv_in
out_h, out_w = calculate_conv2d_output_shape(
(args.latent_h or pipe.unet.config.sample_size),
(args.latent_w or pipe.unet.config.sample_size),
reference_unet.conv_in,
)
additional_residuals_shapes.append(
(batch_size, reference_unet.conv_in.out_channels, out_h, out_w))
# down_blocks
for down_block in reference_unet.down_blocks:
additional_residuals_shapes += [
(batch_size, resnet.out_channels, out_h, out_w) for resnet in down_block.resnets
]
if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None:
for downsampler in down_block.downsamplers:
out_h, out_w = calculate_conv2d_output_shape(out_h, out_w, downsampler.conv)
additional_residuals_shapes.append(
(batch_size, down_block.downsamplers[-1].conv.out_channels, out_h, out_w))
# mid_block
additional_residuals_shapes.append(
(batch_size, reference_unet.mid_block.resnets[-1].out_channels, out_h, out_w)
)
baseline_sample_unet_inputs["down_block_additional_residuals"] = ()
for i, shape in enumerate(additional_residuals_shapes):
sample_residual_input = torch.rand(*shape)
sample_unet_inputs[f"additional_residual_{i}"] = sample_residual_input
if i == len(additional_residuals_shapes) - 1:
baseline_sample_unet_inputs["mid_block_additional_residual"] = sample_residual_input
else:
baseline_sample_unet_inputs["down_block_additional_residuals"] += (sample_residual_input, )
sample_unet_inputs_spec = {
k: (v.shape, v.dtype)
for k, v in sample_unet_inputs.items()
}
logger.info(f"Sample UNet inputs spec: {sample_unet_inputs_spec}")
# JIT trace
logger.info("JIT tracing..")
reference_unet = torch.jit.trace(reference_unet,
list(sample_unet_inputs.values()))
logger.info("Done.")
if args.check_output_correctness:
baseline_out = pipe.unet.to(torch.float32)(**baseline_sample_unet_inputs,
return_dict=False)[0].numpy()
reference_out = reference_unet(*sample_unet_inputs.values())[0].numpy()
report_correctness(baseline_out, reference_out,
"unet baseline to reference PyTorch")
del pipe.unet
gc.collect()
coreml_sample_unet_inputs = {
k: v.numpy().astype(np.float16)
for k, v in sample_unet_inputs.items()
}
coreml_unet, out_path = _convert_to_coreml(unet_name, reference_unet,
coreml_sample_unet_inputs,
["noise_pred"], args)
del reference_unet
gc.collect()
# Set model metadata
coreml_unet.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
if args.xl_version:
coreml_unet.license = "OpenRAIL++-M (https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md)"
else:
coreml_unet.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
coreml_unet.version = args.model_version if model_name != "refiner" or not hasattr(args, "refiner_version") else args.refiner_version
coreml_unet.short_description = \
"Stable Diffusion generates images conditioned on text or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/abs/2112.10752 for details."
# Set the input descriptions
coreml_unet.input_description["sample"] = \
"The low resolution latent feature maps being denoised through reverse diffusion"
coreml_unet.input_description["timestep"] = \
"A value emitted by the associated scheduler object to condition the model on a given noise schedule"
coreml_unet.input_description["encoder_hidden_states"] = \
"Output embeddings from the associated text_encoder model to condition to generated image on text. " \
"A maximum of 77 tokens (~40 words) are allowed. Longer text is truncated. " \
"Shorter text does not reduce computation."
if args.xl_version:
coreml_unet.input_description["time_ids"] = \
"Additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks."
coreml_unet.input_description["text_embeds"] = \
"Additional embeddings from text_encoder_2 that if specified are added to the embeddings that are passed along to the UNet blocks."
# Set the output descriptions
coreml_unet.output_description["noise_pred"] = \
"Same shape and dtype as the `sample` input. " \
"The predicted noise to facilitate the reverse diffusion (denoising) process"
# Set package version metadata
from python_coreml_stable_diffusion._version import __version__
coreml_unet.user_defined_metadata["com.github.apple.ml-stable-diffusion.version"] = __version__
coreml_unet.save(out_path)
logger.info(f"Saved unet into {out_path}")
# Parity check PyTorch vs CoreML
if args.check_output_correctness:
coreml_out = list(
coreml_unet.predict(coreml_sample_unet_inputs).values())[0]
report_correctness(baseline_out, coreml_out,
"unet baseline PyTorch to reference CoreML")
del coreml_unet
gc.collect()
else: