forked from facebookresearch/deit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
augment.py
123 lines (99 loc) · 3.37 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
"""
3Augment implementation
Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)
and timm DA(https://github.com/rwightman/pytorch-image-models)
"""
import torch
from torchvision import transforms
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
import numpy as np
from torchvision import datasets, transforms
import random
from PIL import ImageFilter, ImageOps
import torchvision.transforms.functional as TF
class GaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = random.random() <= self.prob
if not do_it:
return img
img = img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
return img
class Solarization(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
class gray_scale(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2):
self.p = p
self.transf = transforms.Grayscale(3)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
class horizontal_flip(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2,activate_pred=False):
self.p = p
self.transf = transforms.RandomHorizontalFlip(p=1.0)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
def new_data_aug_generator(args = None):
img_size = args.input_size
remove_random_resized_crop = args.src
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
primary_tfl = []
scale=(0.08, 1.0)
interpolation='bicubic'
if remove_random_resized_crop:
primary_tfl = [
transforms.Resize(img_size, interpolation=3),
transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'),
transforms.RandomHorizontalFlip()
]
else:
primary_tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip()
]
secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0),
Solarization(p=1.0),
GaussianBlur(p=1.0)])]
if args.color_jitter is not None and not args.color_jitter==0:
secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter))
final_tfl = [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)