-
Notifications
You must be signed in to change notification settings - Fork 4
/
base.py
1143 lines (988 loc) · 46.6 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Base interfaces.
The core interfaces for the user are:
* :class:`Module` and using :func:`make_layer` to directly create a RETURNN layer via dict.
We recommend using this only for directly wrapping RETURNN layers
and not for any higher-level logic,
which should be done as a :class:`Module`.
* :class:`Module`, to write PyTorch-style code, which acts like a subnetwork.
We recommend using this as the base interface
for any higher-level interfaces
(such as a generic decoder interface).
Use :func:`scoped` as a decorator for the ``__call__`` method.
Instances of both objects can be called directly,
and return instances of type :class:`Tensor`,
which can be thought of as analogue to :class:`torch.Tensor` or :class:`tf.Tensor`.
Use ``x.mark_as_loss()`` to mark some output (layer ref) as a loss.
The root network should be a :class:`Module`,
and then you can use ``make_root_net_dict()``
to get the network dict.
Code example::
class Network(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(nn.FeatureDim("lstm-out", 1024))
def __call__(self, x: nn.Tensor) -> nn.Tensor:
y = self.lstm(x)
return y
net = Network()
net_dict = make_root_net_dict(net, "data")
---
Code conventions:
- Usual, as in RETURNN, PEP8, 2-space indents, 120 char line limit.
- Pure interface classes are prefixed with `I`.
(`Module` is an exception because this is made analogue to PyTorch).
"""
from __future__ import annotations
import numpy
from typing import Dict, Any, Optional, List, Tuple, Union, Set, Sequence, Iterable, Type
import itertools
import contextlib
from weakref import WeakKeyDictionary
# Some imports are not used here, but imported to make them available in the `nn` namespace.
# noinspection PyUnresolvedReferences
from returnn.tensor import (
Dim,
Tensor as Data,
ControlFlowContext,
batch_dim,
single_step_dim,
)
# noinspection PyUnresolvedReferences
from returnn.datasets.util.vocabulary import Vocabulary
from returnn.tensor.marked_dim import MarkedDim as _MarkedDim
import tree
from .. import nn
LayerDictRaw = Dict[str, Any]
TensorRefRaw = str
NetDictRaw = Dict[str, LayerDictRaw]
RawTensorTypes = Union[int, float, complex, numpy.number, numpy.ndarray, bool, str]
OutShapeType = Union[Set[Union[Dim, _MarkedDim]], tuple, list]
# https://returnn.readthedocs.io/en/latest/configuration_reference/behavior_version.html
# Need Dim.is_equal to be more restrictive (v16).
min_returnn_behavior_version = 16
class Tensor:
"""
Refers to a layer in RETURNN.
An instance of this class can be treated very much like a tensor.
It supports all the common unary and binary math operations such as addition.
This is the intended view point for the user,
to treat instances of this class like a tensor.
You do not create instances of this object explicitly
but they are created via any of the standard functions
like :func:`zeros` etc. or any :func:`Module`,
or via :func:`make_layer` for directly wrapping some RETURNN layer,
or via :func:`get_extern_data` for external data.
"""
def __init__(
self,
*,
name_ctx: nn.NameCtx,
data: Optional[Data] = None,
layer_dict: Optional[LayerDictRaw] = None,
is_ref: bool = False,
):
"""
:param name_ctx: this defines the name of the layer itself
:param data: Data template describing the shape and dtype
and other meta information on the tensor (layer output)
:param is_ref: in RETURNN, there can be references to special layers, like "data:..." or "prev:...",
which are not layers themselves, i.e. we do not have a layer dict for them.
"""
# It will be returnn.tensor.Tensor.raw_tensor, thus named raw_tensor here now.
self.raw_tensor = name_ctx
# Do not assign name_ctx.tensor yet because we potentially could raise exceptions later.
assert name_ctx.tensor is None
assert name_ctx.layer_dict is None
if is_ref:
assert layer_dict is None
else: # not is_ref (default)
assert layer_dict is not None
# Note that the following code can potentially raise user errors.
if not data:
data = _data_from_layer_dict(layer_dict, tensor=self)
else:
data = data.copy()
data.control_flow_ctx = nn.NameCtx.inner_control_flow()
if data.have_batch_axis() and not data.batch:
# You could say this is a bug of RETURNN. Or at least RETURNN is just incomplete here.
# RETURNN usually would fix that later when the layer is actually created,
# but we don't do that here.
# We can still try to look at dependencies and use those batch info.
batches = []
for dep in self.raw_tensor.get_tensor_dependencies(_extra_layer_dict=layer_dict):
if dep.tensor is not None and dep.tensor.data.batch and dep.tensor.data.batch not in batches:
batches.append(dep.tensor.data.batch)
if batches:
from returnn.tf.util.data import BatchInfo
data.batch = BatchInfo.get_common_batch_info(batches)
elif name_ctx.root.global_batch:
data.batch = name_ctx.root.global_batch
self.data = data
name_ctx.layer_dict = layer_dict
name_ctx.tensor = self
def __repr__(self):
parts = [self.__class__.__name__, self.raw_tensor.get_abs_name_repr()]
if not hasattr(self, "data"):
return f"<{' '.join(parts)} uninitialized>"
if self.data:
parts.append("[%s]" % ",".join(self.data.get_batch_axes_short_description()))
if nn.is_debug_eager_mode_enabled():
if self.data.placeholder is None:
parts.append("<tf.Tensor: None>")
else:
parts.append(repr(self.data.placeholder))
if self.raw_tensor.layer_dict:
parts.append(
f"via "
+ repr(
self.raw_tensor.module if self.raw_tensor.module else self.raw_tensor.layer_dict.get("class", "?")
)
)
if self.data and self.data.control_flow_ctx:
parts.append(f"ctx={self.data.control_flow_ctx.repr_inner()}")
return f"<{' '.join(parts)}>"
@property
def dims_set(self) -> Set[Dim]:
"""
:return: shape, as a set of dims.
The order must not play a role
(RETURNN principles: https://github.com/rwth-i6/returnn/wiki/RETURNN-principles),
thus return a set.
Note that we do not include any implicit dims here.
Also see :func:`verify_out_shape` and https://github.com/rwth-i6/returnn/issues/1153.
"""
return set(self.data.dim_tags)
@property
def dims(self) -> Tuple[Dim, ...]:
"""
:return: ordered dims.
Note that usually the order should never matter.
For some functions like nn.constant or nn.random_...,
we currently need a specific order,
and often we want to copy the order from some other tensor.
This property dims is supposed to be used for such functions.
Note that the rtype here could potentially change at some point
to a ref-type which just indicates to reuse the same order of this tensor.
So you should not rely on the rtype here
and make any direct use of the returned value,
except of passing it to functions like nn.constant.
https://github.com/rwth-i6/returnn_common/issues/138
"""
return self.data.dims
def remaining_dims(self, remove: Optional[Union[nn.Dim, Sequence[nn.Dim]]] = None) -> List[Dim]:
"""
:return: ordered batch dims
"""
batch_dims = list(self.dims)
if not remove:
pass
elif isinstance(remove, nn.Dim):
batch_dims.remove(remove)
else:
for remove_ in remove:
batch_dims.remove(remove_)
return batch_dims
@property
def dtype(self) -> str:
"""
:return: data type (e.g. "float32")
"""
return self.data.dtype
@property
def feature_dim(self) -> Optional[Dim]:
"""
:return: feature dim
"""
if self.data.feature_dim_axis is None:
return None
return self.data.dims[self.data.feature_dim_axis]
@property
def sparse_dim(self) -> Optional[Dim]:
"""
:return: sparse dim
"""
return self.data.sparse_dim
def verify_out_shape(self, out_shape: OutShapeType):
"""
Verify out_shape via :func:`Data.verify_out_shape`.
This does not add out_shape to the layer dict as we already have that automatically.
Thus, this is purely for verification here on returnn-common side.
Also, we are more relaxed here. We currently ignore missing implicit dims.
https://github.com/rwth-i6/returnn/issues/1153
:return: self, such that you can write this as a chained op
:rtype: Tensor
"""
self.data.verify_out_shape(out_shape, allow_missing_implicit_dims=True)
return self
def mark_as_loss(
self,
name: str,
*,
scale: Optional[float] = 1.0,
as_error: bool = False,
use_normalized_loss: bool = False,
use_flatten_frames: bool = True,
custom_inv_norm_factor: Optional[nn.Tensor] = None,
) -> Tensor:
"""
Mark this as a loss.
This has the effect that it is specially handled by RETURNN.
Specifically, the optimizer can use it in training,
and it is used for reporting per batch or per epoch,
and for learning rate scheduling.
This currently uses :class:`AsIsLoss` in RETURNN
but this is an implementation detail and might change.
:param name: name of the loss. this name is used for reporting by RETURNN, and also for LR scheduling.
:param scale: scale the loss by this factor for the training optimizer
(but not for any reporting). setting to 0.0 has the effect that this loss is not used by the optimizer.
:param as_error: if True, this loss is reported as an error instead of a loss,
and not used by the training optimizer.
This is by convention sth like the frame-error or edit-distance, and usually not differentiable anyway.
:param bool use_flatten_frames: If True, will use :func:`returnn.tf.util.basic.flatten_with_seq_len_mask`,
i.e. a "packed" sequence with the padded frames removed, and accumulates over that.
This can be more efficient, also because it will further optimize incoming computations
and e.g. skip softmax computations right before on the padded frames.
This can also avoid issues with inf/nan in some cases.
If False, it will mask the loss to 0 in the padded frames and accumulate over that.
Typically, setting this to True (default) is both more efficient and better.
:param bool use_normalized_loss: the loss used in optimization will be normalized.
E.g. if the overall normalization is sum(loss)/sum(num_frames), this is also what the optimizer will use,
otherwise the optimizer will just use sum(loss).
:param custom_inv_norm_factor:
The standard norm factor is 1/sum(target_seq_len) if the target has a time-axis,
or 1/sum(output_seq_len) if there is no target and the output has a time-axis,
or 1 otherwise. (See :func:`Loss.init` for details.)
This is used for proper normalization of accumulated loss/error per epoch
and also proper normalization per batch for reporting,
no matter if use_normalized_loss is True or False.
If you want to change this norm factor, you can set this.
As a function, it takes (self=self, output=output, layer=layer) and returns a float scalar.
This here is the inverse of the norm factor.
Here we also allow to pass any shape, and it will automatically be reduced via sum.
So you could simply pass target_seq_len directly here.
Basically, for all reporting, it uses sum(loss) * sum(custom_inv_norm_factor).
"""
root_scope = self.raw_tensor.root
res = nn.copy(self, name=root_scope.get_new_child(suggested_name=name))
res.raw_tensor.layer_dict["loss"] = "as_is"
loss_opts = {}
if scale is not None and scale != 1:
assert "loss_scale" not in res.raw_tensor.layer_dict
loss_opts["scale"] = scale
if as_error:
loss_opts["as_error"] = True
if use_normalized_loss:
loss_opts["use_normalized_loss"] = True
if not use_flatten_frames:
loss_opts["use_flatten_frames"] = False
if custom_inv_norm_factor is not None:
loss_opts["custom_inv_norm_factor"] = custom_inv_norm_factor
if loss_opts:
res.raw_tensor.layer_dict["loss_opts"] = loss_opts
# Add it to the root name scope marked_losses list.
# Note that this logic might change.
root_scope.marked_losses.append(res)
return res
def mark_as_output(self, *, _scope: Optional[nn.NameCtx] = None) -> Tensor:
"""
Mark this as an output.
This has the effect that RETURNN will in any case construct the corresponding layer.
Also see :func:`mark_as_default_output`.
"""
assert self.raw_tensor.layer_dict, f"mark_as_output can only be called on a layer, not a layer-ref {self}."
if not _scope:
scope = self.raw_tensor.root # mark_as_output always refers to the root
else:
scope = _scope # only for internal use
res = self
if self.raw_tensor is scope.children.get("output"):
pass # not needed
elif self.raw_tensor.parent is not scope:
res = nn.copy(self, name=scope.get_new_child(suggested_name=self.raw_tensor.get_abs_name(join_str="_")))
res.raw_tensor.layer_dict["is_output_layer"] = True
else:
assert self.raw_tensor.parent is scope
assert self.raw_tensor.layer_dict
self.raw_tensor.layer_dict["is_output_layer"] = True
scope.marked_outputs.append(res)
return res
def mark_as_default_output(self) -> Tensor:
"""
Mark this as the default output, i.e. create the "output" layer in the root with a reference to this.
This has the effect that RETURNN will in any case construct the corresponding layer,
and it is the default output layer for forwarding and potential other tasks.
:return: the "output" layer.
"""
res = self.raw_tensor.root.make_default_output(self)
res.mark_as_output()
return res
def _sis_hash(self):
# noinspection PyProtectedMember
return self.raw_tensor._sis_hash()
def __add__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, (int, float, numpy.number)) and other == 0:
return self
return nn.combine(self, other, kind="add", name="add")
def __sub__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, (int, float, numpy.number)) and other == 0:
return self
return nn.combine(self, other, kind="sub", name="sub")
def __mul__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, (int, float, numpy.number)) and other == 1:
return self
return nn.combine(self, other, kind="mul", name="mul")
def __truediv__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, (int, float, numpy.number)) and other == 1:
return self
return nn.combine(self, other, kind="truediv", name="truediv")
def __floordiv__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, (int, float, numpy.number)) and other == 1:
return self
return nn.combine(self, other, kind="floordiv", name="floordiv")
def __mod__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
return nn.combine(self, other, kind="mod", name="mod")
def __radd__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, (int, float, numpy.number)) and other == 0:
return self
return nn.combine(other, self, kind="add", name="add")
def __rsub__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, (int, float, numpy.number)) and other == 0:
return self
return nn.combine(other, self, kind="sub", name="sub")
def __rmul__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, (int, float, numpy.number)) and other == 1:
return self
return nn.combine(other, self, kind="mul", name="mul")
def __rtruediv__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, (int, float, numpy.number)) and other == 1:
return self
return nn.combine(other, self, kind="truediv", name="truediv")
def __rfloordiv__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, (int, float, numpy.number)) and other == 1:
return self
return nn.combine(other, self, kind="floordiv", name="floordiv")
def __rmod__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
return nn.combine(other, self, kind="mod", name="mod")
def __neg__(self) -> Tensor:
return nn.neg(self)
def __invert__(self) -> Tensor:
return nn.logical_not(self)
def __pow__(self, other: Union[RawTensorTypes, Tensor], modulo=None) -> Tensor:
assert modulo is None
if isinstance(other, (int, float, numpy.number)) and other == 1:
return self
return nn.combine(self, other, kind="pow", name="pow")
def __rpow__(self, other: Union[RawTensorTypes, Tensor], modulo=None) -> Tensor:
assert modulo is None
return nn.combine(other, self, kind="pow", name="pow")
def __and__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, bool) and other is True:
return self
if isinstance(other, bool) and other is False:
return nn.zeros_like(self)
return nn.combine(self, other, kind="logical_and", name="logical_and")
def __rand__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, bool) and other is True:
return self
if isinstance(other, bool) and other is False:
return nn.zeros_like(self)
return nn.combine(other, self, kind="logical_and", name="logical_and")
def __or__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, bool) and other is True:
return nn.ones_like(self)
if isinstance(other, bool) and other is False:
return self
return nn.combine(self, other, kind="logical_or", name="logical_or")
def __ror__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
if isinstance(other, bool) and other is True:
return nn.ones_like(self)
if isinstance(other, bool) and other is False:
return self
return nn.combine(self, other, kind="logical_and", name="logical_and")
def __abs__(self) -> Tensor:
return nn.abs(self)
def __ceil__(self) -> Tensor:
return nn.ceil(self)
def __floor__(self) -> Tensor:
return nn.floor(self)
def __eq__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
return nn.compare(self, nn.convert_to_tensor(other), kind="equal")
def __ne__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
return nn.compare(self, nn.convert_to_tensor(other), kind="not_equal")
def __lt__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
return nn.compare(self, nn.convert_to_tensor(other), kind="less")
def __le__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
return nn.compare(self, nn.convert_to_tensor(other), kind="less_equal")
def __gt__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
return nn.compare(self, nn.convert_to_tensor(other), kind="greater")
def __ge__(self, other: Union[RawTensorTypes, Tensor]) -> Tensor:
return nn.compare(self, nn.convert_to_tensor(other), kind="greater_equal")
class Parameter(Tensor):
"""
This represents a (potential trainable) parameter,
aka ``tf.Variable`` in TensorFlow,
wrapping to ``VariableLayer`` in RETURNN.
"""
def __init__(
self,
shape: Sequence[Dim],
dtype: Optional[str] = None,
*,
trainable: Optional[bool] = None,
auxiliary: bool = False,
non_critical_for_restore: bool = False,
):
"""
:param shape:
:param dtype:
:param trainable: if True, and optimizer would do updates to this parameter in training mode
:param auxiliary: if True, this indicates that this parameter should not be transformed by transformations
such as weight normalization. One example are running statistics, as used for batch normalization.
This usually implies that the parameter is not trainable, i.e. not to be updated by the optimizer,
but usually has some custom update.
This flag is not passed on to RETURNN but just used here for returnn-common logic.
:param non_critical_for_restore: if True, this parameter is not critical for restoring a model.
"""
if not all(isinstance(dim, Dim) for dim in shape):
raise TypeError(f"shape {shape} must be a sequence of Dim")
if not all(isinstance(dim.dimension, int) for dim in shape):
raise ValueError(f"shape {shape} must be static")
if len(shape) != len(set((d, d.match_priority) for d in shape)):
raise ValueError(f"shape {shape} dims must be unique")
# Note: At creation time, we don't know the name yet.
# The name will be inferred by the parent modules and the attribute chain.
# The name_ctx object will be completed by this information later.
# See Tensor.get_name_in_ctx().
name_ctx = nn.NameCtx(name="<unnamed-param>", parent=None)
name_ctx.require_global_access = True
data = Data("parameter", dim_tags=list(shape), dtype=dtype)
layer_dict = {"class": "variable", "shape": list(shape), "param_name": "param"}
if dtype is not None:
layer_dict["dtype"] = dtype
if auxiliary and trainable is None:
trainable = False
if trainable is not None:
layer_dict["trainable"] = trainable
if non_critical_for_restore:
layer_dict["non_critical_for_restore"] = True
super(Parameter, self).__init__(layer_dict=layer_dict, data=data, name_ctx=name_ctx)
self.auxiliary = auxiliary
self._initial = None # type: Optional[nn.init.ParamInitType]
def __copy__(self):
# Should return new copy. https://github.com/rwth-i6/returnn_common/pull/215#issuecomment-1269651064
res = type(self)(shape=self.dims, dtype=self.dtype, trainable=self.trainable, auxiliary=self.auxiliary)
res.initial = self.initial
return res
def __deepcopy__(self, memo=None):
# Should return new copy. https://github.com/rwth-i6/returnn_common/pull/215#issuecomment-1269651064
from copy import deepcopy
res = type(self)(shape=self.dims, dtype=self.dtype, trainable=self.trainable, auxiliary=self.auxiliary)
assert not res.raw_tensor.tensor_parent_modules
res.raw_tensor.tensor_parent_modules.extend(
(memo[id(m)], k) for m, k in self.raw_tensor.tensor_parent_modules if id(m) in memo
)
if isinstance(self.initial, nn.init.ParamInit):
res.initial = deepcopy(self.initial, memo=memo) # noqa
else:
res.initial = self.initial
return res
@property
def initial(self) -> Optional[nn.init.ParamInitType]:
"""initial value of the parameter"""
return self._initial
@initial.setter
def initial(self, value: Optional[nn.init.ParamInitType]):
import tensorflow as tf
# Keep the original ParamInit, so that copies of the Parameter would have a different initial random value.
# https://github.com/rwth-i6/returnn_common/issues/216
self._initial = value
if isinstance(value, nn.init.ParamInit):
value = value(shape=self.dims, dtype=self.dtype)
if value is None:
self.raw_tensor.layer_dict.pop("init", None)
self.raw_tensor.layer_dict.pop("init_by_layer", None)
elif isinstance(value, nn.Tensor):
self.raw_tensor.layer_dict.pop("init", None)
if not value.raw_tensor.parent.can_access_children_from_root:
accessible_parent = value.raw_tensor.parent
while not accessible_parent.can_access_children_from_root:
accessible_parent = accessible_parent.parent
value.raw_tensor.assign_parent(accessible_parent)
# We could also maybe move out all the dependencies.
# However, it's not clear whether this is always safe.
for dep in value.raw_tensor.get_tensor_dependencies():
assert (
dep.parent.can_access_children_from_root
), f"dep {dep} of moved value {value} is not accessible"
self.raw_tensor.layer_dict["init_by_layer"] = value
else:
self.raw_tensor.layer_dict.pop("init_by_layer", None)
self.raw_tensor.layer_dict["init"] = value
if nn.is_debug_eager_mode_enabled():
shape = [d.get_dim_value() for d in self.dims]
if isinstance(value, nn.Tensor):
assert value.data.placeholder is not None
value_tf = value.data.placeholder
else:
value_tf = tf.broadcast_to(tf.convert_to_tensor(value), shape)
if self.data.placeholder is None:
var = tf.Variable(value_tf, shape=[d.get_dim_value() for d in self.dims], dtype=self.data.dtype)
self.data.placeholder = var
else:
var = self.data.placeholder
assert isinstance(var, tf.Variable)
var.assign(value_tf)
@property
def initial_value(self) -> Optional[Union[nn.Tensor, nn.RawTensorTypes]]:
"""
In case initial is a ParamInit, this will return the actual value.
"""
if self.raw_tensor.layer_dict.get("init_by_layer", None) is not None:
return self.raw_tensor.layer_dict["init_by_layer"]
return self.raw_tensor.layer_dict.get("init", None)
@property
def weight_decay(self) -> float:
"""
Weight decay, which is equivalent to L2 loss on the parameters for SGD.
On RETURNN side, whether this is handled separately or is part of the main loss,
can be controlled via the ``decouple_constraints`` config option.
https://github.com/rwth-i6/returnn_common/issues/59#issuecomment-1073913421
"""
return self.raw_tensor.layer_dict.get("L2", 0.0)
@weight_decay.setter
def weight_decay(self, value: Optional[float]):
if value:
self.raw_tensor.layer_dict["L2"] = value
else:
self.raw_tensor.layer_dict.pop("L2", None)
@property
def trainable(self) -> Optional[bool]:
"""trainable"""
return self.raw_tensor.layer_dict.get("trainable", None)
@trainable.setter
def trainable(self, value: Optional[bool]):
if value is not None:
self.raw_tensor.layer_dict["trainable"] = value
else:
self.raw_tensor.layer_dict.pop("trainable", None)
class LayerState(dict):
"""
Covers all the state of a layer,
i.e. exactly what needs to be stored and passed into the module or module
next time you call it as initial state.
This behaves somewhat like a namedtuple, although we derive from dict.
"""
def __init__(self, *args, **kwargs):
if kwargs:
assert not args
super().__init__(**kwargs)
elif args:
assert len(args) == 1
if isinstance(args[0], dict):
super().__init__(**args[0])
else:
super().__init__(state=args[0])
else:
super().__init__()
def __repr__(self):
return f"{self.__class__.__name__}({', '.join(f'{k}={v!r}' for (k, v) in self.items())})"
def __getattr__(self, item):
if item in self:
return self[item]
raise AttributeError(f"{self}.{item}")
def __setattr__(self, key, value):
self[key] = value
def deep_tensors(self) -> List[nn.Tensor]:
"""See :func:`cls_deep_tensors`."""
return self.cls_deep_tensors(self)
@classmethod
def cls_deep_tensors(cls, obj: Union[LayerState, dict, Any]) -> List[nn.Tensor]:
"""
Iterates through obj and all its sub-objects, yielding all tensors.
"""
from .loop import _LoopStateHolder
cache_tensor_names = set() # type: Set[nn.NameCtx] # names because tensors are not hashable
tensors = [] # type: List[nn.Tensor]
queue = [obj]
while queue:
x = queue.pop()
if isinstance(x, nn.Tensor):
if x.raw_tensor not in cache_tensor_names:
cache_tensor_names.add(x.raw_tensor)
tensors.append(x)
elif isinstance(x, (dict, _LoopStateHolder)):
queue.extend(x.values())
elif isinstance(x, (list, tuple)):
queue.extend(x)
else:
raise TypeError(f"unexpected type {type(x)}")
return tensors
def make_layer(
layer_dict: LayerDictRaw,
*,
name: Optional[Union[str, nn.NameCtx]] = None,
predefined_out_data: Optional[Data] = None,
name_ctx_ignore_top_stack_frames: int = 0,
) -> Tensor:
"""
Creates the layer. This also registers the layer instance in the top name ctx.
When no name is given, this assumes that the top name ctx corresponds to this module.
If a layer has params, and you want the param sharing logic,
you should instead derive a new class from :class:`Module`.
Usually, you do not need either of these,
as all standard layers should already be wrapped,
and it should be possible to define any possible logic
using that.
(If this is not the case, please report an issue.)
:param LayerDictRaw layer_dict: can contain :class:`Tensor` instances
:param str|NameCtx|None name:
if str: (suggested) layer name. if given, will create a new :class:`NameCtx`
if NameCtx, will use this.
:param Data|None predefined_out_data: normally we can derive the out data automatically.
If this should be skipped, you can pass this explicitly.
:param int name_ctx_ignore_top_stack_frames: for :func:`NameCtx.current_ctx`.
If your calling function creates exactly one single layer, you might want to ignore its stack frame
and set ignore_top_stack_frames=1 and also set a name for the layer.
If you are potentially creating multiple layers in your calling function,
leave the default ignore_top_stack_frames=0.
Some postprocessing step might anyway simplify obsolete subnetworks,
see :mod:`naming`.
"""
if isinstance(name, str) or not name:
parent_ctx = nn.NameCtx.current_ctx(ignore_top_stack_frames=name_ctx_ignore_top_stack_frames + 1)
name_ctx = nn.NameCtx(suggested_name=name, parent=parent_ctx)
created_name_ctx = True
elif isinstance(name, nn.NameCtx):
name_ctx = name
created_name_ctx = False
else:
raise TypeError(f"name must be str or NameCtx, not {type(name)}; or you should pass a module")
assert not name_ctx.tensor and not name_ctx.layer_dict # not yet assigned
layer_dict = layer_dict.copy()
try:
layer = Tensor(layer_dict=layer_dict, name_ctx=name_ctx, data=predefined_out_data)
except Exception as exc:
# Just forward the exception.
# However, if we already created a new name_ctx for it, we can clean this up now.
if created_name_ctx:
assert name_ctx.parent
name_ctx.parent.children.pop(name_ctx.name)
raise exc
if name_ctx.module:
name_ctx.module.calls.append(name_ctx)
for tag in layer.data.dim_tags:
_register_dim_deps_when_novel(tag, [layer])
# Debug out. Similar as RETURNN template log. Maybe put this behind a flag? Anyway, useful for now.
print(layer)
return layer
def get_extern_data(data: Data) -> Tensor:
"""
Get extern data from root ctx.
As a side effect, it registers the given data as extern data,
and this will be included when creating the RETURNN config,
via :func:`NameCtx.get_returnn_config`.
"""
assert isinstance(data, Data) # the usage was different before. make sure we get this correct
scope = nn.NameCtx.top() # must exist
assert not scope.parent # get_extern_data only allowed (only makes sense) in root name ctx
if data.name not in scope.extern_data:
scope.extern_data[data.name] = data
else:
assert scope.extern_data[data.name] is data
data.batch = _init_global_batch()
root_layer_name = f"data:{data.name}"
out = _get_raw_layer_by_name(root_layer_name, scope=scope, data=data)
for tag in data.dim_tags:
if not tag.is_batch_dim() and tag.is_dynamic() and not tag.dyn_size_ext:
# Undefined dynamic dim tag. Set default data template.
tag.dyn_size_ext = Data(
name=f"{data.name}_default_dyn_size_ext",
dim_tags=[nn.batch_dim],
dtype=data.size_dtype,
batch=data.batch,
)
_register_dim_deps_when_novel(tag, [out])
if nn.is_debug_eager_mode_enabled():
out.data.placeholder = _make_random_tf_tensor_for_returnn_data(out.data)
return out
def _make_random_tf_tensor_for_returnn_data(data: Data):
import tensorflow as tf
shape = []
for dim in data.dim_tags:
if dim.is_batch_dim():
assert data.batch
shape.append(data.batch.dim)
elif dim.dimension is not None:
shape.append(dim.dimension)
else:
dim.complete_dyn_size()
if dim.dyn_size_ext is None:
assert data.batch
dim.dyn_size_ext = Data(
name=f"{data.name}_dummy_dyn_size_ext",
dim_tags=[nn.batch_dim],
dtype=data.size_dtype,
batch=data.batch,
)
if dim.dyn_size_ext.placeholder is None:
dim.dyn_size_ext.placeholder = _make_random_tf_tensor_for_returnn_data(dim.dyn_size_ext)
shape.append(tf.reduce_max(dim.dyn_size_ext.placeholder))
dtype = tf.as_dtype(data.dtype)
if dtype.is_integer:
if data.sparse:
return tf.random.uniform(shape=shape, dtype=dtype, minval=0, maxval=data.dim)
else:
import binascii
c = abs(binascii.crc32(data.name.encode("utf8"))) % 21 + 3
shape = tf.convert_to_tensor(shape)
c_tf = tf.constant(c, name="dummy_random_const", dtype=dtype)
rnd = tf.broadcast_to(c_tf, shape)
rnd_diff = tf.random.uniform(shape=shape, minval=0, maxval=2**31 - 1, dtype=dtype)
rnd_diff = rnd_diff % tf.reshape(tf.minimum(tf.range(0, tf.size(rnd), dtype=dtype) + 1, c_tf - 2), shape)
rnd = tf.clip_by_value(rnd - rnd_diff, 1, c_tf)
return rnd
assert dtype.is_floating # not implemented otherwise
return tf.random.normal(shape=shape, dtype=dtype)
def _get_raw_layer_by_name(name: str, *, scope: Optional[nn.NameCtx] = None, data: Data) -> Tensor:
"""
Special layer can be "data:..." or whatever.
"""
if not scope:
scope = nn.NameCtx.current_ctx() # must exist
return scope.get_child_tensor(name, data=data)
def _get_sub_layer(layer: Tensor, name: str, *, data: Data) -> Tensor:
"""
Like the "{layer}/{name}" syntax in RETURNN.
Normally this should only be needed for internal usage.
"""
out = layer.raw_tensor.get_child_tensor(name, data=data)
if nn.is_debug_eager_mode_enabled():
assert layer.raw_tensor.debug_layer
import returnn.tf.layers.base
assert isinstance(layer.raw_tensor.debug_layer, returnn.tf.layers.base.LayerBase)
sub_layer = layer.raw_tensor.debug_layer.get_sub_layer(name)
assert sub_layer and sub_layer.output.dim_tags == out.data.dim_tags
out.raw_tensor.debug_layer = sub_layer
out.data = sub_layer.output
return out
class ReturnnConstructTemplateException(Exception):
"""
In :func:`_data_from_layer_dict`, when we call layer_class.get_out_data_from_opts,
we potentially can get errors, often due to user mistakes.
We wrap those errors in this exception for better reporting.
"""
def _init_global_batch():
from returnn.tf.util.data import BatchInfo
root_name_ctx = nn.NameCtx.top().root
if root_name_ctx.global_batch:
return root_name_ctx.global_batch
if nn.is_debug_eager_mode_enabled():
import tensorflow as tf
root_name_ctx.global_batch = BatchInfo.make_global_batch_info(
tf.constant(3, name="global_batch")
) # https://xkcd.com/221/, but prime
else:
# We need some global batch info, and this needs a tensor (e.g. placeholder),
# but we don't have any tensor yet, nor do we want to create any tensors at this point.
# So we pass the dummy value -1.
# Such dummy global batch info with -1 will be handled specially in RETURNN init_batch_info,
# and it will be replaced with the real global batch.
root_name_ctx.global_batch = BatchInfo.make_global_batch_info(-1)
return root_name_ctx.global_batch
def _data_from_layer_dict(layer_dict: LayerDictRaw, *, tensor: Tensor) -> Data:
"""
Use RETURNN layer_class.get_out_data_from_opts to get the :class:`Data`.
For this function, we need to set up some dummy network and dummy source layers.
"""
from returnn.tf.network import TFNetwork, ExternData
from returnn.tf.layers.base import InternalLayer, LayerBase
from returnn.util import BehaviorVersion
from returnn.config import get_global_config
config = get_global_config(auto_create=True)
config.typed_dict.setdefault("behavior_version", min_returnn_behavior_version)
BehaviorVersion.set(config.typed_dict["behavior_version"])
loop = nn.NameCtx.inner_loop() # Note: for control_flow_ctx, we should also check Cond
net = TFNetwork(
config=config,
extern_data=ExternData(),
name="dummy_net",
train_flag=True, # should not have an effect usually for templates, except maybe in debug-eager-mode
inside_rec_time_dim=loop.axis if loop else None,
control_flow_ctx=nn.NameCtx.inner_control_flow(),
)
net.extern_data.set_batch_info(_init_global_batch())
ref_to_layer_name = {} # type: Dict[nn.NameCtx, str]
def _get_unique_name(name) -> str:
reserved_names = set(net.layers.keys()) | {"data"}
if name not in reserved_names:
return name
i = 0
while True:
name_ = f"{name}_{i}"
if name_ not in reserved_names:
return name_
i += 1
def _get_layer_name(ref: Tensor) -> str:
if ref.raw_tensor in ref_to_layer_name:
return ref_to_layer_name[ref.raw_tensor]
name = _get_unique_name(ref.raw_tensor.name)
ref_to_layer_name[ref.raw_tensor] = name
assert name not in net.layers
data = ref.data.copy()
net.layers[name] = InternalLayer(name=name, network=net, output=data)
return name
def _map_layer_dict_elem(value):
if isinstance(value, Tensor):
return _get_layer_name(value)
return value
layer_dict = tree.map_structure(_map_layer_dict_elem, layer_dict)
out_name = _get_unique_name(tensor.raw_tensor.name)
net_dict = {
out_name: layer_dict,
# Simple workaround in case the layer wants to access its previous layer.
# https://github.com/rwth-i6/returnn_common/issues/243
f"prev:{out_name}": {"class": "constant", "shape": ()},
}
if nn.is_debug_eager_mode_enabled():
_add_layer = None # implies to really construct the layer
else:
# Creates only a template layer.
def _add_layer(name: str, layer_class: Type[LayerBase], **layer_desc) -> LayerBase:
# noinspection PyProtectedMember
layer_desc = net._create_layer_layer_desc(name=name, layer_desc=layer_desc, template=True)
try:
out_data = layer_class.get_out_data_from_opts(**layer_desc)
out_data = layer_class.fixup_out_data(out_data, **layer_desc)
except Exception as exc:
msgs = ["The RETURNN call\n", f" {layer_class.__name__}.get_out_data_from_opts(\n"]
for key, v in layer_desc.items():
msgs.append(f" {key}={v!r},\n")
msgs += [
" )\n",
"raised the exception:\n",
f" {type(exc).__name__} {exc!s}\n",
"(See above for the RETURNN exception traceback.)",
]
# Use `with_traceback`, such that the user directly sees the full traceback,
# and also that debuggers stop right where it matters.