forked from keras-team/keras-cv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add distribution strategy support for model.fit/eval/predict. (keras-…
…team#119) * Add unit/integration test for tf.distribute. * Fix format * Skip the test case for non-tf backend * Fix typo * Fix format and unit test context config. * Address review comments. * Add support for h5 weights loading. * Fix test * Add support for a -1 dimension in the `Reshape` operation. (keras-team#103) The code to compute the output shape is now shared between the `Reshape` operation and the `Reshape` layer. * Added ReLU activation layer (keras-team#104) * added relu * add relu * added correctness test * reformated * updates based on review * Fix docstring * Added R2score (keras-team#106) * Add meanX metrics * All regression metrics except for root mean squared error * Formatting issues * Add RootMeanSquaredError * Docstring spacing * Line too long fix * Add R2Score * Docstring fixes * Fix test * Fix tests * Adds RemoteMonitor Callback (keras-team#108) * Add Remote Monitor Callback * Add Remote Monitor Callback * Add Remote Monitor Callback * Add Remote Monitor * Add wrapper layer. * Add learning rate schedules (keras-team#102) * Add learning rate schedules * Some review comments * Use fancy new serialization tests * s/TensorFlow/backend in docstring * Update docstrings * More review comments * Added LeakyReLU activation layer (keras-team#109) * added LeakyReLu * update docstring * reformat * update config * updated test name * Fix docstrings * Fix init and update tests to import from correct path (keras-team#110) * Add distribute support for tensorflow trainer. * Revert the previous merge edit. * Fix lint issue * Address review comments. * Add TPU strategy support * Fix lint --------- Co-authored-by: Francois Chollet <francois.chollet@gmail.com> Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com> Co-authored-by: divyasreepat <divyashreepathihalli@gmail.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Ian Stenbit <3072903+ianstenbit@users.noreply.github.com>
- Loading branch information
1 parent
7c76b35
commit f02520d
Showing
2 changed files
with
270 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from keras_core import layers | ||
from keras_core import losses | ||
from keras_core import models | ||
from keras_core import metrics | ||
from keras_core import optimizers | ||
from keras_core.utils import rng_utils | ||
|
||
|
||
def test_model_fit(): | ||
|
||
cpus = tf.config.list_physical_devices("CPU") | ||
tf.config.set_logical_device_configuration( | ||
cpus[0], | ||
[ | ||
tf.config.LogicalDeviceConfiguration(), | ||
tf.config.LogicalDeviceConfiguration(), | ||
], | ||
) | ||
|
||
rng_utils.set_random_seed(1337) | ||
|
||
strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1']) | ||
with strategy.scope(): | ||
inputs = layers.Input((100,), batch_size=32) | ||
x = layers.Dense(256, activation="relu")(inputs) | ||
x = layers.Dense(256, activation="relu")(x) | ||
x = layers.Dense(256, activation="relu")(x) | ||
x = layers.BatchNormalization()(x) | ||
outputs = layers.Dense(16)(x) | ||
model = models.Model(inputs, outputs) | ||
|
||
model.summary() | ||
|
||
x = np.random.random((50000, 100)) | ||
y = np.random.random((50000, 16)) | ||
batch_size = 32 | ||
epochs = 5 | ||
|
||
model.compile( | ||
optimizer=optimizers.SGD(learning_rate=0.001), | ||
loss=losses.MeanSquaredError(), | ||
metrics=[metrics.MeanSquaredError()], | ||
# TODO(scottzhu): Find out where is the variable that is not created eagerly | ||
# and break the usage of XLA. | ||
jit_compile=False, | ||
) | ||
history = model.fit( | ||
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2 | ||
) | ||
|
||
print("History:") | ||
print(history.history) | ||
|
||
if __name__ == "__main__": | ||
test_model_fit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters