Skip to content

Commit

Permalink
Fix torch ci
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Dec 14, 2024
1 parent 5fc7b6a commit 60aa8d4
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ def __init__(self, factor=0.5, data_format=None, **kwargs):
self.random_generator = self.backend.random.SeedGenerator()

def get_random_transformation(self, images, training=True, seed=None):
if seed is None:
seed = self._get_seed_generator(self.backend._backend)
random_values = self.backend.random.uniform(
shape=(self.backend.core.shape(images)[0],),
minval=0,
maxval=1,
seed=self.random_generator,
seed=seed,
)
should_apply = self.backend.numpy.expand_dims(
random_values < self.factor, axis=[1, 2, 3]
Expand Down

0 comments on commit 60aa8d4

Please sign in to comment.