diff --git a/integration_tests/distribute_training_test.py b/integration_tests/distribute_training_test.py new file mode 100644 index 0000000000..74460e50ce --- /dev/null +++ b/integration_tests/distribute_training_test.py @@ -0,0 +1,58 @@ +import numpy as np +import tensorflow as tf + +from keras_core import layers +from keras_core import losses +from keras_core import models +from keras_core import metrics +from keras_core import optimizers +from keras_core.utils import rng_utils + + +def test_model_fit(): + + cpus = tf.config.list_physical_devices("CPU") + tf.config.set_logical_device_configuration( + cpus[0], + [ + tf.config.LogicalDeviceConfiguration(), + tf.config.LogicalDeviceConfiguration(), + ], + ) + + rng_utils.set_random_seed(1337) + + strategy = tf.distribute.MirroredStrategy(['CPU:0', 'CPU:1']) + with strategy.scope(): + inputs = layers.Input((100,), batch_size=32) + x = layers.Dense(256, activation="relu")(inputs) + x = layers.Dense(256, activation="relu")(x) + x = layers.Dense(256, activation="relu")(x) + x = layers.BatchNormalization()(x) + outputs = layers.Dense(16)(x) + model = models.Model(inputs, outputs) + + model.summary() + + x = np.random.random((50000, 100)) + y = np.random.random((50000, 16)) + batch_size = 32 + epochs = 5 + + model.compile( + optimizer=optimizers.SGD(learning_rate=0.001), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + # TODO(scottzhu): Find out where is the variable that is not created eagerly + # and break the usage of XLA. + jit_compile=False, + ) + history = model.fit( + x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2 + ) + + print("History:") + print(history.history) + +if __name__ == "__main__": + test_model_fit() \ No newline at end of file diff --git a/keras_core/backend/tensorflow/trainer.py b/keras_core/backend/tensorflow/trainer.py index 00982bc5fa..ebd5137583 100644 --- a/keras_core/backend/tensorflow/trainer.py +++ b/keras_core/backend/tensorflow/trainer.py @@ -19,6 +19,27 @@ def __init__(self): self.test_function = None self.predict_function = None + # Model must be created under scope of DistStrat it will be trained + # with. + if tf.distribute.has_strategy(): + self._distribute_strategy = tf.distribute.get_strategy() + else: + self._distribute_strategy = None + + self._distribute_reduction_method = None + + @property + def distribute_strategy(self): + return self._distribute_strategy or tf.distribute.get_strategy() + + @property + def distribute_reduction_method(self): + return self._distribute_reduction_method or "auto" + + @distribute_reduction_method.setter + def distribute_reduction_method(self, value): + self._distribute_reduction_method = value + def train_step(self, data): x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) @@ -83,7 +104,13 @@ def one_step_on_data(data): def one_step_on_iterator(iterator): """Runs a single training step given a Dataset iterator.""" data = next(iterator) - return one_step_on_data(data) + outputs = self.distribute_strategy.run( + one_step_on_data, args=(data,)) + outputs = reduce_per_replica( + outputs, self.distribute_strategy, + reduction=self.distribute_reduction_method + ) + return outputs if not self.run_eagerly: train_function = tf.function( @@ -110,7 +137,13 @@ def one_step_on_data(data): def one_step_on_iterator(iterator): """Runs a single test step given a Dataset iterator.""" data = next(iterator) - return one_step_on_data(data) + outputs = self.distribute_strategy.run( + one_step_on_data, args=(data,)) + outputs = reduce_per_replica( + outputs, self.distribute_strategy, + reduction=self.distribute_reduction_method + ) + return outputs if not self.run_eagerly: test_function = tf.function( @@ -137,7 +170,13 @@ def one_step_on_data(data): def one_step_on_iterator(iterator): """Runs a single predict step given a Dataset iterator.""" data = next(iterator) - return one_step_on_data(data) + outputs = self.distribute_strategy.run( + one_step_on_data, args=(data,)) + outputs = reduce_per_replica( + outputs, self.distribute_strategy, + reduction=self.distribute_reduction_method + ) + return outputs if not self.run_eagerly: predict_function = tf.function( @@ -430,3 +469,173 @@ def catch_stop_iteration(self): ) self._current_iterator = None self.data_adapter.on_epoch_end() + + +def reduce_per_replica(values, strategy, reduction): + """Attempt to reduce the structure `values` to single values. + + Given `values` (a `tf.Tensor` or a `PerReplica` structure), + which represents the values across all the replicas, `reduce_per_replica` + attempts to "reduce" those values and returns the corresponding structure + that represents only single values. + + Currently, `reduce_per_replica` is only used for reducing the metric results + from `tf.distribute.Strategy.run()`. Depending on the underlying + `Strategy` implementation, `values` may be a `PerReplica` object, + which can be thought of as a collection of values across the replicas, + or a `tf.Tensor`, if the strategy has already conducted the reduction + for the downstream library. + + There are five possible outcomes of reduction: + + 1) if the `values` is a structure of simple `tf.Tensor`s, meaning that + reduction is not actually needed, `reduce_per_replica` returns the + structure as-is. + 2) else, if `reduction="auto"`, then the best reduction strategy is + chosen based on the current environment. This should only be used + for training cases (`fit()`). + 3) else, if `reduction="first"`, then `reduce_per_replica` + returns the values of the first replica. This is used in the case of + training and evaluation, where `values` is expected to hold the same + value across the replicas as a result of `Strategy`'s synchronization + across the replicas. + `reduce_per_replica` does not synchronize the values. + 4) else, if `reduction="sum"`, then `reduce_per_replica` returns the sum + of values for all replicas. This may be used in the custom training loop + case, where each replica contain different values which are not + synchronized. + 5) else, if `reduction="concat"`, then `reduce_per_replica` + returns the concatenation of the values across the replicas, along the + axis of dimension 0. This is used in the inference case (`predict()`). + + Args: + values: Structure of `PerReplica` objects or `tf.Tensor`s. `tf.Tensor`s + are returned as-is. + strategy: `tf.distribute.Strategy` object. + reduction: One of `"auto"`, `"first"`, `"concat"`, or `"sum"`. + `"auto"` will select `"first"` when used under a TPUStrategy, or + `"sum"` otherwise. + + Returns: + Structure of `Tensor`s, representing the result of reduction. + + Raises: + ValueError: if the reduction method is not supported. + """ + + if reduction == "auto": + reduction = "sum" # Ignore TPU strategy which should default to "first" + + def _reduce(v): + """Reduce a single `PerReplica` object.""" + if _collective_all_reduce_multi_worker(strategy): + if reduction == "concat": + return _multi_worker_concat(v, strategy) + elif reduction == "sum": + return strategy.reduce("SUM", v, axis=None) + + if not _is_per_replica_instance(v): + return v + elif reduction == "first": + return strategy.experimental_local_results(v)[0] + elif reduction == "concat": + if _is_tpu_multi_host(strategy): + return _tpu_multi_host_concat(v, strategy) + else: + return concat(strategy.experimental_local_results(v)) + elif reduction == "sum": + return tf.reduce_sum(strategy.experimental_local_results(v)) + else: + raise ValueError( + '`reduction` must be "first", "concat", "sum", or "auto". ' + f"Received: reduction={reduction}." + ) + + return tf.nest.map_structure(_reduce, values) + + +def _multi_worker_concat(v, strategy): + """Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" + replicas = strategy.gather(v, axis=0) + # v might not have the same shape on different replicas + if _is_per_replica_instance(v): + shapes = tf.concat( + [ + tf.expand_dims(tf.shape(single_value)[0], axis=0) + for single_value in v.values + ], + axis=0, + ) + all_shapes = strategy.gather(shapes, axis=0) + else: + # v is a tensor. This may happen when, say, we have 2x1 multi-worker. + all_shapes = strategy.gather( + tf.expand_dims(tf.shape(v)[0], axis=0), axis=0 + ) + + replicas = tf.split( + replicas, + num_or_size_splits=all_shapes, + num=strategy.num_replicas_in_sync, + ) + ordered_replicas = [] + num_replicas_per_worker = len(strategy.extended.worker_devices) + for replica_id in range(num_replicas_per_worker): + ordered_replicas += replicas[replica_id::num_replicas_per_worker] + return concat(ordered_replicas) + + +def concat(tensors, axis=0): + """Concats `tensor`s along `axis`.""" + if isinstance(tensors[0], tf.SparseTensor): + return tf.sparse.concat(axis=axis, sp_inputs=tensors) + elif _is_scalar(tensors[0]): + return tf.stack(tensors, axis=axis) + else: + return tf.concat(tensors, axis=axis) + + +def _tpu_multi_host_concat(v, strategy): + """Correctly order TPU PerReplica objects.""" + replicas = strategy.experimental_local_results(v) + # When distributed datasets are created from Tensors / NumPy, + # TPUStrategy.experimental_distribute_dataset shards data in + # (Replica, Host) order, and TPUStrategy.experimental_local_results returns + # it in (Host, Replica) order. + num_replicas_per_host = strategy.extended.num_replicas_per_host + ordered_replicas = [] + for replica_id in range(num_replicas_per_host): + ordered_replicas += replicas[replica_id::num_replicas_per_host] + return concat(ordered_replicas) + + +def _collective_all_reduce_multi_worker(strategy): + return ( + isinstance(strategy, tf.distribute.MultiWorkerMirroredStrategy) + ) and strategy.extended._in_multi_worker_mode() + + +def _is_per_replica_instance(obj): + return isinstance(obj, tf.distribute.DistributedValues) and isinstance( + obj, tf.__internal__.CompositeTensor + ) + + +def _is_scalar(x): + return isinstance(x, (tf.Tensor, tf.Variable)) and x.shape.rank == 0 + + +def _is_tpu_multi_host(strategy): + return _is_tpu_strategy(strategy) and strategy.extended.num_hosts > 1 + + +def _is_tpu_strategy(strategy): + return _is_tpu_strategy_class(strategy.__class__) + + +def _is_tpu_strategy_class(clz): + def is_tpu_strat(k): + return k.__name__.startswith("TPUStrategy") + if is_tpu_strat(clz): + return True + return any(map(_is_tpu_strategy_class, clz.__bases__))