forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_tensor.py
1662 lines (1422 loc) · 64.8 KB
/
_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# mypy: allow-untyped-defs
import copyreg
import enum
import functools
import warnings
from collections import OrderedDict
from copy import deepcopy
from numbers import Number
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch._C as _C
from torch._namedtensor_internals import (
check_serializing_named_tensor,
is_ellipsis,
resolve_ellipsis,
single_ellipsis_index,
unzip_namedshape,
update_names,
)
from torch.overrides import (
get_default_nowrap_functions,
handle_torch_function,
has_torch_function,
has_torch_function_unary,
has_torch_function_variadic,
)
def _handle_torch_function_and_wrap_type_error_to_not_implemented(f):
assigned = functools.WRAPPER_ASSIGNMENTS
@functools.wraps(f, assigned=assigned)
def wrapped(*args, **kwargs):
try:
# See https://github.com/pytorch/pytorch/issues/75462
if has_torch_function(args):
return handle_torch_function(wrapped, args, *args, **kwargs)
return f(*args, **kwargs)
except TypeError:
return NotImplemented
return wrapped
# Should not be used, this is kept only for BC of loading old serialized Tensor subclasses
def _rebuild_from_type(func, type, args, dict):
if type is Tensor:
return func(*args)
ret = func(*args).as_subclass(type)
ret.__dict__ = dict
return ret
def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
if type(ret) is not new_type:
ret = ret.as_subclass(new_type)
# Tensor does define __setstate__ even though it doesn't define
# __getstate__. So only use __setstate__ if it is NOT the one defined
# on Tensor
if (
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
is not Tensor.__setstate__
):
ret.__setstate__(state)
else:
ret = torch._utils._set_obj_state(ret, state)
return ret
# NB: If you subclass Tensor, and want to share the subclassed class
# across processes, you must also update torch/multiprocessing/reductions.py
# to define a ForkingPickler serialization mode for the class.
#
# NB: If you add a new method to Tensor, you must update
# torch/_C/__init__.pyi.in to add a type annotation for your method;
# otherwise, it will not show up in autocomplete.
class Tensor(torch._C.TensorBase):
_is_param: bool
def _clear_non_serializable_cached_data(self):
r"""Clears any data cached in the tensor's ``__dict__`` that would prevent the tensor
from being serialized.
For example, subclasses with custom dispatched sizes / strides cache this info in
non-serializable PyCapsules within the ``__dict__``, and this must be cleared out for
serialization to function.
Any subclass that overrides this MUST call ``super()._clear_non_serializable_cached_data().``
Additional data cleared within the override must be able to be re-cached transparently
to avoid breaking subclass functionality.
"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor._clear_non_serializable_cached_data, (self,), self
)
# NB: Wrapper subclasses that implement custom-dispatched sizes / strides cache
# this info via non-serializable PyCapsules.
CACHED_SIZES_STRIDES_KEYS = [
"_sym_sizes_capsule",
"_sym_sizes_capsule_len",
"_sym_strides_capsule",
"_sym_strides_capsule_len",
]
for key in CACHED_SIZES_STRIDES_KEYS:
self.__dict__.pop(key, None)
def __deepcopy__(self, memo):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
if not self.is_leaf:
raise RuntimeError(
"Only Tensors created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment. "
"If you were attempting to deepcopy a module, this may be because "
"of a torch.nn.utils.weight_norm usage, "
"see https://github.com/pytorch/pytorch/pull/103001"
)
if id(self) in memo:
return memo[id(self)]
with torch.no_grad():
# TODO: skipping storage copy is wrong for meta, as meta
# does accurate alias tracking; however, the code below
# doesn't work because of
# https://github.com/pytorch/pytorch/issues/47442
# Update the test in test_serialization if you remove 'meta' from here
if (
self.is_sparse
or self.device.type
in ["lazy", "xla", "mtia", "mps", "maia", "meta", "ipu"]
or (
not torch._C._has_storage(self)
and self.device.type == torch._C._get_privateuse1_backend_name()
)
or (type(self) is not Tensor and self.data_ptr() == 0)
):
new_tensor = self.clone()
if type(new_tensor) is not type(self):
raise RuntimeError(
"The default implementation of __deepcopy__() for wrapper subclasses "
"only works for subclass types that implement clone() and for which "
"cloning returns another instance of the same subclass. You should either "
"properly implement clone() for your subclass or override __deepcopy__() "
"if it is intended behavior for clone() to return an instance of a "
"different type."
)
else:
new_storage = self._typed_storage()._deepcopy(memo)
if self.is_quantized:
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[
Tuple[torch.qscheme, float, int],
Tuple[torch.qscheme, Tensor, Tensor, int],
]
if self.qscheme() == torch.per_tensor_affine:
quantizer_params = (
self.qscheme(),
self.q_scale(),
self.q_zero_point(),
)
elif self.qscheme() in (
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
):
quantizer_params = (
self.qscheme(),
self.q_per_channel_scales(),
self.q_per_channel_zero_points(),
self.q_per_channel_axis(),
)
else:
raise RuntimeError(
f"Unsupported qscheme {self.qscheme()} in deepcopy"
)
# TODO: Once we decide to break serialization FC, no longer
# need to wrap with TypedStorage
new_tensor = torch._utils._rebuild_qtensor(
torch.storage.TypedStorage(
wrap_storage=new_storage._untyped_storage,
dtype=self.dtype,
_internal=True,
),
self.storage_offset(),
self.size(),
self.stride(),
quantizer_params,
self.requires_grad,
self._backward_hooks,
)
if type(new_tensor) is not type(self):
raise RuntimeError(
"The default implementation of __deepcopy__() for quantized tensors "
"expects the tensor returned by torch._utils._rebuild_qtensor() to "
"match the type of the instance being copied. If you encounter this, "
"please open an issue on PyTorch's GitHub."
)
else:
new_tensor = self.new_empty([])
if type(new_tensor) is not type(self):
raise RuntimeError(
"The default implementation of __deepcopy__() for non-wrapper subclasses "
"only works for subclass types that implement new_empty() and for which "
"that function returns another instance of the same subclass. You should "
"either properly implement new_empty() for your subclass or override "
"__deepcopy__() if it is intended behavior for new_empty() to return "
"an instance of a different type."
)
new_tensor.set_(
new_storage, self.storage_offset(), self.size(), self.stride()
)
if self.is_conj():
new_tensor = new_tensor.conj_physical()
if self.is_neg():
new_tensor = new_tensor.neg()
if self.requires_grad:
new_tensor.requires_grad_()
if self.grad is not None:
new_tensor.grad = self.grad.__deepcopy__(memo)
if type(self) is not Tensor:
if type(new_tensor) is not type(self):
raise RuntimeError(
"Type of deepcopy result does not match the type of the source tensor. "
"If you encounter this, please open an issue on PyTorch's GitHub."
)
# Plain Tensors don't have slots
slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined]
for slot in slots_to_save:
if hasattr(self, slot):
setattr(new_tensor, slot, deepcopy(getattr(self, slot), memo))
# don't try to deepcopy non-serializable cached data
self._clear_non_serializable_cached_data()
new_tensor.__dict__ = deepcopy(self.__dict__, memo)
memo[id(self)] = new_tensor
return new_tensor
def __reduce_ex__(self, proto):
materialize_fake_tensors = (
torch.serialization._serialization_tls.materialize_fake_tensors
)
state = torch._utils._get_obj_state(self)
# Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has
# some state that cannot be pickled
if (
# TODO: remove hasattr, it's a hack to support versions of torch that
# don't have _subclasses
hasattr(torch, "_subclasses")
and type(self) is torch._subclasses.fake_tensor.FakeTensor
and materialize_fake_tensors
) or (type(self) is Tensor and not state):
# Fast path for regular tensor without Python state.
return self._reduce_ex_internal(proto)
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto)
func, args = self._reduce_ex_internal(proto)
# sizes / strides cache needs to be cleared here because it'll just be re-cached
# if cleared earlier. Note that state references the -actual- tensor dict.
self._clear_non_serializable_cached_data()
return (_rebuild_from_type_v2, (func, type(self), args, state))
def storage(self):
r"""
storage() -> torch.TypedStorage
Returns the underlying :class:`TypedStorage`.
.. warning::
:class:`TypedStorage` is deprecated. It will be removed in the future, and
:class:`UntypedStorage` will be the only storage class. To access the
:class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`.
"""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.storage, (self,), self)
torch.storage._warn_typed_storage_removal(stacklevel=2)
return self._typed_storage()
# For internal use only, to avoid raising deprecation warning
def _typed_storage(self):
untyped_storage = self.untyped_storage()
return torch.TypedStorage(
wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
)
def _reduce_ex_internal(self, proto):
check_serializing_named_tensor(self)
from torch.utils.hooks import warn_if_has_hooks
# See Note [Don't serialize hooks]
warn_if_has_hooks(self)
backward_hooks: Dict[Any, Any] = OrderedDict()
skip_data = torch.serialization._serialization_tls.skip_data
materialize_fake_tensors = (
torch.serialization._serialization_tls.materialize_fake_tensors
)
if self.device.type in ["xla", "maia"] or (
not torch._C._has_storage(self)
and self.device.type == torch._C._get_privateuse1_backend_name()
):
if skip_data:
raise RuntimeError(
"Cannot serialize tensors on backends with no storage under skip_data context manager"
)
cpu_tensor = self.cpu()
return (
torch._utils._rebuild_device_tensor_from_cpu_tensor,
(cpu_tensor, self.dtype, str(self.device), self.requires_grad),
)
# Legacy comment that does not hold anymore.
# Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors.
# We considered a few options:
# 1. CPU tensor can't be used here.
# Otherwise in torch.load CPU storage is reconstructed with randomly
# initialized data, moved onto backend device, and then storage is updated
# to the serialized content. This works perfectly for CPU/CUDA but not these backends;
# their tensors are disconnected with storage so they don't get the update.
# 2. Python list is not a good fit due to performance reason.
# `tolist()` converts every single element in the tensor into python objects
# and serialize them one by one.
if self.device.type in ["mtia"]:
# Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
# support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
# this would reconstruct the BFloat16 tensor from numpy.
if skip_data:
raise RuntimeError(
"Cannot serialize tensors on backends with no storage under skip_data context manager"
)
numpy_tensor = (
self.cpu().numpy()
if self.dtype != torch.bfloat16
else self.cpu().to(torch.float32).numpy()
)
return (
torch._utils._rebuild_device_tensor_from_numpy,
(numpy_tensor, self.dtype, str(self.device), self.requires_grad),
)
if self.device.type == "meta":
# NB: This implementation BREAKS storage sharing. Current
# hypothesis is that no one cares for meta tensors.
if skip_data:
warnings.warn(
"Serializing tensors on the meta device under skip_data context manager is a no-op"
)
arg_meta = (
self.dtype,
tuple(self.size()),
self.stride(),
self.requires_grad,
)
return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta)
if self.is_quantized:
if skip_data:
raise RuntimeError(
"Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature"
)
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[
Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]
]
if self.qscheme() == torch.per_tensor_affine:
quantizer_params = (
torch.per_tensor_affine,
self.q_scale(),
self.q_zero_point(),
)
elif self.qscheme() in (
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
):
# convert scales and zero points to tuple to avoid recursive calls
# when/if we get multi-axis quantized tensors in the future, the shape
# is recoverable from the main tensor shape
quantizer_params = (
torch.per_channel_affine,
self.q_per_channel_scales(),
self.q_per_channel_zero_points(),
self.q_per_channel_axis(),
)
else:
raise RuntimeError(
f"Serialization is not supported for tensors of type {self.qscheme()}"
)
# TODO: Once we decide to break serialization FC, no longer
# need to wrap with TypedStorage
args_qtensor = (
torch.storage.TypedStorage(
wrap_storage=self._typed_storage()._untyped_storage,
dtype=self.dtype,
_internal=True,
),
self.storage_offset(),
tuple(self.size()),
self.stride(),
quantizer_params,
self.requires_grad,
backward_hooks,
)
return (torch._utils._rebuild_qtensor, args_qtensor)
elif self.is_sparse:
if self.layout == torch.sparse_coo:
args_sparse = (
self.layout,
(self._indices(), self._values(), self.size(), self.is_coalesced()),
)
else:
raise NotImplementedError(
f"sparse tensor __reduce_ex__ for layout `{self.layout}`"
)
return (torch._utils._rebuild_sparse_tensor, args_sparse)
elif self.layout in {
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
}:
if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
compressed_indices, plain_indices = (
self.crow_indices(),
self.col_indices(),
)
else:
compressed_indices, plain_indices = (
self.ccol_indices(),
self.row_indices(),
)
args_sparse_compressed = (
self.layout,
(
compressed_indices,
plain_indices,
self.values(),
self.size(),
),
)
return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed)
elif self.is_nested:
if skip_data:
raise RuntimeError(
"Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature"
)
args_nested = (
# NB: values() currently returns the storage as a buffer in an unsafe way.
# Ideally, we'd use a private API for this instead. TODO: Switch to this if
# we ever get around to adding it.
self.values(),
self._nested_tensor_size(),
self._nested_tensor_strides(),
self._nested_tensor_storage_offsets(),
)
return (torch._utils._rebuild_nested_tensor, args_nested)
elif (
type(self) is not torch.Tensor
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
and (
isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor)
or (
not isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
and self.data_ptr() == 0
)
)
):
arg_wrapper_subclass = (
type(self),
self.dtype,
tuple(self.size()),
self.stride(),
self.storage_offset(),
self.layout,
self.device,
self.requires_grad,
)
return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
elif (
type(self) is not torch.Tensor
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
and (
isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
and not (skip_data and materialize_fake_tensors)
)
):
arg_wrapper_subclass = (
type(self),
self.dtype,
tuple(self.size()),
self.stride(),
self.storage_offset(),
self.layout,
self.device,
self.requires_grad,
)
return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
else:
v3_dtypes = torch.storage._new_dtypes()
if self.dtype in v3_dtypes:
rebuild_func = torch._utils._rebuild_tensor_v3
storage = self.untyped_storage()
else:
# TODO: Once we decide to break serialization FC, no longer
# need to wrap with TypedStorage
rebuild_func = torch._utils._rebuild_tensor_v2 # type: ignore[assignment]
storage = torch.storage.TypedStorage(
wrap_storage=self._typed_storage()._untyped_storage,
dtype=self.dtype,
_internal=True,
) # type: ignore[assignment]
# TODO: remove hasattr, it's a hack to support versions of torch that
# don't have _subclasses
if (
hasattr(torch, "_subclasses")
and isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
and skip_data
):
storage._fake_device = self.device
args = (
storage,
self.storage_offset(),
tuple(self.size()),
self.stride(),
self.requires_grad,
backward_hooks,
) # previously was self._backward_hooks
if isinstance(storage, torch.storage.UntypedStorage):
args = args + (self.dtype,) # type: ignore[assignment]
metadata = torch._utils.get_tensor_metadata(self)
if metadata:
args = args + (metadata,) # type: ignore[assignment]
return (rebuild_func, args)
def __setstate__(self, state):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__setstate__, (self,), self, state)
# Warning: this method is NOT called when you torch.load() a tensor;
# that is managed by _rebuild_tensor_v2
if not self.is_leaf:
raise RuntimeError("__setstate__ can be only called on leaf Tensors")
if len(state) == 4:
# legacy serialization of Tensor
self.set_(*state)
return
elif len(state) == 5:
# legacy serialization of Variable
self.data = state[0]
state = (state[3], state[4], state[2])
# The setting of _backward_hooks is expected to be a no-op.
# See Note [Don't serialize hooks]
self.requires_grad, _, self._backward_hooks = state
def __repr__(self, *, tensor_contents=None):
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.__repr__, (self,), self, tensor_contents=tensor_contents
)
# All strings are unicode in Python 3.
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
def backward(
self, gradient=None, retain_graph=None, create_graph=False, inputs=None
):
r"""Computes the gradient of current tensor wrt graph leaves.
The graph is differentiated using the chain rule. If the tensor is
non-scalar (i.e. its data has more than one element) and requires
gradient, the function additionally requires specifying a ``gradient``.
It should be a tensor of matching type and shape, that represents
the gradient of the differentiated function w.r.t. ``self``.
This function accumulates gradients in the leaves - you might need to zero
``.grad`` attributes or set them to ``None`` before calling it.
See :ref:`Default gradient layouts<default-grad-layouts>`
for details on the memory layout of accumulated gradients.
.. note::
If you run any forward ops, create ``gradient``, and/or call ``backward``
in a user-specified CUDA stream context, see
:ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
.. note::
When ``inputs`` are provided and a given input is not a leaf,
the current implementation will call its grad_fn (though it is not strictly needed to get this gradients).
It is an implementation detail on which the user should not rely.
See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
Args:
gradient (Tensor, optional): The gradient of the function
being differentiated w.r.t. ``self``.
This argument can be omitted if ``self`` is a scalar.
retain_graph (bool, optional): If ``False``, the graph used to compute
the grads will be freed. Note that in nearly all cases setting
this option to True is not needed and often can be worked around
in a much more efficient way. Defaults to the value of
``create_graph``.
create_graph (bool, optional): If ``True``, graph of the derivative will
be constructed, allowing to compute higher order derivative
products. Defaults to ``False``.
inputs (sequence of Tensor, optional): Inputs w.r.t. which the gradient will be
accumulated into ``.grad``. All other tensors will be ignored. If not
provided, the gradient is accumulated into all the leaf Tensors that were
used to compute the :attr:`tensors`.
"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.backward,
(self,),
self,
gradient=gradient,
retain_graph=retain_graph,
create_graph=create_graph,
inputs=inputs,
)
torch.autograd.backward(
self, gradient, retain_graph, create_graph, inputs=inputs
)
def register_hook(self, hook):
r"""Registers a backward hook.
The hook will be called every time a gradient with respect to the
Tensor is computed. The hook should have the following signature::
hook(grad) -> Tensor or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad`.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks.
Example::
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad
2
4
6
[torch.FloatTensor of size (3,)]
>>> h.remove() # removes the hook
"""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.register_hook, (self,), self, hook)
if not self.requires_grad:
raise RuntimeError(
"cannot register a hook on a tensor that doesn't require gradient"
)
if self._backward_hooks is None:
self._backward_hooks = OrderedDict()
if self.grad_fn is not None:
self.grad_fn._register_hook_dict(self)
from torch.utils.hooks import RemovableHandle
handle = RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
def register_post_accumulate_grad_hook(self, hook):
r"""Registers a backward hook that runs after grad accumulation.
The hook will be called after all gradients for a tensor have been accumulated,
meaning that the .grad field has been updated on that tensor. The post
accumulate grad hook is ONLY applicable for leaf tensors (tensors without a
.grad_fn field). Registering this hook on a non-leaf tensor will error!
The hook should have the following signature::
hook(param: Tensor) -> None
Note that, unlike other autograd hooks, this hook operates on the tensor
that requires grad and not the grad itself. The hook can in-place modify
and access its Tensor argument, including its .grad field.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks. Since
this hook runs during the backward pass, it will run in no_grad mode (unless
create_graph is True). You can use torch.enable_grad() to re-enable autograd
within the hook if you need it.
Example::
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> lr = 0.01
>>> # simulate a simple SGD update
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)
>>> h.remove() # removes the hook
"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.register_post_accumulate_grad_hook, (self,), self, hook
)
if not self.requires_grad:
raise RuntimeError(
"cannot register a hook on a tensor that doesn't require gradient"
)
if self.grad_fn is not None:
raise RuntimeError(
"post accumulate grad hooks cannot be registered on non-leaf tensors"
)
if self._post_accumulate_grad_hooks is None:
self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict()
from torch.utils.hooks import RemovableHandle
handle = RemovableHandle(self._post_accumulate_grad_hooks)
self._post_accumulate_grad_hooks[handle.id] = hook
return handle
def reinforce(self, reward):
def trim(str):
return "\n".join([line.strip() for line in str.split("\n")])
raise RuntimeError(
trim(
r"""reinforce() was removed.
Use torch.distributions instead.
See https://pytorch.org/docs/main/distributions.html
Instead of:
probs = policy_network(state)
action = probs.multinomial()
next_state, reward = env.step(action)
action.reinforce(reward)
action.backward()
Use:
probs = policy_network(state)
# NOTE: categorical is equivalent to what used to be called multinomial
m = torch.distributions.Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()
"""
)
)
detach = _C._add_docstr(
_C.TensorBase.detach,
r"""
Returns a new Tensor, detached from the current graph.
The result will never require gradient.
This method also affects forward mode AD gradients and the result will never
have forward mode AD gradients.
.. note::
Returned Tensor shares the same storage with the original one.
In-place modifications on either of them will be seen, and may trigger
errors in correctness checks.
""",
)
detach_ = _C._add_docstr(
_C.TensorBase.detach_,
r"""
Detaches the Tensor from the graph that created it, making it a leaf.
Views cannot be detached in-place.
This method also affects forward mode AD gradients and the result will never
have forward mode AD gradients.
""",
)
def is_shared(self):
r"""Checks if tensor is in shared memory.
This is always ``True`` for CUDA tensors.
"""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.is_shared, (self,), self)
return self._typed_storage()._is_shared()
def share_memory_(self):
r"""Moves the underlying storage to shared memory.
This is a no-op if the underlying storage is already in shared memory
and for CUDA tensors. Tensors in shared memory cannot be resized.
See :meth:`torch.UntypedStorage.share_memory_` for more details.
"""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.share_memory_, (self,), self)
self._typed_storage()._share_memory_()
return self
def module_load(self, other, assign=False):
r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`.
Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
It is expected that ``self`` is a parameter or buffer in an ``nn.Module`` and ``other`` is the
value in the state dictionary with the corresponding key, this method defines
how ``other`` is remapped before being swapped with ``self`` via
:func:`~torch.utils.swap_tensors` in :meth:`~nn.Module.load_state_dict`.
.. note::
This method should always return a new object that is not ``self`` or ``other``.
For example, the default implementation returns ``self.copy_(other).detach()``
if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``.
Args:
other (Tensor): value in state dict with key corresponding to ``self``
assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict`
"""
if has_torch_function_variadic(self, other):
return handle_torch_function(
Tensor.module_load, (self, other), self, other, assign=assign
)
if assign:
return other.detach()
else:
return self.copy_(other).detach()
def __reversed__(self):
r"""Reverses the tensor along dimension 0."""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__reversed__, (self,), self)
if self.dim() == 0:
return self
else:
return self.flip(0)
def norm(
self,
p: Optional[Union[float, str]] = "fro",
dim=None,
keepdim=False,
dtype=None,
):
r"""See :func:`torch.norm`"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype
)
return torch.norm(self, p, dim, keepdim, dtype=dtype)
def solve(self, other):
from torch._linalg_utils import solve
return solve(self, other)
def lstsq(self, other):
from torch._linalg_utils import lstsq
return lstsq(self, other)
def eig(self, eigenvectors=False):
from torch._linalg_utils import eig
return eig(self, eigenvectors=eigenvectors)
def symeig(self, eigenvectors=False):
from torch._linalg_utils import _symeig
return _symeig(self, eigenvectors=eigenvectors)
def lu(self, pivot=True, get_infos=False):
r"""See :func:`torch.lu`"""
# If get_infos is True, then we don't need to check for errors and vice versa
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos
)
LU, pivots, infos = torch._lu_with_info(
self, pivot=pivot, check_errors=(not get_infos)
)
if get_infos:
return LU, pivots, infos
else:
return LU, pivots
def stft(
self,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: "Optional[Tensor]" = None,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: Optional[bool] = None,
return_complex: Optional[bool] = None,
):
r"""See :func:`torch.stft`
.. warning::
This function changed signature at version 0.4.1. Calling with
the previous signature may cause error or return incorrect result.
"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.stft,
(self,),
self,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode,
normalized=normalized,
onesided=onesided,
return_complex=return_complex,
)
return torch.stft(
self,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
return_complex=return_complex,
)
def istft(
self,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: "Optional[Tensor]" = None,
center: bool = True,
normalized: bool = False,
onesided: Optional[bool] = None,
length: Optional[int] = None,
return_complex: bool = False,
):
r"""See :func:`torch.istft`"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.istft,
(self,),
self,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
length=length,
return_complex=return_complex,
)
return torch.istft(
self,
n_fft,
hop_length,
win_length,
window,
center,
normalized,
onesided,
length,
return_complex=return_complex,
)
def resize(self, *sizes):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.resize, (self,), self, *sizes)
warnings.warn("non-inplace resize is deprecated")