Skip to content

Commit

Permalink
Deterministic/Replay mode for augmentations (#350)
Browse files Browse the repository at this point in the history
* version bump

* add proof of concept for deterministic mode (when you can store and replay applied transforms)

* wip

* format with black

* ugly but working prototype

* make it a little more readable

* add working replay mode and jupyter notebook with example.

* add example to readme
  • Loading branch information
albu committed Sep 27, 2019
1 parent 14df752 commit 9942689
Show file tree
Hide file tree
Showing 6 changed files with 486 additions and 22 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@

**Serialization** [`serialization.ipynb`](https://github.com/albu/albumentations/blob/master/notebooks/serialization.ipynb)

**Replay/Deterministic mode** [`replay.ipynb`](https://github.com/albu/albumentations/blob/master/notebooks/replay.ipynb)

You can use this [Google Colaboratory notebook](https://colab.research.google.com/drive/1JuZ23u0C0gx93kV0oJ8Mq0B6CBYhPLXy#scrollTo=GwFN-In3iagp&forceEdit=true&offline=true&sandboxMode=true)
to adjust image augmentation parameters and see the resulting images.

Expand Down
103 changes: 102 additions & 1 deletion albumentations/core/composition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division
from collections import defaultdict

import random

Expand All @@ -10,8 +11,9 @@
from albumentations.core.transforms_interface import DualTransform
from albumentations.core.utils import format_args, Params
from albumentations.augmentations.bbox_utils import BboxProcessor
from albumentations.core.serialization import SERIALIZABLE_REGISTRY, instantiate_lambda

__all__ = ["Compose", "OneOf", "OneOrOther", "BboxParams", "KeypointParams"]
__all__ = ["Compose", "OneOf", "OneOrOther", "BboxParams", "KeypointParams", "ReplayCompose"]


REPR_INDENT_STEP = 2
Expand Down Expand Up @@ -64,6 +66,9 @@ def __init__(self, transforms, p):
self.transforms = Transforms(transforms)
self.p = p

self.replay_mode = False
self.applied_in_replay = False

def __getitem__(self, item):
return self.transforms[item]

Expand Down Expand Up @@ -94,11 +99,23 @@ def _to_dict(self):
"transforms": [t._to_dict() for t in self.transforms],
}

def get_dict_with_id(self):
return {
"__class_fullname__": self.get_class_fullname(),
"id": id(self),
"params": None,
"transforms": [t.get_dict_with_id() for t in self.transforms],
}

def add_targets(self, additional_targets):
if additional_targets:
for t in self.transforms:
t.add_targets(additional_targets)

def set_deterministic(self, flag, save_key="replay"):
for t in self.transforms:
t.set_deterministic(flag, save_key)


class Compose(BaseCompose):
"""Compose transforms and handle all transformations regrading bounding boxes
Expand Down Expand Up @@ -193,6 +210,11 @@ def __init__(self, transforms, p=0.5):
self.transforms_ps = [t / s for t in transforms_ps]

def __call__(self, force_apply=False, **data):
if self.replay_mode:
for t in self.transforms:
data = t(**data)
return data

if force_apply or random.random() < self.p:
random_state = np.random.RandomState(random.randint(0, 2 ** 32 - 1))
t = random_state.choice(self.transforms.transforms, p=self.transforms_ps)
Expand All @@ -207,6 +229,11 @@ def __init__(self, first=None, second=None, transforms=None, p=0.5):
super(OneOrOther, self).__init__(transforms, p)

def __call__(self, force_apply=False, **data):
if self.replay_mode:
for t in self.transforms:
data = t(**data)
return data

if random.random() < self.p:
return self.transforms[0](force_apply=True, **data)
else:
Expand Down Expand Up @@ -248,6 +275,80 @@ def __call__(self, force_apply=False, **data):
return data


class ReplayCompose(Compose):
def __init__(
self, transforms, bbox_params=None, keypoint_params=None, additional_targets=None, p=1.0, save_key="replay"
):
super(ReplayCompose, self).__init__(transforms, bbox_params, keypoint_params, additional_targets, p)
self.set_deterministic(True, save_key=save_key)
self.save_key = save_key

def __call__(self, force_apply=False, **kwargs):
kwargs[self.save_key] = defaultdict(dict)
result = super(ReplayCompose, self).__call__(force_apply=force_apply, **kwargs)
serialized = self.get_dict_with_id()
self.fill_with_params(serialized, result[self.save_key])
self.fill_applied(serialized)
result[self.save_key] = serialized
return result

@staticmethod
def replay(saved_augmentations, **kwargs):
augs = ReplayCompose._restore_for_replay(saved_augmentations)
return augs(force_apply=True, **kwargs)

@staticmethod
def _restore_for_replay(transform_dict, lambda_transforms=None):
"""
Args:
transform (dict): A dictionary with serialized transform pipeline.
lambda_transforms (dict): A dictionary that contains lambda transforms, that
is instances of the Lambda class.
This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys
in that dictionary should be named same as `name` arguments in respective lambda transforms from
a serialized pipeline.
"""
transform = transform_dict
applied = transform["applied"]
params = transform["params"]
lmbd = instantiate_lambda(transform, lambda_transforms)
if lmbd:
transform = lmbd
else:
name = transform["__class_fullname__"]
args = {k: v for k, v in transform.items() if k not in ["__class_fullname__", "applied", "params"]}
cls = SERIALIZABLE_REGISTRY[name]
if "transforms" in args:
args["transforms"] = [
ReplayCompose._restore_for_replay(t, lambda_transforms=lambda_transforms)
for t in args["transforms"]
]
transform = cls(**args)

transform.params = params
transform.replay_mode = True
transform.applied_in_replay = applied
return transform

def fill_with_params(self, serialized, all_params):
params = all_params.get(serialized.get("id"))
serialized["params"] = params
del serialized["id"]
for transform in serialized.get("transforms", []):
self.fill_with_params(transform, all_params)

def fill_applied(self, serialized):
if "transforms" in serialized:
applied = [self.fill_applied(t) for t in serialized["transforms"]]
serialized["applied"] = any(applied)
else:
serialized["applied"] = serialized.get("params") is not None
return serialized["applied"]

def _to_dict(self):
raise NotImplementedError("You cannot serialize ReplayCompose")


class BboxParams(Params):
"""
Parameters of bounding boxes
Expand Down
26 changes: 16 additions & 10 deletions albumentations/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,7 @@ def to_dict(transform, on_not_implemented_error="raise"):
return {"__version__": __version__, "transform": transform_dict}


def from_dict(transform_dict, lambda_transforms=None):
"""
Args:
transform (dict): A dictionary with serialized transform pipeline.
lambda_transforms (dict): A dictionary that contains lambda transforms, that is instances of the Lambda class.
This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys
in that dictionary should be named same as `name` arguments in respective lambda transforms from
a serialized pipeline.
"""
transform = transform_dict["transform"]
def instantiate_lambda(transform, lambda_transforms=None):
if transform.get("__type__") == "Lambda":
name = transform["__name__"]
if lambda_transforms is None:
Expand All @@ -87,6 +78,21 @@ def from_dict(transform_dict, lambda_transforms=None):
if transform is None:
raise ValueError("Lambda transform with {name} was not found in `lambda_transforms`".format(name=name))
return transform


def from_dict(transform_dict, lambda_transforms=None):
"""
Args:
transform (dict): A dictionary with serialized transform pipeline.
lambda_transforms (dict): A dictionary that contains lambda transforms, that is instances of the Lambda class.
This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys
in that dictionary should be named same as `name` arguments in respective lambda transforms from
a serialized pipeline.
"""
transform = transform_dict["transform"]
lmbd = instantiate_lambda(transform, lambda_transforms)
if lmbd:
return lmbd
name = transform["__class_fullname__"]
args = {k: v for k, v in transform.items() if k != "__class_fullname__"}
cls = SERIALIZABLE_REGISTRY[name]
Expand Down
62 changes: 52 additions & 10 deletions albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import absolute_import

import random
from warnings import warn

import cv2
from copy import deepcopy

from albumentations.core.serialization import SerializableMeta
from albumentations.core.six import add_metaclass
Expand Down Expand Up @@ -44,33 +46,68 @@ def to_tuple(param, low=None, bias=None):

@add_metaclass(SerializableMeta)
class BasicTransform(object):
call_backup = None

def __init__(self, always_apply=False, p=0.5):
self.p = p
self.always_apply = always_apply
self._additional_targets = {}

# replay mode params
self.deterministic = False
self.save_key = "replay"
self.params = {}
self.replay_mode = False
self.applied_in_replay = False

def __call__(self, force_apply=False, **kwargs):
if self.replay_mode:
if self.applied_in_replay:
return self.apply_with_params(self.params, **kwargs)
else:
return kwargs

if (random.random() < self.p) or self.always_apply or force_apply:
params = self.get_params()
params = self.update_params(params, **kwargs)

if self.targets_as_params:
assert all(key in kwargs for key in self.targets_as_params), "{} requires {}".format(
self.__class__.__name__, self.targets_as_params
)
targets_as_params = {k: kwargs[k] for k in self.targets_as_params}
params_dependent_on_targets = self.get_params_dependent_on_targets(targets_as_params)
params.update(params_dependent_on_targets)
res = {}
for key, arg in kwargs.items():
if arg is not None:
target_function = self._get_target_function(key)
target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])}
res[key] = target_function(arg, **dict(params, **target_dependencies))
else:
res[key] = None
return res
if self.deterministic:
if self.targets_as_params:
warn(
self.get_class_fullname() + " could work incorrectly in ReplayMode for other input data"
" because its' params depend on targets."
)
kwargs[self.save_key][id(self)] = deepcopy(params)
return self.apply_with_params(params, **kwargs)

return kwargs

def apply_with_params(self, params, force_apply=False, **kwargs):
if params is None:
return kwargs
params = self.update_params(params, **kwargs)
res = {}
for key, arg in kwargs.items():
if arg is not None:
target_function = self._get_target_function(key)
target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])}
res[key] = target_function(arg, **dict(params, **target_dependencies))
else:
res[key] = None
return res

def set_deterministic(self, flag, save_key="replay"):
assert save_key != "params", "params save_key is reserved"
self.deterministic = flag
self.save_key = save_key
return self

def __repr__(self):
state = self.get_base_init_args()
state.update(self.get_transform_init_args())
Expand Down Expand Up @@ -151,6 +188,11 @@ def _to_dict(self):
state.update(self.get_transform_init_args())
return state

def get_dict_with_id(self):
d = self._to_dict()
d["id"] = id(self)
return d


class DualTransform(BasicTransform):
"""Transform for segmentation task."""
Expand Down
291 changes: 291 additions & 0 deletions notebooks/replay.ipynb

Large diffs are not rendered by default.

24 changes: 23 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from albumentations.core.transforms_interface import to_tuple, ImageOnlyTransform, DualTransform
from albumentations.augmentations.bbox_utils import check_bboxes
from albumentations.core.composition import OneOrOther, Compose, OneOf, PerChannel
from albumentations.core.composition import OneOrOther, Compose, OneOf, PerChannel, ReplayCompose
from albumentations.augmentations.transforms import HorizontalFlip, Rotate, Blur, MedianBlur
from .compat import mock, MagicMock, Mock, call

Expand Down Expand Up @@ -138,3 +138,25 @@ def test_per_channel_multi():
image = np.ones((8, 8, 5))
data = augmentation(image=image)
assert data


def test_deterministic_oneof():
aug = ReplayCompose([OneOf([HorizontalFlip(), Blur()])], p=1)
for i in range(10):
image = (np.random.random((8, 8)) * 255).astype(np.uint8)
image2 = np.copy(image)
data = aug(image=image)
assert "replay" in data
data2 = ReplayCompose.replay(data["replay"], image=image2)
assert np.array_equal(data["image"], data2["image"])


def test_deterministic_one_or_other():
aug = ReplayCompose([OneOrOther(HorizontalFlip(), Blur())], p=1)
for i in range(10):
image = (np.random.random((8, 8)) * 255).astype(np.uint8)
image2 = np.copy(image)
data = aug(image=image)
assert "replay" in data
data2 = ReplayCompose.replay(data["replay"], image=image2)
assert np.array_equal(data["image"], data2["image"])

0 comments on commit 9942689

Please sign in to comment.