Skip to content

Commit

Permalink
update RandAugmentV3
Browse files Browse the repository at this point in the history
  • Loading branch information
flytocc committed Mar 8, 2023
1 parent e6b4f1e commit 72b016b
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 9 deletions.
5 changes: 2 additions & 3 deletions ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x0_5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@ DataLoader:
use_log_aspect: True
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
- RandAugmentV3:
num_layers: 2
interpolation: bicubic
img_size: 256
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
Expand Down
5 changes: 2 additions & 3 deletions ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x0_75.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@ DataLoader:
use_log_aspect: True
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
- RandAugmentV3:
num_layers: 2
interpolation: bicubic
img_size: 256
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
Expand Down
5 changes: 2 additions & 3 deletions ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x1_0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@ DataLoader:
use_log_aspect: True
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
- RandAugmentV3:
num_layers: 2
interpolation: bicubic
img_size: 256
- NormalizeImage:
scale: 1.0/255.0
mean: [0.0, 0.0, 0.0]
Expand Down
85 changes: 85 additions & 0 deletions ppcls/data/preprocess/ops/randaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import random
from .operators import RawColorJitter
from .timm_autoaugment import _pil_interp
from paddle.vision.transforms import transforms as T

import numpy as np
Expand Down Expand Up @@ -260,3 +261,87 @@ def rotate_with_fill(img, magnitude):
"invert": lambda img, _: ImageOps.invert(img),
"cutout": lambda img, magnitude: cutout(img, magnitude, replace=fillcolor[0])
}


class RandAugmentV3(RandAugment):
"""Customed RandAugment for MobileViTv2"""

def __init__(self,
num_layers=2,
magnitude=3,
fillcolor=(0, 0, 0),
interpolation="bicubic"):
self.num_layers = num_layers
self.magnitude = magnitude
self.max_level = 10
interpolation = _pil_interp(interpolation)

abso_level = self.magnitude / self.max_level
self.level_map = {
"shearX": 0.3 * abso_level,
"shearY": 0.3 * abso_level,
"translateX": 150.0 / 331.0 * abso_level,
"translateY": 150.0 / 331.0 * abso_level,
"rotate": 30 * abso_level,
"color": 0.9 * abso_level,
"posterize": 8 - int(4.0 * abso_level),
"solarize": 255.0 * (1 - abso_level),
"contrast": 0.9 * abso_level,
"sharpness": 0.9 * abso_level,
"brightness": 0.9 * abso_level,
"autocontrast": 0,
"equalize": 0,
"invert": 0
}

rnd_ch_op = random.choice

self.func = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
interpolation,
fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
interpolation,
fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
interpolation,
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
interpolation,
fillcolor=fillcolor),
"rotate": lambda img, magnitude: img.rotate(
magnitude * rnd_ch_op([-1, 1]),
interpolation,
fillcolor=fillcolor),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"posterize": lambda img, magnitude:
ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude:
ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude:
ImageEnhance.Contrast(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"sharpness": lambda img, magnitude:
ImageEnhance.Sharpness(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"brightness": lambda img, magnitude:
ImageEnhance.Brightness(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"autocontrast": lambda img, _:
ImageOps.autocontrast(img),
"equalize": lambda img, _: ImageOps.equalize(img),
"invert": lambda img, _: ImageOps.invert(img)
}

0 comments on commit 72b016b

Please sign in to comment.