Skip to content

Commit

Permalink
Add distribution strategy support for model.fit/eval/predict. (keras-…
Browse files Browse the repository at this point in the history
…team#119)

* Add unit/integration test for tf.distribute.

* Fix format

* Skip the test case for non-tf backend

* Fix typo

* Fix format and unit test context config.

* Address review comments.

* Add support for h5 weights loading.

* Fix test

* Add support for a -1 dimension in the `Reshape` operation. (keras-team#103)

The code to compute the output shape is now shared between the `Reshape` operation and the `Reshape` layer.

* Added ReLU activation layer (keras-team#104)

* added relu

* add relu

* added correctness test

* reformated

* updates based on review

* Fix docstring

* Added R2score (keras-team#106)

* Add meanX metrics

* All regression metrics except for root mean squared error

* Formatting issues

* Add RootMeanSquaredError

* Docstring spacing

* Line too long fix

* Add R2Score

* Docstring fixes

* Fix test

* Fix tests

* Adds RemoteMonitor Callback (keras-team#108)

* Add Remote Monitor Callback

* Add Remote Monitor Callback

* Add Remote Monitor Callback

* Add Remote Monitor

* Add wrapper layer.

* Add learning rate schedules (keras-team#102)

* Add learning rate schedules

* Some review comments

* Use fancy new serialization tests

* s/TensorFlow/backend in docstring

* Update docstrings

* More review comments

* Added LeakyReLU activation layer (keras-team#109)

* added LeakyReLu

* update docstring

* reformat

* update config

* updated test name

* Fix docstrings

* Fix init and update tests to import from correct path (keras-team#110)

* Add distribute support for tensorflow trainer.

* Revert the previous merge edit.

* Fix lint issue

* Address review comments.

* Add TPU strategy support

* Fix lint

---------

Co-authored-by: Francois Chollet <francois.chollet@gmail.com>
Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com>
Co-authored-by: divyasreepat <divyashreepathihalli@gmail.com>
Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>
Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Co-authored-by: Ian Stenbit <3072903+ianstenbit@users.noreply.github.com>
  • Loading branch information
7 people authored May 9, 2023
1 parent 7c76b35 commit f02520d
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 3 deletions.
58 changes: 58 additions & 0 deletions integration_tests/distribute_training_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import tensorflow as tf

from keras_core import layers
from keras_core import losses
from keras_core import models
from keras_core import metrics
from keras_core import optimizers
from keras_core.utils import rng_utils


def test_model_fit():

cpus = tf.config.list_physical_devices("CPU")
tf.config.set_logical_device_configuration(
cpus[0],
[
tf.config.LogicalDeviceConfiguration(),
tf.config.LogicalDeviceConfiguration(),
],
)

rng_utils.set_random_seed(1337)

strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1'])
with strategy.scope():
inputs = layers.Input((100,), batch_size=32)
x = layers.Dense(256, activation="relu")(inputs)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.BatchNormalization()(x)
outputs = layers.Dense(16)(x)
model = models.Model(inputs, outputs)

model.summary()

x = np.random.random((50000, 100))
y = np.random.random((50000, 16))
batch_size = 32
epochs = 5

model.compile(
optimizer=optimizers.SGD(learning_rate=0.001),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
# TODO(scottzhu): Find out where is the variable that is not created eagerly
# and break the usage of XLA.
jit_compile=False,
)
history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
)

print("History:")
print(history.history)

if __name__ == "__main__":
test_model_fit()
215 changes: 212 additions & 3 deletions keras_core/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,27 @@ def __init__(self):
self.test_function = None
self.predict_function = None

# Model must be created under scope of DistStrat it will be trained
# with.
if tf.distribute.has_strategy():
self._distribute_strategy = tf.distribute.get_strategy()
else:
self._distribute_strategy = None

self._distribute_reduction_method = None

@property
def distribute_strategy(self):
return self._distribute_strategy or tf.distribute.get_strategy()

@property
def distribute_reduction_method(self):
return self._distribute_reduction_method or "auto"

@distribute_reduction_method.setter
def distribute_reduction_method(self, value):
self._distribute_reduction_method = value

def train_step(self, data):
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)

Expand Down Expand Up @@ -83,7 +104,13 @@ def one_step_on_data(data):
def one_step_on_iterator(iterator):
"""Runs a single training step given a Dataset iterator."""
data = next(iterator)
return one_step_on_data(data)
outputs = self.distribute_strategy.run(
one_step_on_data, args=(data,))
outputs = reduce_per_replica(
outputs, self.distribute_strategy,
reduction=self.distribute_reduction_method
)
return outputs

if not self.run_eagerly:
train_function = tf.function(
Expand All @@ -110,7 +137,13 @@ def one_step_on_data(data):
def one_step_on_iterator(iterator):
"""Runs a single test step given a Dataset iterator."""
data = next(iterator)
return one_step_on_data(data)
outputs = self.distribute_strategy.run(
one_step_on_data, args=(data,))
outputs = reduce_per_replica(
outputs, self.distribute_strategy,
reduction=self.distribute_reduction_method
)
return outputs

if not self.run_eagerly:
test_function = tf.function(
Expand All @@ -137,7 +170,13 @@ def one_step_on_data(data):
def one_step_on_iterator(iterator):
"""Runs a single predict step given a Dataset iterator."""
data = next(iterator)
return one_step_on_data(data)
outputs = self.distribute_strategy.run(
one_step_on_data, args=(data,))
outputs = reduce_per_replica(
outputs, self.distribute_strategy,
reduction=self.distribute_reduction_method
)
return outputs

if not self.run_eagerly:
predict_function = tf.function(
Expand Down Expand Up @@ -430,3 +469,173 @@ def catch_stop_iteration(self):
)
self._current_iterator = None
self.data_adapter.on_epoch_end()


def reduce_per_replica(values, strategy, reduction):
"""Attempt to reduce the structure `values` to single values.
Given `values` (a `tf.Tensor` or a `PerReplica` structure),
which represents the values across all the replicas, `reduce_per_replica`
attempts to "reduce" those values and returns the corresponding structure
that represents only single values.
Currently, `reduce_per_replica` is only used for reducing the metric results
from `tf.distribute.Strategy.run()`. Depending on the underlying
`Strategy` implementation, `values` may be a `PerReplica` object,
which can be thought of as a collection of values across the replicas,
or a `tf.Tensor`, if the strategy has already conducted the reduction
for the downstream library.
There are five possible outcomes of reduction:
1) if the `values` is a structure of simple `tf.Tensor`s, meaning that
reduction is not actually needed, `reduce_per_replica` returns the
structure as-is.
2) else, if `reduction="auto"`, then the best reduction strategy is
chosen based on the current environment. This should only be used
for training cases (`fit()`).
3) else, if `reduction="first"`, then `reduce_per_replica`
returns the values of the first replica. This is used in the case of
training and evaluation, where `values` is expected to hold the same
value across the replicas as a result of `Strategy`'s synchronization
across the replicas.
`reduce_per_replica` does not synchronize the values.
4) else, if `reduction="sum"`, then `reduce_per_replica` returns the sum
of values for all replicas. This may be used in the custom training loop
case, where each replica contain different values which are not
synchronized.
5) else, if `reduction="concat"`, then `reduce_per_replica`
returns the concatenation of the values across the replicas, along the
axis of dimension 0. This is used in the inference case (`predict()`).
Args:
values: Structure of `PerReplica` objects or `tf.Tensor`s. `tf.Tensor`s
are returned as-is.
strategy: `tf.distribute.Strategy` object.
reduction: One of `"auto"`, `"first"`, `"concat"`, or `"sum"`.
`"auto"` will select `"first"` when used under a TPUStrategy, or
`"sum"` otherwise.
Returns:
Structure of `Tensor`s, representing the result of reduction.
Raises:
ValueError: if the reduction method is not supported.
"""

if reduction == "auto":
reduction = "sum" # Ignore TPU strategy which should default to "first"

def _reduce(v):
"""Reduce a single `PerReplica` object."""
if _collective_all_reduce_multi_worker(strategy):
if reduction == "concat":
return _multi_worker_concat(v, strategy)
elif reduction == "sum":
return strategy.reduce("SUM", v, axis=None)

if not _is_per_replica_instance(v):
return v
elif reduction == "first":
return strategy.experimental_local_results(v)[0]
elif reduction == "concat":
if _is_tpu_multi_host(strategy):
return _tpu_multi_host_concat(v, strategy)
else:
return concat(strategy.experimental_local_results(v))
elif reduction == "sum":
return tf.reduce_sum(strategy.experimental_local_results(v))
else:
raise ValueError(
'`reduction` must be "first", "concat", "sum", or "auto". '
f"Received: reduction={reduction}."
)

return tf.nest.map_structure(_reduce, values)


def _multi_worker_concat(v, strategy):
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
replicas = strategy.gather(v, axis=0)
# v might not have the same shape on different replicas
if _is_per_replica_instance(v):
shapes = tf.concat(
[
tf.expand_dims(tf.shape(single_value)[0], axis=0)
for single_value in v.values
],
axis=0,
)
all_shapes = strategy.gather(shapes, axis=0)
else:
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
all_shapes = strategy.gather(
tf.expand_dims(tf.shape(v)[0], axis=0), axis=0
)

replicas = tf.split(
replicas,
num_or_size_splits=all_shapes,
num=strategy.num_replicas_in_sync,
)
ordered_replicas = []
num_replicas_per_worker = len(strategy.extended.worker_devices)
for replica_id in range(num_replicas_per_worker):
ordered_replicas += replicas[replica_id::num_replicas_per_worker]
return concat(ordered_replicas)


def concat(tensors, axis=0):
"""Concats `tensor`s along `axis`."""
if isinstance(tensors[0], tf.SparseTensor):
return tf.sparse.concat(axis=axis, sp_inputs=tensors)
elif _is_scalar(tensors[0]):
return tf.stack(tensors, axis=axis)
else:
return tf.concat(tensors, axis=axis)


def _tpu_multi_host_concat(v, strategy):
"""Correctly order TPU PerReplica objects."""
replicas = strategy.experimental_local_results(v)
# When distributed datasets are created from Tensors / NumPy,
# TPUStrategy.experimental_distribute_dataset shards data in
# (Replica, Host) order, and TPUStrategy.experimental_local_results returns
# it in (Host, Replica) order.
num_replicas_per_host = strategy.extended.num_replicas_per_host
ordered_replicas = []
for replica_id in range(num_replicas_per_host):
ordered_replicas += replicas[replica_id::num_replicas_per_host]
return concat(ordered_replicas)


def _collective_all_reduce_multi_worker(strategy):
return (
isinstance(strategy, tf.distribute.MultiWorkerMirroredStrategy)
) and strategy.extended._in_multi_worker_mode()


def _is_per_replica_instance(obj):
return isinstance(obj, tf.distribute.DistributedValues) and isinstance(
obj, tf.__internal__.CompositeTensor
)


def _is_scalar(x):
return isinstance(x, (tf.Tensor, tf.Variable)) and x.shape.rank == 0


def _is_tpu_multi_host(strategy):
return _is_tpu_strategy(strategy) and strategy.extended.num_hosts > 1


def _is_tpu_strategy(strategy):
return _is_tpu_strategy_class(strategy.__class__)


def _is_tpu_strategy_class(clz):
def is_tpu_strat(k):
return k.__name__.startswith("TPUStrategy")
if is_tpu_strat(clz):
return True
return any(map(_is_tpu_strategy_class, clz.__bases__))

0 comments on commit f02520d

Please sign in to comment.