-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
214 lines (198 loc) · 9.94 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import argparse
import datetime
import glob
import os
import sys
import time
import wandb
import numpy as np
from functools import partial
from omegaconf import OmegaConf
from packaging import version
from prefetch_generator import BackgroundGenerator
import torch
from tqdm import tqdm
from einops import rearrange
try:
import lightning.pytorch as pl
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
LIGHTNING_PACK_NAME = "lightning.pytorch."
except:
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
LIGHTNING_PACK_NAME = "pytorch_lightning."
from ipdb import set_trace as st
from utils.util import instantiate_from_config, tensor2img
# from utils.modules.attention import enable_flash_attentions
from main import get_parser, load_state_dict, nondefault_trainer_args
def get_dataloader(data_cfg, batch_size, shuffle=False):
print(f"Shuffle datasets : {shuffle}")
import torch.utils.data as Data
# data_cfg.params.shuffle = shuffle
dataset = instantiate_from_config(data_cfg)
dataloader = Data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=False,
drop_last=False
)
print(data_cfg)
print(f"- len(dataset): {len(dataset)}")
print(f"- len(dataloader): {len(dataloader)}")
return dataloader
if __name__ == "__main__":
sys.path.append(os.getcwd())
parser = get_parser()
parser = Trainer.add_argparse_args(parser)
opt, unknown = parser.parse_known_args()
# set save directories
if opt.resume:
rank_zero_info("Resuming from {}".format(opt.resume))
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
logdir = "/".join(paths[:-2])
rank_zero_info("logdir: {}".format(logdir))
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base
_tmp = logdir.split("/")
nowname = _tmp[-1]
else:
logdir = opt.logdir
ckpt = "nan"
if opt.caps_path is not None:
sampledir = os.path.join(logdir, f"samples_given_{os.path.basename(opt.caps_path).split('.')[0]}")
elif opt.new_unconditional_guidance is False:
sampledir = os.path.join(logdir, f'samples-{opt.save_mode}-ucgs{opt.unconditional_guidance_scale}-'
f'{os.path.basename(ckpt).split(".")[0]}-ddim{opt.ddim_step}')
else:
sampledir = os.path.join(logdir, f'samples-{opt.save_mode}-ucgsimg{opt.unconditional_guidance_scale_img}'
f'-ucgsvid{opt.unconditional_guidance_scale_vid}-'
f'{os.path.basename(ckpt).split(".")[0]}')
sampledir = sampledir + opt.suffix
os.makedirs(sampledir, exist_ok=True)
seed_everything(opt.seed)
# init and save configs
configs = [OmegaConf.load(cfg.strip()) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop("lightning", OmegaConf.create())
# merge trainer cli with config
trainer_config = lightning_config.get("trainer", OmegaConf.create())
trainer_config["devices"] = opt.ngpu or trainer_config["devices"]
print(f"!!! WARNING: Number of gpu is {trainer_config['devices']} ")
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
# model
config.model["params"].update({"use_fp16": False})
load_strict = trainer_config.get('ckpt_load_strict', True)
model = instantiate_from_config(config.model).cpu()
model.load_state_dict(load_state_dict(ckpt.strip(), location='cpu'), strict=load_strict)
print(f"Load ckpt from {ckpt} with strict {load_strict}")
model.register_schedule(linear_start=model.linear_start, linear_end=model.linear_end)
model = model.cuda().eval()
# data
print(f"- Loading validation data...")
bs = opt.batch_size or config.data.params.batch_size
if opt.caps_path is not None:
config.data.params.validation.target = "data.custom.VideoFolderDataset_Inference"
config.data.params.validation.params.caps_path = opt.caps_path
config.data.params.validation.params.num_replication = opt.num_replication
if opt.dataset_root is not None:
config.data.params.validation.params.root = opt.dataset_root
config.data.params.validation.params.max_data_num = opt.total_sample_number
dataloader = get_dataloader(config.data.params.validation, bs, opt.shuffle)
part_num = len(dataloader) / opt.total_part
start_idx = int((opt.cur_part - 1) * part_num)
end_idx = int(opt.cur_part * part_num)
# start to generate
vc = None
ddim_step = opt.ddim_step
save_mode = opt.save_mode # bybatch, byvideo, byframe
verbose = opt.test_verbose
video_length = opt.video_length
total_sample_number = opt.total_sample_number
use_ddim = opt.use_ddim and (ddim_step > 0)
print(f"- Saving generated samples to {sampledir}")
print(f"- Use ddim: {use_ddim} with ddim steps: {ddim_step}")
print(f"- Cur part {opt.cur_part}/{opt.total_part} with idx from {start_idx} to {end_idx}")
batch_idx = 0
for batch in tqdm(iter(dataloader), desc=f"new ucgs: {opt.new_unconditional_guidance}"):
if batch_idx >= start_idx and batch_idx < end_idx:
if len(os.listdir(sampledir)) >= total_sample_number:
final_number = max(len(os.listdir(sampledir)), bs * (batch_idx + 1))
print(f"Having generated {final_number} video samples in {sampledir}!")
exit()
if vc is None:
vc, _ = model.get_input(batch, None)
x_T = torch.randn_like(vc)
start_t = time.time()
if opt.new_unconditional_guidance is False:
unconditional_guidance_scale = opt.unconditional_guidance_scale
try:
sample_log = model.log_videos_parallel(batch, N=bs, n_frames=video_length, x_T=x_T, verbose=verbose,
sample=False, use_ddim=use_ddim, ddim_steps=ddim_step,
unconditional_guidance_scale=unconditional_guidance_scale)
except:
sample_log = model.log_videos(batch, N=bs, n_frames=video_length, x_T=x_T, verbose=verbose,
sample=False, use_ddim=use_ddim, ddim_steps=ddim_step,
unconditional_guidance_scale=unconditional_guidance_scale)
else:
ucgs_img = opt.unconditional_guidance_scale_img
ucgs_vid = opt.unconditional_guidance_scale_vid
try:
sample_log = model.log_videos_parallel(batch, N=bs, n_frames=video_length, x_T=x_T, verbose=verbose,
sample=False, use_ddim=use_ddim, ddim_steps=ddim_step,
ucgs_image=ucgs_img, ucgs_video=ucgs_vid)
except:
sample_log = model.log_videos(batch, N=bs, n_frames=video_length, x_T=x_T, verbose=verbose,
sample=False, use_ddim=use_ddim, ddim_steps=ddim_step,
ucgs_image=ucgs_img, ucgs_video=ucgs_vid)
end_t = time.time()
# print(f"Generation time: {end_t - start_t}s")
# print(sample_log.keys())
if use_ddim is False:
cur_video = sample_log["samples_ddpm"]
elif opt.new_unconditional_guidance is False:
cur_video = sample_log[f"samples_ug_scale_{unconditional_guidance_scale:.2f}"]
else:
cur_video = sample_log[f"samples_ug_scale_i{ucgs_img:.2f}v{ucgs_vid:.2f}"]
cur_video = cur_video.detach().cpu() # b c t h w
if save_mode == "bybatch":
save = tensor2img(cur_video)
save.save(os.path.join(sampledir, f"{batch_idx:04d}.jpg"))
elif save_mode == "byvideo":
video_names = batch['video_name']
for b, name in enumerate(video_names):
save = tensor2img(cur_video[b].unsqueeze(0))
video_name = f"b{batch_idx:04d}{b:02d}-v{name}-s{opt.seed}"
save.save(os.path.join(sampledir, f"{video_name}.jpg"))
elif save_mode == "byframe":
video_names = batch['video_name']
for b, name in enumerate(video_names):
video_name = f"b{batch_idx:04d}{b:02d}-v{name}-s{opt.seed}"
save_path = os.path.join(sampledir, video_name)
os.makedirs(save_path, exist_ok=True)
for t in range(video_length):
frame = tensor2img(cur_video[b, :, t, :, :])
# frame.save(os.path.join(save_path, f"{t:04d}.png"))
frame.save(os.path.join(save_path, f"{t:04d}.jpg"))
else:
raise NotImplementedError
batch_idx += 1
torch.cuda.empty_cache()