-
Notifications
You must be signed in to change notification settings - Fork 62
/
test_dataset.py
104 lines (84 loc) · 3 KB
/
test_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Run: python3 tests/test_dataset.py
import sys
def test_video_dataset():
from dataset import VideoDataset
dataset_dirs = VideoDataset(
data_root="assets/tests/",
caption_column="prompts.txt",
video_column="videos.txt",
max_num_frames=49,
id_token=None,
random_flip=None,
)
dataset_csv = VideoDataset(
data_root="assets/tests/",
dataset_file="assets/tests/metadata.csv",
caption_column="caption",
video_column="video",
max_num_frames=49,
id_token=None,
random_flip=None,
)
assert len(dataset_dirs) == 1
assert len(dataset_csv) == 1
assert dataset_dirs[0]["video"].shape == (49, 3, 480, 720)
assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all()
print(dataset_dirs[0]["video"].shape)
def test_video_dataset_with_resizing():
from dataset import VideoDatasetWithResizing
dataset_dirs = VideoDatasetWithResizing(
data_root="assets/tests/",
caption_column="prompts.txt",
video_column="videos.txt",
max_num_frames=49,
id_token=None,
random_flip=None,
)
dataset_csv = VideoDatasetWithResizing(
data_root="assets/tests/",
dataset_file="assets/tests/metadata.csv",
caption_column="caption",
video_column="video",
max_num_frames=49,
id_token=None,
random_flip=None,
)
assert len(dataset_dirs) == 1
assert len(dataset_csv) == 1
assert dataset_dirs[0]["video"].shape == (48, 3, 480, 720) # Changes due to T2V frame bucket sampling
assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all()
print(dataset_dirs[0]["video"].shape)
def test_video_dataset_with_bucket_sampler():
import torch
from dataset import BucketSampler, VideoDatasetWithResizing
from torch.utils.data import DataLoader
dataset_dirs = VideoDatasetWithResizing(
data_root="assets/tests/",
caption_column="prompts_multi.txt",
video_column="videos_multi.txt",
max_num_frames=49,
id_token=None,
random_flip=None,
)
sampler = BucketSampler(dataset_dirs, batch_size=8)
def collate_fn(data):
captions = [x["prompt"] for x in data[0]]
videos = [x["video"] for x in data[0]]
videos = torch.stack(videos)
return captions, videos
dataloader = DataLoader(dataset_dirs, batch_size=1, sampler=sampler, collate_fn=collate_fn)
first = False
for captions, videos in dataloader:
if not first:
assert len(captions) == 8 and isinstance(captions[0], str)
assert videos.shape == (8, 48, 3, 480, 720)
first = True
else:
assert len(captions) == 8 and isinstance(captions[0], str)
assert videos.shape == (8, 48, 3, 256, 360)
break
if __name__ == "__main__":
sys.path.append("./training")
test_video_dataset()
test_video_dataset_with_resizing()
test_video_dataset_with_bucket_sampler()