forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
variable.h
768 lines (654 loc) · 31 KB
/
variable.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
#pragma once
#include <torch/csrc/utils/python_stub.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/cpp_hook.h>
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
namespace torch { namespace autograd {
struct Node;
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Variable
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// A `Variable` augments a `Tensor` with the ability to interact in our
/// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between
/// `Node`s in the autograd graph. A `Variable` can either be a leaf, like a
/// weight in a neural network, or an interior variable, when it is the result
/// of an operation between variables. Every `Variable` also stores another
/// `Variable` called its `grad` (gradient). If the variable is a leaf, its
/// gradient will be accumulated into this variable.
///
/// Gradient Edges
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the
/// edge in the autograd graph that connects the variable to a particular input
/// of the gradient function that will be invoked with the variable during the
/// backward pass. More precisely, this gradient function can be one of two
/// things:
/// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the
/// gradient of the function that produced the variable.
/// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a
/// scalar gradient value into its `grad` variable.
///
/// Versioning
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Another major feature of `Variable`s are *versions*. Versions are
/// incremented when an in-place mutation of a variable occurs. Versions are
/// useful when constructing `SavedVariable`s, which take a snapshot of a
/// `Variable` at a certain version. You can retrieve a `Variable`'s version
/// through its `current_version()` method.
///
/// Views
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// It is possible for a `Variable` to be a *view* of another `Variable`, in
/// which case it tracks that `Variable`'s data and autograd history. Beyond
/// construction, the interface of a view is identical to that of a regular
/// `Variable`. You can determine whether `Variable` is in fact a view by
/// probing its `is_view()` method. Note that the *view* semantics are only
/// meaningful for `Variable` relations that are relevant to autograd. For
/// example, if you hide your code from autograd using `.no_grad()`, the
/// `Variable`s will not be registered as having view relations, even if they
/// share storage.
/// See NOTE [ Autograd View Variables ] for more details.
///
///
/// Interface
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// `Variable` inherits from `Tensor` and thus its API is a superset of that of
/// `Tensor`. This means you can perform all the usual mathematical and other
/// operations you can perform on `Tensor`s also on `Variable`s. Furthermore,
/// `Variable` and `Tensor` actually convert implicitly between each other. You
/// can thus call functions defined on `Tensor`s also with `Variable`s. For
/// this, the `Variable` class allows implicit construction from `Tensor`. It is
/// the responsibility of calling code to ensure that this constructor is
/// invoked only when the `Tensor` contains autograd metadata. Most notably, it
/// is *not* correct to construct a brand new `Variable` from a `Tensor` using
/// this constructor. To do so, you must use the `make_variable` free function
/// instead. To create a view variable, use `make_variable_view`.
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
struct TORCH_API Variable : public at::Tensor {
/// Default constructor.
Variable() = default;
// Factory Functions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// TODO: These factory functions don't need to be friends anymore. Move them out of
// the Variable class.
/// Creates a `Variable` that is a *view* of another (*base*) variable.
/// The `gradient_edge` is an optional (gradient_function, input_number) pair.
/// `is_differentiable` is a bool that specifies whether this view is
/// differentiable, i.e., whether the relation should be tracked by autograd.
/// See NOTE [ Autograd View Variables ] for details.
friend Variable make_variable_view(
Variable base,
at::Tensor data,
bool is_differentiable,
bool allow_tensor_metadata_change,
Edge gradient_edge);
/// Creates a `Variable` from the given `Tensor`, copying its underlying `TensorImpl`.
/// `requires_grad` should be
/// set only for leaves, and determines whether the `Variable` will accumulate
/// gradients. NOTE: `data` must *not* be a `Variable` already. Its dynamic
/// type *must* be `Tensor`.
friend Variable make_variable(
at::Tensor data,
bool requires_grad,
bool allow_tensor_metadata_change);
/// Creates a `Variable` from the given `Tensor`, copying its underlying `TensorImpl`.
/// `gradient_edge` should be a (function, input_nr) pair specifying the function
/// in the autograd graph, and what particular input of that function, this
/// variable is connected to.
friend Variable make_variable(
at::Tensor data,
Edge gradient_edge,
bool allow_tensor_metadata_change);
// Tensor Conversions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// "Downcasts" a `Tensor` into a `Variable`. Only call this on tensors you
// know are Variables.
/*implicit*/ Variable(at::Tensor const& rhs) : at::Tensor(rhs) {
TORCH_CHECK(
is_variable() || !defined(),
"Tensor that was converted to Variable was not actually a Variable");
}
/*implicit*/ Variable(at::Tensor&& rhs)
: at::Tensor(std::move(rhs)) {
TORCH_CHECK(
is_variable() || !defined(),
"Tensor that was converted to Variable was not actually a Variable");
}
// NOTE: Assignment operators to Tensor come for free from the constructors.
/// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
/// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
/// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
///
/// One notable difference with the legacy `.data()` function is that changes to the
/// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
/// will not update the original `Variable`, due to the fact that this function
/// shallow-copies the `Variable`'s underlying TensorImpl.
at::Tensor tensor_data() const noexcept;
/// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
/// in Python, which create a new `Variable` that shares the same storage and
/// tensor metadata with the original `Variable`, but with a completely new
/// autograd history.
///
/// NOTE: If we change the tensor metadata (e.g. sizes / strides /
/// storage / storage_offset) of a variable created from `var.variable_data()`, those
/// changes will not update the original variable `var`. In `.variable_data()`, we set
/// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
/// in order to prevent users from changing metadata of `var.variable_data()`
/// and expecting the original variable `var` to also be updated.
at::Tensor variable_data() const noexcept;
// Gradient Node and Edges
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Gets the gradient function of the `Variable`. If this is a leaf variable,
/// the pointer returned will be null.
///
/// For View Variables:
/// Gets the up-to-date grad_fn. If the shared data or base was modified, we
/// re-create the grad_fn to express the up-to-date view relationship between
/// this and the base Variable.
const std::shared_ptr<Node>& grad_fn() const;
/// Gets the raw gradient function pointer, whatever it currently is.
Node* grad_fn_unsafe() const;
/// Set the gradient accumulator of the `Variable`. This is only applicable to
/// leaf variables. Interior variables should call `set_gradient_edge()`.
void set_grad_accumulator(std::weak_ptr<Node> grad_accumulator);
/// Attempts to get a pointer to the gradient accumulator of the `Variable`,
/// if it still exists. If the gradient accumulator function has been
/// destroyed, returns a `nullptr`.
std::shared_ptr<Node> try_get_grad_accumulator() const;
/// Gets the gradient accumulator of the `Variable` if it has one, or else
/// create one on the fly and return it.
std::shared_ptr<Node> grad_accumulator() const;
/// Returns the "canonical" gradient edge of this `Variable`, i.e. either the
/// gradient function if this is an interior `Variable`, or the gradient
/// accumulator otherwise. If the `Variable` is interior, the returned `Edge`
/// will store the input index of the `Node` to which this variable is
/// connected in its `input_nr` field. For leaves, the `input_nr` is always
/// zero. Note that `set_gradient_edge` and `gradient_edge` are not
/// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and
/// `set_grad_accumulator` to set the accumulator.
Edge gradient_edge() const {
// If grad_fn is null (as is the case for a leaf node), we instead
// interpret the gradient function to be a gradient accumulator, which will
// accumulate its inputs into the grad property of the variable. These
// nodes get suppressed in some situations, see "suppress gradient
// accumulation" below. Note that only variables which have `requires_grad =
// True` can have gradient accumulators.
if (const auto& gradient = grad_fn()) {
return Edge(gradient, output_nr());
} else {
return Edge(grad_accumulator(), 0);
}
}
/// Returns a copy of this `Variable` that is detached from its autograd graph
/// and has a blank version. This method is OK to call if the `Variable` is a
/// view.
/// NOTE: Previously, if we change the tensor metadata (e.g. sizes / strides /
/// storage / storage_offset) of a tensor created from `detach()`, those metadata
/// in the original tensor will also be updated. However, the new behavior is that
/// those metadata changes to the detached tensor will not update the original tensor
/// anymore, and in the `detach()` function we need to set `allow_tensor_metadata_change_`
/// to false to make such changes explicitly illegal, in order to prevent users from
/// changing metadata of the detached tensor and expecting the original tensor to also
/// be updated.
Variable detach() const;
/// Like `detach()`, but removes this `Variable` in-place. This method may
/// only be called on non-view `Variable`s. You can use `is_view()` to check
/// this. If this `Variable` is a view, throws an `std::runtime_error()`.
void detach_();
/// Computes the gradient of current tensor w.r.t. graph leaves.
void backward(
const Tensor& gradient,
bool keep_graph,
bool create_graph) const;
/// Sets the tensor data held by this `Variable` to be the same as `new_data`.
/// It requires that `new_data` and `Variable` have compatible tensor type, by
/// checking `_has_compatible_shallow_copy_type(this, new_data)`.
void set_data(const at::Tensor &new_data) const;
/// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the
/// `Variable`.
/// NOTE: This will always set the `grad_fn`, even if this is a leaf variable,
/// and never the `grad_accumulator`. For the latter, use
/// `set_grad_accumulator`. This allows late construction of an interior
/// `Variable`.
void set_gradient_edge(Edge edge) noexcept;
/// Returns the input index of the gradient `Node` to which this
/// `Variable` is connected. Note: input indexes of the gradient `Node`
/// correspond to output indexes of the corresponding forward `Node`.
uint32_t output_nr() const noexcept;
/// True if this `Variable` is a leaf and thus does not have a `grad_fn`.
bool is_leaf() const noexcept;
// Versions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Increments the version count of this `Variable`.
void bump_version() noexcept;
void set_version_counter(const c10::VariableVersion& version_counter) noexcept;
/// Retrieves this `Variable`s version counter.
const c10::VariableVersion& version_counter() const noexcept;
/// Retrieves the current value of the `Variable`'s version counter.
/// Equivalent to calling `version_counter().current_version()`.
uint32_t current_version() const noexcept;
// Autograd Graph Interaction
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Update the `grad_fn` of an existing Variable. Called after in-place
/// modifications.
///
/// For View Variables:
/// Called after in-place modifications. Modifies the grad_fn of the base
/// Variable.
void rebase_history(Edge gradient_edge);
// Hooks
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void add_hook(std::shared_ptr<FunctionPreHook> hook);
const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const noexcept;
void clear_hooks();
template <typename T>
using hook_return_void_t = c10::guts::enable_if_t<std::is_void<typename std::result_of<T&(Variable)>::type>::value, unsigned>;
template <typename T>
using hook_return_var_t = c10::guts::enable_if_t<std::is_same<typename std::result_of<T&(Variable)>::type, Variable>::value, unsigned>;
// Remove hook at given position
void remove_hook(unsigned pos);
// Returns the index of the hook in the list which can be used to remove hook
// Register a hook with no return value
template <typename T>
hook_return_void_t<T> register_hook(T&& hook);
// Register a hook with variable return value
template <typename T>
hook_return_var_t<T> register_hook(T&& hook);
// View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Returns true if this `Variable` is a view of another `Variable`.
bool is_view() const noexcept;
/// Returns the `Variable` that this `Variable` is a view of. If this
/// `Variable` is not a view, throw a `std::runtime_error`.
const Variable& base() const;
// Miscellaneous
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
void set_name(const std::string& name);
const std::string& name() const noexcept;
PyObject* pyobj() const noexcept;
void set_pyobj(PyObject* pyobj) noexcept;
struct AutogradMeta;
Variable::AutogradMeta* get_autograd_meta() const noexcept;
private:
struct DifferentiableViewMeta;
// Private Methods
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Variable(c10::intrusive_ptr<at::TensorImpl> self);
at::TensorImpl* get() const;
void create_cpp_hook();
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Variable::AutogradMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
/// metadata fields that are necessary for tracking the Variable's autograd history.
struct TORCH_API Variable::AutogradMeta : public c10::AutogradMetaInterface {
std::string name;
Variable grad_;
std::shared_ptr<Node> grad_fn_;
std::weak_ptr<Node> grad_accumulator_;
std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
std::shared_ptr<hooks_list> cpp_hooks_list;
// Only meaningful on leaf variables (must be false otherwise)
bool requires_grad_;
bool is_view_;
// The "output number" of this variable; e.g., if this variable
// was the second output of a function, then output_nr == 1.
// We use this to make sure we can setup the backwards trace
// correctly when this variable is passed to another function.
uint32_t output_nr_;
// Mutex to ensure that concurrent read operations that modify internal
// state are still thread-safe. Used by grad_fn() and
// grad_accumulator().
std::mutex mutex_;
/// Sets the `requires_grad` property of `Variable`. This should be true for
/// leaf variables that want to accumulate gradients, and false for all other
/// variables.
void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) override {
TORCH_CHECK(
!requires_grad || at::isFloatingType(at::typeMetaToScalarType(self_impl->dtype())),
"Only Tensors of floating point dtype can require gradients");
requires_grad_ = requires_grad;
}
bool requires_grad() const override {
return requires_grad_ || grad_fn_;
}
/// Accesses the gradient `Variable` of this `Variable`.
Variable& grad() override {
return grad_;
}
const Variable& grad() const override {
return grad_;
}
AutogradMeta(
at::TensorImpl* self_impl,
bool requires_grad = false,
Edge gradient_edge = Edge());
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Variable::DifferentiableViewMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// NOTE [ Autograd View Variables ]
///
/// Many operations return Variable that shares storage with an input Variable.
/// The returned Variable is called a **view** Variable on the input **base**
/// Variable.
///
/// In PyTorch, we have two types of views: differentiable views, and
/// non-differentiable views. In either type, to support proper version
/// checking, the base and view Variables must always share the same
/// version_counter.
///
///
/// Differentiable Views
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Differentiable views are the view variables where you want gradients to flow
/// back to the base variables. Out-of-place operations on views are quite
/// straightforward, but in-place ones are very tricky. Even if the base
/// variable may not require grad when we create the view, we still need to
/// track the view relation because future in-place ops may require back-proping
/// through it. For example, we need to support
///
/// (1) in-place operation on view, e.g.,
///
/// # Have:
/// # base.requires_grad = False
/// # var.requires_grad = True
/// base[1] = var # i.e., base[1].copy_(var)
/// torch.autograd.grad(base.sum(), var) <- should return an all ones tensor
///
/// (2) in-place operation on base after view is created, e.g.,
///
/// # Have:
/// # base.requires_grad = False
/// # var.requires_grad = True
/// view = base[1]
/// base.copy_(var)
/// torch.autograd.grad(view.sum(), var) <- should return a tensor with
/// var[1] filled with all ones and
/// zeros everywhere else
///
/// Variable::DifferentiableViewMeta is created to support gradient tracking of
/// such **in-place** operations. In particular,
/// + if an in-place op is done on base, the grad_fn field of the view may
/// become stale. So accesses should always go through grad_fn(), which
/// reconstructs an updated grad_fn if the version_counter has incremented.
/// All other fields are always valid.
/// + if an in-place op is done on view, in rebase_history() of view, which is
/// called after every in-place op in VariableType.cpp, the grad_fn of base
/// is updated.
///
///
/// Non-Differentiable Views
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// In certain cases, although function outputs share storage with inputs, they
/// will **never** require gradient history tracking. Instead of registering the
/// view relation via DifferentiableViewMeta in autograd, the views will be
/// using usual AutogradMeta and just share the version counters with the base
/// Variables.
/// Such views include:
/// 1. Views created from .detach()
/// 2. Views that are non-differentiable by its nature.
/// E.g., `sparse_tensor.indices()` is a integral view on a (possibly)
/// floating point tensor.
/// See top of `derivatives.yaml` on how to specify that outputs of a
/// function are non-differentiable.
/// These are called non-differentiable views as the gradients do not flow
/// through the view relation.
/// Relevant logic for non-differentiable views is implemented in
/// make_variable_view below, and wrap_output of gen_variable_type.py.
struct TORCH_API Variable::DifferentiableViewMeta : public Variable::AutogradMeta {
/// The base `Variable` (never a view).
Variable base_;
/// The value of the version_counter at the time grad_fn was created. The
/// grad_fn field is stale if attr_version !=
/// version_counter.current_version().
uint32_t attr_version;
bool requires_grad() const override {
return requires_grad_ || grad_fn_ || (is_view_ && base_.requires_grad());
}
DifferentiableViewMeta(at::TensorImpl* self_impl, Variable base, Edge gradient_edge);
~DifferentiableViewMeta();
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Variable Implementation
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Factory Functions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// NOTE: `allow_tensor_metadata_change` is set to true by default, because there
/// are a lot of call sites to these factory functions that need to change the
/// variable's size or storage afterwards, and they don't expect the original
/// tensor (where the variable is created from) to be updated. Setting
/// `allow_tensor_metadata_change_` to false by default would unnecessarily
/// prevent those changes from happening and is undesirable.
// See NOTE [ Autograd View Variables ] for details.
inline Variable make_variable_view(
Variable base,
at::Tensor data,
bool is_differentiable = true,
bool allow_tensor_metadata_change = true,
Edge gradient_edge = Edge()) {
if (data.defined()) {
if (is_differentiable) {
/// Differentiable view. Track history with DifferentiableViewMeta.
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
data_impl_copy->set_autograd_meta(c10::guts::make_unique<Variable::DifferentiableViewMeta>(
data_impl_copy.get(), std::move(base), std::move(gradient_edge)));
return Variable(data_impl_copy);
} else {
/// Non-differentiable view. Just share version counter.
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/base.version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
data_impl_copy->set_autograd_meta(c10::guts::make_unique<Variable::AutogradMeta>(
data_impl_copy.get(), false, std::move(gradient_edge)));
return Variable(data_impl_copy);
}
}
return Variable();
}
inline Variable make_variable(
at::Tensor data,
bool requires_grad = false,
bool allow_tensor_metadata_change = true) {
TORCH_CHECK(
!data.is_variable(),
"Must not create a new variable from a variable, use its .tensor_data()");
if (data.defined()) {
if (data.getIntrusivePtr().use_count() == 1 && data.getIntrusivePtr()->unique_version()) {
auto data_impl = data.getIntrusivePtr();
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
data_impl->set_autograd_meta(c10::guts::make_unique<Variable::AutogradMeta>(data_impl.get(), requires_grad));
return Variable(std::move(data_impl));
} else {
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
data_impl_copy->set_autograd_meta(c10::guts::make_unique<Variable::AutogradMeta>(
data_impl_copy.get(), requires_grad));
return Variable(data_impl_copy);
}
}
return Variable();
}
inline Variable make_variable(
at::Tensor data,
Edge gradient_edge,
bool allow_tensor_metadata_change = true) {
TORCH_CHECK(
!data.is_variable(),
"Must not create a new variable from a variable, use its .tensor_data()");
if (data.defined()) {
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
data_impl_copy->set_autograd_meta(c10::guts::make_unique<Variable::AutogradMeta>(
data_impl_copy.get(), false, std::move(gradient_edge)));
return Variable(data_impl_copy);
}
return Variable();
}
// Tensor Conversion
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Downcasts the `Tensor` reference to a `Variable` reference. If compiling
/// in DEBUG mode and the tensor's dynamic type is not in fact `Variable`,
/// throws a `std::invalid_argument` exception.
inline Variable& as_variable_ref(at::Tensor& tensor) {
TORCH_CHECK(
tensor.is_variable(),
"Attempted to cast a Tensor to a Variable, but "
"the dynamic type of the value is not Variable.");
return static_cast<Variable&>(tensor);
}
inline const Variable& as_variable_ref(const at::Tensor& tensor) {
TORCH_CHECK(
tensor.is_variable(),
"Attempted to cast a Tensor to a Variable, but "
"the dynamic type of the value is not Variable.");
return static_cast<const Variable&>(tensor);
}
inline at::Tensor Variable::tensor_data() const noexcept {
auto self_impl_copy = get()->shallow_copy_and_detach(
/*version_counter=*/get()->version_counter(),
/*allow_tensor_metadata_change=*/get()->allow_tensor_metadata_change());
return at::Tensor(self_impl_copy);
}
inline at::Tensor Variable::variable_data() const noexcept {
auto self_impl_copy = get()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/false);
self_impl_copy->set_autograd_meta(c10::guts::make_unique<Variable::AutogradMeta>(self_impl_copy.get(), false));
return at::Tensor(self_impl_copy);
}
// Gradient Node and Edges
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline Node* Variable::grad_fn_unsafe() const {
return get_autograd_meta()->grad_fn_.get();
}
inline void Variable::set_grad_accumulator(
std::weak_ptr<Node> grad_accumulator) {
get_autograd_meta()->grad_accumulator_ = std::move(grad_accumulator);
}
inline std::shared_ptr<Node> Variable::try_get_grad_accumulator() const {
return get_autograd_meta()->grad_accumulator_.lock();
}
inline Variable Variable::detach() const {
auto var = make_variable_view(*this, *this, /*is_differentiable=*/false, /*allow_tensor_metadata_change=*/false, Edge());
return var;
}
inline void Variable::set_gradient_edge(Edge edge) noexcept {
get_autograd_meta()->grad_fn_ = std::move(edge.function);
get_autograd_meta()->output_nr_ = edge.input_nr;
}
inline uint32_t Variable::output_nr() const noexcept {
return get_autograd_meta()->output_nr_;
}
inline bool Variable::is_leaf() const noexcept {
return get_autograd_meta()->grad_fn_ == nullptr;
}
// Versions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::set_version_counter(
const c10::VariableVersion& version_counter) noexcept {
unsafeGetTensorImpl()->set_version_counter(version_counter);
}
inline void Variable::bump_version() noexcept {
unsafeGetTensorImpl()->bump_version();
}
inline uint32_t Variable::current_version() const noexcept {
return unsafeGetTensorImpl()->version_counter().current_version();
}
inline const c10::VariableVersion& Variable::version_counter() const noexcept {
return unsafeGetTensorImpl()->version_counter();
}
// Hooks
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::add_hook(std::shared_ptr<FunctionPreHook> hook) {
get_autograd_meta()->hooks_.push_back(std::move(hook));
}
inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks()
const noexcept {
return get_autograd_meta()->hooks_;
}
inline void Variable::clear_hooks() {
get_autograd_meta()->hooks_.clear();
}
template <typename T>
auto Variable::register_hook(T&& hook) -> Variable::hook_return_void_t<T> {
TORCH_CHECK(requires_grad(), "cannot register a hook on a variable that "
"doesn't require gradient");
auto &list = get_autograd_meta()->cpp_hooks_list;
if(!list) {
create_cpp_hook();
}
unsigned idx = list->size();
// Return the grad argument in case of a hook with void return type to have an
// std::function with Variable return type
std::function<void(Variable)> fn(hook);
list->emplace_back([fn](Variable grad){
fn(grad);
return Variable();});
return idx;
}
template <typename T>
auto Variable::register_hook(T&& hook) -> Variable::hook_return_var_t<T> {
TORCH_CHECK(requires_grad(), "cannot register a hook on a variable that "
"doesn't require gradient");
auto &list = get_autograd_meta()->cpp_hooks_list;
if(!list) {
create_cpp_hook();
}
unsigned idx = list->size();
list->push_back(hook);
return idx;
}
// View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline bool Variable::is_view() const noexcept {
return get_autograd_meta()->is_view_;
}
inline const Variable& Variable::base() const {
if (is_view()) {
auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
return diff_view_meta->base_;
} else {
throw std::runtime_error("Can't get base of non-view Variable");
}
}
// Miscellaneous
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::set_name(const std::string& name) {
get_autograd_meta()->name = name;
}
inline const std::string& Variable::name() const noexcept {
return get_autograd_meta()->name;
}
inline void Variable::set_pyobj(PyObject* pyobj) noexcept {
get()->set_pyobj(pyobj);
}
inline PyObject* Variable::pyobj() const noexcept {
return get()->pyobj();
}
inline Variable::AutogradMeta* Variable::get_autograd_meta() const noexcept {
return static_cast<Variable::AutogradMeta*>(get()->autograd_meta());
}
// Private Methods
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline Variable::Variable(c10::intrusive_ptr<at::TensorImpl> self)
: at::Tensor(std::move(self)) {}
inline at::TensorImpl* Variable::get() const {
TORCH_CHECK(defined(), "Called Variable::get() on an undefined Variable");
return unsafeGetTensorImpl();
}
}} // namespace torch::autograd