-
Notifications
You must be signed in to change notification settings - Fork 256
/
ops.py
6727 lines (5550 loc) · 221 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2022 The Mesh TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mesh TensorFlow ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import functools
import itertools
import operator
import os
import re
import gin
from mesh_tensorflow import utils
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow.compat.v1 as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops.gen_nn_ops import conv3d_backprop_input_v2
from tensorflow.python.ops.nn_ops import conv3d_backprop_filter_v2
Dimension = collections.namedtuple("Dimension", ["name", "size"])
def convert_to_dimension(d):
"""Converts input to a Dimension.
Args:
d: Dimension, tuple (string, int), or None.
Returns:
Dimension or None.
Raises:
ValueError: If d cannot be converted to a Dimension.
"""
if d is None:
return None
if isinstance(d, Dimension):
if not isinstance(d.name, str) or not isinstance(d.size, int):
raise ValueError("Bad dimension %s" % (d,))
return d
name, size = d
if isinstance(name, str) and isinstance(size, int):
return Dimension(name, size)
else:
raise ValueError("could not convert %s to Dimension" % (d,))
class Shape(object):
"""Shape of a Tensor or Mesh.
#### Examples
```python
# Create shape [4, 8] with names "x" and "y" respectively.
shape = mtf.Shape([mtf.Dimension("x", 4), mtf.Dimension("y", 8)])
```
"""
def __init__(self, dims):
"""Constructs a shape for a Tensor or Mesh.
Args:
dims: List-like of Dimensions.
Raises:
ValueError: If Dimensions are repeated.
"""
self._dims = [convert_to_dimension(d) for d in tuple(dims)]
if len(set(dims)) != len(dims):
raise ValueError("Shape must not have repeated dimensions %s" % dims)
@property
def dims(self):
return list(self._dims)
@property
def ndims(self):
return len(self._dims)
def __repr__(self):
return self.to_string
def __eq__(self, other):
return self.dims == other.dims
def __ne__(self, other):
return self.dims != other.dims
def __add__(self, other):
if isinstance(other, Shape):
other = other.dims
if isinstance(other, Dimension):
other = [other]
return Shape(self.dims + other)
def __sub__(self, other):
if other is None:
return self
if isinstance(other, Shape):
other = other.dims
if isinstance(other, Dimension):
if other not in self.dims:
raise ValueError(
"Subtracting a dimension from a shape requires that the shape"
" contain that dimension. Use shape - [dimension] for the case"
" where the dimension may not be in the shape.")
other = [other]
return Shape([d for d in self.dims if d not in other])
def __len__(self):
return len(self._dims)
def __getitem__(self, key):
return self._dims[key]
def __iter__(self):
return iter(self._dims)
@property
def to_integer_list(self):
return [d.size for d in self.dims]
@property
def size(self):
return list_product(self.to_integer_list)
@property
def to_string(self):
return "Shape[%s]" % ", ".join(
["%s=%d" % (d.name, d.size) for d in self.dims])
@property
def cumprod(self):
"""Cumulative product (exclusive) of Dimension sizes."""
return _cumprod(self.to_integer_list)[:-1]
def cumprod_to_tensor_axis(self, cumprod):
"""Maximum tensor axis i such that self.cumprod[i] == cumprod, or None."""
try:
return len(self) - 1 - self.cumprod[::-1].index(cumprod)
except ValueError:
return None
@property
def dimension_names(self):
return [d.name for d in self.dims]
def rename_dimension(self, old_name, new_name):
"""Returns a copy where one dimension is renamed."""
if old_name not in self.dimension_names:
raise ValueError("Shape %s does not have dimension named %s"
% (self, old_name))
return Shape(
[Dimension(new_name, d.size) if d.name == old_name else d
for d in self.dims])
def resize_dimension(self, name, new_size):
"""Returns a copy where one dimension has a different size."""
if name not in self.dimension_names:
raise ValueError("Shape %s does not have dimension named %s"
% (self, name))
return Shape(
[Dimension(name, new_size) if d.name == name else d
for d in self.dims])
def get_dim_by_name(self, name):
"""Get the Dimension with `name` from this shape.
Args:
name: a string, the name of the dimension we wish to get
Returns:
Dimension with `name`
Raises:
ValueError: if the shape does not contain a dimension with `name`
"""
for d in self.dims:
if d.name == name:
return d
raise ValueError("Dimension {} not found in {}.".format(
name, self.to_string))
def convert_to_shape(x):
"""Converts input to a Shape.
Args:
x: Shape, str, or None.
Returns:
Shape or None.
Raises:
ValueError: If x cannot be converted to a Shape.
"""
if x is None:
return None
if isinstance(x, Shape):
return x
if isinstance(x, str):
x = _parse_string_to_list_of_pairs(x, seconds_to_int=True)
return Shape(x)
class LayoutRules(object):
"""Represents layout of a computation.
#### Examples
```python
# Map "d_ff" and "heads" Tensor Dimensions to the "model" Mesh Dimension.
layout_rules = mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])
```
"""
def __init__(self, pairs):
"""Constructs a layout.
Args:
pairs: Set-like of string pairs (tensor_dim_name, mesh_dim_name).
"""
self._pairs = set(pairs)
def __repr__(self):
return "LayoutRules%s" % self._pairs
def tensor_dimension_to_mesh_axis(self, tensor_dimension, mesh_shape):
"""Mesh axis associated with tensor dimension (or None).
Args:
tensor_dimension: Dimension.
mesh_shape: Shape.
Returns:
Integer or None.
Raises:
ValueError: If one Tensor dimension maps to two mesh dimensions.
"""
val = [i for i, mesh_dimension in enumerate(mesh_shape)
if (tensor_dimension.name, mesh_dimension.name) in self._pairs]
if len(val) > 1:
raise ValueError(
"Tensor dimension maps to multiple mesh dimensions"
" tensor_dimension=%s mesh_shape=%s layout=%s"
% (tensor_dimension, mesh_shape, self._pairs))
return val[0] if val else None
def tensor_layout(self, tensor_shape, mesh_shape):
"""Computes TensorLayout given a Tensor Shape and a Mesh Shape.
Args:
tensor_shape: Shape.
mesh_shape: Shape.
Returns:
TensorLayout.
Raises:
ValueError: If two Tensor Dimensions map to the same Mesh Dimensions.
"""
ret = [self.tensor_dimension_to_mesh_axis(d, mesh_shape)
for d in tensor_shape]
not_nones = [a for a in ret if a is not None]
if len(not_nones) != len(set(not_nones)):
raise ValueError(
"Two Tensor Dimensions may not map to the same Mesh Dimension:"
" layout=%s tensor_shape=%s mesh_shape=%s " %
(self, tensor_shape, mesh_shape))
return TensorLayout(ret)
def mesh_dimension_name_to_tensor_dimension_names(self, mesh_dimension_name):
return [tdn for tdn, mdn in self._pairs if mdn == mesh_dimension_name]
def convert_to_layout_rules(x):
"""Converts input to a LayoutRules.
Args:
x: LayoutRules, str, or set-like of string pairs.
Returns:
LayoutRules.
"""
if isinstance(x, LayoutRules):
return x
if isinstance(x, str):
x = _parse_string_to_list_of_pairs(x)
return LayoutRules(x)
class TensorLayout(object):
"""Injective partial map between Tensor axes and Mesh axes.
TensorLayout is a tuple of optional integers with length tensor.ndims. Each
item is either a unique integer indicating the mesh axis over which that
tensor dimension is split or None, indicating that this tensor dimension is
not split.
#### Examples
```python
# Split first and last Tensor dimensions according to mesh axes 0 and 1.
tensor_layout = mtf.TensorLayout([0, None, 1])
```
"""
def __init__(self, tensor_axis_to_mesh_axis):
"""Creates a TensorLayout.
Args:
tensor_axis_to_mesh_axis: List-like where each element is an int or None.
"""
self._tensor_axis_to_mesh_axis = tuple(tensor_axis_to_mesh_axis)
def __eq__(self, other):
return self.tensor_axis_to_mesh_axis == other.tensor_axis_to_mesh_axis
def __ne__(self, other):
return self.tensor_axis_to_mesh_axis != other.tensor_axis_to_mesh_axis
def __repr__(self):
return "TensorLayout%s" % (self.tensor_axis_to_mesh_axis,)
def __len__(self):
return len(self._tensor_axis_to_mesh_axis)
def __getitem__(self, key):
return self._tensor_axis_to_mesh_axis[key]
def __iter__(self):
return iter(self._tensor_axis_to_mesh_axis)
@property
def tensor_axis_to_mesh_axis(self):
"""Converts to a tuple of optional integers."""
return self._tensor_axis_to_mesh_axis
@property
def is_fully_replicated(self):
"""Whether all tensor dimensions map to None."""
return self.tensor_axis_to_mesh_axis == (None,) * len(self)
def mesh_axis_to_tensor_axis(self, mesh_ndims):
"""For each mesh axis, which Tensor axis maps to it.
Args:
mesh_ndims: int.
Returns:
Tuple of optional integers, with length mesh_ndims.
"""
ta2ma = self._tensor_axis_to_mesh_axis
return tuple(
[ta2ma.index(mesh_axis) if mesh_axis in ta2ma else None
for mesh_axis in xrange(mesh_ndims)])
class Graph(object):
"""Mesh-TensorFlow graph."""
def __init__(self):
self._operations = []
self._trainable_variables = []
self._all_variables = []
# Maps a name used in the graph to the next id to use for that name.
self._names_in_use = {}
self.name_to_variable = {}
self.captured_variable_scope = tf.get_variable_scope()
def __repr__(self):
return self.to_string
@property
def operations(self):
return self._operations
@property
def trainable_variables(self):
return self._trainable_variables
@property
def all_variables(self):
return self._all_variables
@property
def to_string(self):
return "\n".join([op.to_string for op in self.operations])
def unique_name(self, name, mark_as_used=True):
"""Like tf.Graph.unique_name, returns a unique operation name for `name`.
Args:
name: The name for an operation.
mark_as_used: whether to mark this name as being used.
Returns:
A string to use as the name for the operation.
"""
scope_name = tf.get_variable_scope().name
if scope_name:
name = scope_name + "/" + name
# As in TensorFlow, treat names as case insensitive when deciding whether
# they are in use.
name_key = name.lower()
i = self._names_in_use.get(name_key, 0)
if mark_as_used:
self._names_in_use[name_key] = i + 1
if i > 0:
base_name_key = name_key
while name_key in self._names_in_use:
name_key = "%s_%d" % (base_name_key, i)
i += 1
if mark_as_used:
self._names_in_use[name_key] = 1
name = "%s_%d" % (name, i-1)
return name
@gin.configurable
def rewrite_stack_variables(self,
max_combined_variable_size=2 ** 29,
max_combined_slice_size=2 ** 27,
mesh_to_impl=None):
"""Rewrite the current graph to combine variables.
This helps speed up graph construction times in the case of large meshes
and large numbers of variables.
This function should be called after graph construction (it is called by
default in the Lowering constuctor).
When we find a set of variables with the same shape/dtype/etc, we replace
them with one StackedVariable and an "unstack" operation. The
StackedVariable has multiple master variables (so as to maintain checkpiont
compatibility), but only one slice variable per device. We point the inputs
of later operations to the outputs of the "unstack" operations, instead of
the outputs of the defunct single variables.
In order for variables to be combinable, they must be set in the same Assign
operation(s) - so it is necessary to call mtf.grouped_assign() from the
optimizer instead of many separate calls to mtf.assign(). The assign
operations get rewritten to set the appropriate stacked variables.
TODO(noam): Combining to larger sizes seems to cause errors on TPU.
debug this. Perhaps we should try to keep the combined master variables
on the same device.
Args:
max_combined_variable_size: an integer
max_combined_slice_size: an integer
mesh_to_impl: an optional dictionary from Mesh to MeshImpl
"""
# pylint: disable=protected-access
all_variables = self._all_variables
operations = self._operations
self._operations = []
self._all_variables = []
self._trainable_variables = []
# We can only stack varaibles which share the same set of assignment
# operations.
var_to_assign_ops = collections.defaultdict(str)
for op in operations:
if isinstance(op, Assign):
for v in op._variables:
var_to_assign_ops[v] += op.name + ", "
# Two variables with the same "key" can be stacked together.
def var_key(v):
return str([v.mesh,
v.shape,
str(v.dtype.__dict__),
v.trainable,
var_to_assign_ops[v]])
key_to_vars = collections.defaultdict(collections.deque)
for v in all_variables:
key_to_vars[var_key(v)].append(v)
individual_to_stacked = {}
for op in operations:
if isinstance(op, StackedVariable):
raise ValueError("stack_variables() should not be called twice.")
elif isinstance(op, Variable):
if op.name in individual_to_stacked:
continue
similar_vars = key_to_vars[var_key(op)]
num_to_stack = len(similar_vars)
if max_combined_variable_size is not None:
num_to_stack = min(
num_to_stack, max_combined_variable_size // op.shape.size)
if mesh_to_impl is not None:
mesh_impl = mesh_to_impl[op.mesh]
if mesh_impl.size == 1:
num_to_stack = 1 # no point in stacking for single processors.
slice_size = mesh_impl.slice_size(op.shape)
num_to_stack = min(
num_to_stack, max_combined_slice_size // slice_size)
num_to_stack = max(1, num_to_stack)
to_stack = [similar_vars.popleft() for _ in xrange(num_to_stack)]
if num_to_stack > 1:
stacked_var = StackedVariable(to_stack)
stack_dim = stacked_var.shape.dims[0]
unstacked = unstack(stacked_var.outputs[0], stack_dim)
unstack_op = unstacked[0].operation
# replace the output Tensors of the unstack operation with the
# Tensors which were the outputs of the original variable operations.
# Later operations use these Tensors as inputs.
unstack_op._outputs = [v.outputs[0] for v in to_stack]
for t in unstack_op._outputs:
t._operation = unstack_op
for idx, v in enumerate(to_stack):
individual_to_stacked[v.name] = stacked_var, idx
else:
assert op == to_stack[0]
self._operations.append(op)
self._all_variables.append(op)
if op.trainable:
self._trainable_variables.append(op)
else:
if isinstance(op, Assign):
# Rewrite the grouped assignment to stack up the values and then
# assign to the stacked variables.
new_variables = []
new_values = []
var_to_val = dict(zip([v.name for v in op._variables], op._inputs))
for var, val in zip(op._variables, op._inputs):
if var.name in individual_to_stacked:
stacked_var, pos = individual_to_stacked[var.name]
if pos == 0:
vals = [var_to_val[n] for n in stacked_var.original_names]
new_variables.append(stacked_var)
new_values.append(
stack(vals, stacked_var.shape.dims[0].name, 0))
else:
new_variables.append(var)
new_values.append(val)
op._variables = new_variables
op._inputs = new_values
self._operations.append(op)
# pylint: enable=protected-access
def combine_assignments(self, assignments):
"""Rewrite the current graph to combine "Assign" operations.
Combine similar Assign operations into grouped Assign operations.
This is useful when using the rewrite_stack_variables() optimization,
since variables can only be stacked if they are present in the same set
of Assign operations.
This function takes a list of Assign operations and returns a possibly
shorter list of Assign operations. The input Assignment operations
are removed from the graph and become invalid.
Args:
assignments: a list of Assign objects
Returns:
a list of Assign objects
"""
group_by_fn = collections.defaultdict(list)
for a in assignments:
if not isinstance(a, Assign):
raise ValueError("ops should be instances of mtf.Assign")
group_by_fn[a.assign_fn].append(a)
assignments_set = set(assignments)
self._operations = [
op for op in self._operations if op not in assignments_set]
ret = []
for fn, ops in six.iteritems(group_by_fn):
variables = []
values = []
for a in ops:
variables.extend(a.variables)
values.extend(a.inputs)
ret.append(Assign(variables, values, fn))
return ret
def make_variables_untrainable(self, variables):
"""Makes the variables untrainable.
Args:
variables: a list of Variable objects
"""
variables = set(variables)
for v in variables:
v._trainable = False # pylint: disable=protected-access
self._trainable_variables = [
v for v in self._trainable_variables if v not in variables
]
def clone_operations(self, ops, input_mapping):
"""Clone a portion of the graph, but with different inputs.
The differnt inputs are specified by the `input_mapping` dictionary, which
maps from input Tensor in the original operations to input Tensor in the
cloned operations. If an original operation uses an external input that is
not in `input_mapping`, then the original input is used for the cloned
operation.
The function returns a list of cloned operations as well an
`extended_mapping` dictionary which consits of the union of the input
mapping and the map from original-operation-output to
cloned-operation-output.
Variables and Random operations are not cloned.
Args:
ops: a list of operations
input_mapping: a dictionary from Tensor to Tensor
Returns:
cloned_operations: a list of operations
extended_mapping: a dictionary from Tensor to Tensor
"""
# pylint: disable=protected-access
mapping = copy.copy(input_mapping)
prev_num_operations = len(self.operations)
for op in ops:
if isinstance(op, Variable):
continue
if isinstance(op, RandomOperation):
# The random values will be copied instead of recomputed.
# TODO(noam): Use stateless_random to allow for recompute.
tf.logging.warning(
"Not cloning random operation, so as to ensure the same values.")
continue
new_op = copy.copy(op)
# new_op._name = self.unique_name(op.name)
self._operations.append(new_op)
new_op._inputs = [mapping.get(t, t) for t in op._inputs]
new_op._outputs = []
for i, t in enumerate(op.outputs):
new_t = Tensor(new_op, t.shape, t.dtype, t.name, i)
new_t.usable = True
new_op._outputs.append(new_t)
if t in mapping:
raise ValueError(
"input mapping should not contain any of the outputs"
" of the cloned operations")
mapping[t] = new_t
# pylint: enable=protected-access
return self.operations[prev_num_operations:], mapping
def capture_operations(self, fn):
"""Run a function and capture the list of operations it generates.
Args:
fn: a function taking no arguments
Returns:
fn_output: the function output
captured_operations: a list of Operation
"""
n = len(self.operations)
y = fn()
return y, self.operations[n:]
class Lowering(object):
"""Lowering of a Graph from Mesh-TensorFlow to TensorFlow.
#### Examples
Below we form a Graph with one Tensor and lower it to recover the original
tf.Tensor.
```python
from mesh_tensorflow import placement_mesh_impl
graph = mtf.Graph()
mesh = mtf.Mesh(graph, "my_mesh")
inputs = tf.constant(0.)
mtf_inputs = mtf.import_tf_tensor(mesh,
inputs=inputs,
shape=mtf.Shape([]))
mesh_impl = placement_mesh_impl.PlacementMeshImpl(
shape=[], layout={}, devices=[""])
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
outputs = lowering.export_to_tf_tensor(mtf_inputs) # tf.constant(0.)
```
"""
def __init__(self, graph, mesh_to_impl, autostack=True, log_file=None):
"""Creates a Lowering of a Graph.
Args:
graph: Graph.
mesh_to_impl: {Mesh: MeshImpl}. Keys are the Mesh's in the graph and
their values are MeshImpl's, which map Tensor Dimension names to
Mesh Dimension names.
autostack: a boolean. If True, then the graph gets rewritten to
reduce the number of variables (see rewrite_stack_variables()).
This is a helpful performance optimization for large meshes.
For more fine-grained control, you can call
graph.rewrite_stack_variables() yourself before creating the Lowering.
log_file: an optional string. If provided, information about the variables
and operations will also be logged to this file.
"""
# tf.logging.info("LOWERING GRAPH:\n%s" % graph.to_string)
self.mesh_to_impl = mesh_to_impl # {Mesh: MeshImpl}
self.graph = graph
if autostack:
self.autostack()
self._counters = []
self.tensors = {} # {Tensor: Mesh.LaidOutTensor}
self.operations = {} # {Operation: tf.Operation}
self.variables = {} # {Variable: LaidOutVariable}
for op in graph.operations:
# tf.logging.info("Lowering operation %s" % op.to_string)
with tf.name_scope(op.name):
op.lower(self)
for out in op.outputs:
self.add_counter(
"output/%s" % type(op).__name__, self.laid_out_size(out))
self.add_counter("output_unique/%s" % type(op).__name__, out.size)
def log_info(f=None):
"""Log the variables and operations, possibly to file `f` as well."""
log_variable_sizes(
graph.trainable_variables,
"Trainable Variables",
verbose=True,
mesh_to_impl=self.mesh_to_impl,
log_file=f)
log_variable_sizes(
graph.all_variables,
"All Variables",
verbose=False,
mesh_to_impl=self.mesh_to_impl,
log_file=f)
_log_info_also_to_file(
"Counters:\n" + pretty_print_counters(self._counters), log_file=f)
if log_file:
with tf.io.gfile.GFile(log_file, mode="w") as f:
log_info(f)
else:
log_info()
def mesh_impl(self, m):
if not isinstance(m, Mesh):
m = m.mesh
return self.mesh_to_impl[m]
def export_to_tf_tensor(self, x):
"""Turn a Tensor into a tf.Tensor.
Args:
x: Tensor.
Returns:
tf.Tensor.
"""
mesh_impl = self.mesh_impl(x)
return mesh_impl.export_to_tf_tensor(
x, self.tensors[x].to_laid_out_tensor())
def lowered_operation(self, op):
return self.operations[op]
def copy_masters_to_slices(self):
if os.environ.get("MTF_SEQUENCE_MODE", "") == "1":
mesh_impls = [impl for impl in six.itervalues(self.mesh_to_impl)]
assert len(mesh_impls) == 1
mesh_impl = mesh_impls[0]
return mesh_impl.copy_master_to_slice_ops[-1]
else:
return tf.group(
[v.copy_master_to_slices for v in six.itervalues(self.variables)])
def copy_slices_to_masters(self):
return tf.group(
[v.copy_slices_to_master for v in six.itervalues(self.variables)])
def add_counter(self, key, value):
assert isinstance(value, int)
self._counters.append((key, value))
@property
def counters(self):
return self._counters
def laid_out_size(self, tensor):
"""Total size of all slices.
Args:
tensor: Tensor.
Returns:
int.
"""
return self.mesh_impl(tensor).laid_out_size(tensor.shape)
def set_tensor_lowering(self, tensor, laid_out_tensor):
self.verify_slice_shapes(tensor, laid_out_tensor)
self.tensors[tensor] = laid_out_tensor
def verify_slice_shapes(self, tensor, laid_out_tensor):
mesh_impl = self.mesh_impl(tensor)
correct_shape = mesh_impl.slice_shape(tensor.shape)
actual_shape = laid_out_tensor.slice_shape
if actual_shape != correct_shape:
raise ValueError(
"Wrong slice shape: correct_shape = %s actual shape = %s"
% (correct_shape, actual_shape))
def autostack(self):
"""Rewrite graph to combine similarly-shaped variables (faster startup)."""
num_slices = 0
for v in self.graph.all_variables:
num_slices += self.mesh_to_impl[v.mesh].size
if num_slices >= 2 ** 16:
# Startup times are slow with lots of variable slices.
# Perform more aggressive stacking
max_combined_slice_size = 2 ** 27
else:
# Stacking hurts memory utilization - only stack small variables.
max_combined_slice_size = 2 ** 16
self.graph.rewrite_stack_variables(
mesh_to_impl=self.mesh_to_impl,
max_combined_slice_size=max_combined_slice_size)
class Mesh(object):
"""A placeholder with no functionality.
A Graph is built with each Tensor assigned to a Mesh. The Mesh does not
know its shape or its implementation.
A Lowering assigns each Mesh to a MeshImpl.
"""
def __init__(self, graph, name, variable_placer=None):
self._graph = graph
self._name = name
self._variable_placer = variable_placer
@property
def graph(self):
return self._graph
@property
def variable_placer_fn(self):
if self._variable_placer is not None:
return self._variable_placer.device_function
else:
return "cpu:0"
class MeshImpl(object):
"""Implementation of a Mesh.
Unlike Mesh, MeshImpl carries Shape and LayoutRules. Subclasses of MeshImpl
also carry devices.
#### Examples
```python
shape = mtf.Shape([mtf.Dimension("batch", 4),
mtf.Dimension("model", 8)])
layout_rules = mtf.LayoutRules([("batch", "batch"),
("d_ff", "model"),
("heads", "model")])
mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
```
"""
def __init__(self, shape, layout_rules):
"""Creates a mesh implementation.
Args:
shape: Shape.
layout_rules: LayoutRules.
"""
self._shape = convert_to_shape(shape)
self._layout_rules = convert_to_layout_rules(layout_rules)
@property
def shape(self):
return self._shape
@property
def ndims(self):
return len(self._shape)
@property
def layout_rules(self):
return self._layout_rules
@property
def size(self):
return self.shape.size
@property
def supports_control_dependencies(self):
return True
def tensor_dimension_to_mesh_axis(self, tensor_dimension):
"""Mesh axis associated with tensor dimension (or None).
Args:
tensor_dimension: Dimension.
Returns:
int or None.
"""
return self.layout_rules.tensor_dimension_to_mesh_axis(
tensor_dimension, self.shape)
def tensor_layout(self, arg):
"""Compute TensorLayout for a Tensor or a Shape.
Args:
arg: Tensor or Shape.
Returns:
TensorLayout.
"""
if isinstance(arg, Tensor):
arg = arg.shape
return self.layout_rules.tensor_layout(arg, self.shape)
def mesh_axis_to_cumprod(self, tensor_shape):
"""For each mesh axis, give the product of previous tensor axes.
Args:
tensor_shape: Shape.
Returns:
list with length self.ndims where each element is an integer or None.
"""
tensor_layout = self.tensor_layout(tensor_shape)
ma2ta = tensor_layout.mesh_axis_to_tensor_axis(self.ndims)
ta2cumprod = tensor_shape.cumprod
return [None if ta is None else ta2cumprod[ta] for ta in ma2ta]
def slice_shape(self, tensor_shape):
"""Shape of each slice of the Tensor.
Args:
tensor_shape: Shape.
Returns:
list of integers with length tensor_shape.ndims.
Raises:
ValueError: If a Tensor dimension is not divisible by the corresponding
Mesh dimension.
"""
tensor_layout = self.tensor_layout(tensor_shape)
ret = []
for tensor_dim, mesh_axis in zip(
tensor_shape, tensor_layout.tensor_axis_to_mesh_axis):
if mesh_axis is None:
ret.append(tensor_dim.size)
else:
mesh_dim = self.shape[mesh_axis]
if tensor_dim.size % mesh_dim.size != 0:
raise ValueError(
"Tensor dimension size not divisible by mesh dimension size:"
" tensor_shape=%s tensor_layout=%s"
% (tensor_shape, tensor_layout))
ret.append(tensor_dim.size // mesh_dim.size)
return ret
def slice_begin(self, tensor_shape, pnum):
"""Begin position for the tensor slice for the given processor.
Args:
tensor_shape: Shape.
pnum: int <= self.size.
Returns:
list of integers with length tensor_shape.ndims.
"""
tensor_layout = self.tensor_layout(tensor_shape)
coordinates = pnum_to_processor_coordinates(self.shape, pnum)
ret = []
for dim_size, mesh_axis in zip(
tensor_shape.to_integer_list, tensor_layout.tensor_axis_to_mesh_axis):
if mesh_axis is None:
ret.append(0)
else: