Skip to content

Commit

Permalink
Use SeedGenerator backend when creating variable (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
sampathweb authored Jul 11, 2023
1 parent feee944 commit f990033
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
6 changes: 3 additions & 3 deletions keras_core/layers/preprocessing/tf_data_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tensorflow import nest

from keras_core import backend
import keras_core.backend
from keras_core.layers.layer import Layer
from keras_core.random.seed_generator import SeedGenerator
from keras_core.utils import backend_utils
Expand All @@ -22,7 +22,7 @@ def __init__(self, **kwargs):

def __call__(self, inputs, **kwargs):
if backend_utils.in_tf_graph() and not isinstance(
inputs, backend.KerasTensor
inputs, keras_core.backend.KerasTensor
):
# We're in a TF graph, e.g. a tf.data pipeline.
self.backend.set_backend("tensorflow")
Expand All @@ -47,7 +47,7 @@ def __call__(self, inputs, **kwargs):

@tracking.no_automatic_dependency_tracking
def _get_seed_generator(self, backend=None):
if backend is None or backend == self.backend._backend:
if backend is None or backend == keras_core.backend.backend():
return self.generator
if not hasattr(self, "_backend_generators"):
self._backend_generators = {}
Expand Down
13 changes: 8 additions & 5 deletions keras_core/random/seed_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

import keras_core.backend
from keras_core.api_export import keras_core_export


Expand Down Expand Up @@ -30,9 +31,9 @@ def __init__(self, seed, **kwargs):
if kwargs:
raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
if custom_backend is not None:
backend = custom_backend
self.backend = custom_backend
else:
from keras_core import backend
self.backend = keras_core.backend

if seed is None:
seed = make_default_seed()
Expand All @@ -43,9 +44,9 @@ def __init__(self, seed, **kwargs):

def seed_initializer(*args, **kwargs):
dtype = kwargs.get("dtype", None)
return backend.convert_to_tensor([seed, 0], dtype=dtype)
return self.backend.convert_to_tensor([seed, 0], dtype=dtype)

self.state = backend.Variable(
self.state = self.backend.Variable(
seed_initializer,
shape=(2,),
dtype="uint32",
Expand All @@ -65,7 +66,9 @@ def draw_seed(seed):
seed_state = seed.state
# Use * 1 to create a copy
new_seed_value = seed_state.value * 1
increment = convert_to_tensor(np.array([0, 1]), dtype="uint32")
increment = seed.backend.convert_to_tensor(
np.array([0, 1]), dtype="uint32"
)
seed.state.assign(seed_state + increment)
return new_seed_value
elif isinstance(seed, int):
Expand Down

0 comments on commit f990033

Please sign in to comment.