Skip to content

Commit

Permalink
Bugfix in GaussNoise (#315)
Browse files Browse the repository at this point in the history
* Bugfix in GaussNoise

* Added check for variance of gaussian noise to be non negative

* var_limit it GaussNoise should be non negative

* bugfix. var can be int
  • Loading branch information
ternaus authored Aug 20, 2019
1 parent aff2cce commit 1bc367f
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,30 +1881,47 @@ class GaussNoise(ImageOnlyTransform):
Args:
var_limit ((float, float) or float): variance range for noise. If var_limit is a single float, the range
will be (-var_limit, var_limit). Default: (10., 50.).
will be (0, var_limit). Default: (10.0, 50.0).
mean (float): mean of the noise. Default: 0
p (float): probability of applying the transform. Default: 0.5.
Targets:
image
Image types:
uint8
uint8, float32
"""

def __init__(self, var_limit=(10., 50.), always_apply=False, p=0.5):
def __init__(self, var_limit=(10.0, 50.0), mean=None, always_apply=False, p=0.5):
super(GaussNoise, self).__init__(always_apply, p)
self.var_limit = to_tuple(var_limit)
if isinstance(var_limit, tuple):
if var_limit[0] < 0:
raise ValueError("Lower var_limit should be non negative.")
if var_limit[1] < 0:
raise ValueError("Upper var_limit should be non negative.")
self.var_limit = var_limit
elif isinstance(var_limit, (int, float)):
if var_limit < 0:
raise ValueError(" var_limit should be non negative.")

self.var_limit = (0, var_limit)

self.mean = mean

def apply(self, img, gauss=None, **params):
return F.gauss_noise(img, gauss=gauss)

def get_params_dependent_on_targets(self, params):
image = params['image']
var = random.uniform(self.var_limit[0], self.var_limit[1])
mean = var
sigma = var ** 0.5
random_state = np.random.RandomState(random.randint(0, 2 ** 32 - 1))
gauss = random_state.normal(mean, sigma, image.shape)

if self.mean is None:
DeprecationWarning('In the version 0.4.0 default behavior of GaussNoise mean will be changed to 0.')
self.mean = var

gauss = random_state.normal(self.mean, sigma, image.shape)
return {
'gauss': gauss
}
Expand Down

0 comments on commit 1bc367f

Please sign in to comment.