forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BatchedTensorImpl.h
134 lines (109 loc) · 4.62 KB
/
BatchedTensorImpl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#pragma once
#include <bitset>
#include <ATen/ArrayRef.h>
#include <ATen/SmallVector.h>
#include <ATen/Tensor.h>
namespace at {
// We assume this in a few other places in the codebase,
// but there isn't a centralized definition.
constexpr int64_t kVmapMaxTensorDims = 64;
// The valid vmap levels range from [0, 64). This effectively means that we
// support a maximum of 64 nested vmaps.
constexpr int64_t kVmapNumLevels = 64;
// Store this number of elements of BatchDims on the stack. Most people will
// probably use <= 5 nested vmaps, but adjust this number as necessary.
constexpr int64_t kBatchDimsStackSize = 5;
// a BatchDim represents a "private" dimension on a Tensor created inside of
// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
// is being vmap'ed over and the `level` being an identifier for which vmap
// said dimension was created inside.
struct BatchDim {
BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
int64_t dim() const {
return dim_;
}
int64_t level() const {
return level_;
}
private:
int64_t dim_;
int64_t level_;
};
using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
using BatchDimsRef = ArrayRef<BatchDim>;
// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
// BatchedTensorImpl.
//
// The batch dimensions are treated as being "private"; they are not user-visible.
// For example, in the following Tensor,
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
// dimensions 0 and 1 are batch dimensions.
//
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
// Returns a reference to BatchDims that represent which dimensions of this
// tensor are private.
BatchDimsRef bdims() const { return bdims_; }
// BatchedTensorImpl wraps a Tensor
const Tensor& value() const { return value_; };
// Given a public dimension index, return the dimension index in the underlying
// value() tensor.
// For example, if we have
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=2)])
// bt.actualDim(0) -> 1
// bt.actualDim(1) -> 3
// bt.actualDim(2) -> Error
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
// Override a bunch of methods inherited from TensorImpl to return error messages.
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
IntArrayRef strides() const override;
int64_t stride(int64_t d) const override;
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;
bool has_storage() const override;
const Storage& storage() const override;
int64_t storage_offset() const override;
private:
// see NOTE: [BatchedTensorImpl levels invariant]
void checkInvariants() const;
Tensor value_;
// Note: [BatchedTensorImpl levels invariant]
// There is an invariant that the BatchDims must be stored in increasing `level`
// order. That is, for i < j, bdims_[i].level must be less than bdims_[j].level.
BatchDims bdims_;
};
inline bool isBatched(const Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
}
// It is unsafe to call this on a Tensor that is not backed by a
// BatchedTensorImpl. Please use `maybeGetBatched` whenever possible.
inline BatchedTensorImpl* unsafeGetBatched(Tensor tensor) {
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
}
inline BatchedTensorImpl* maybeGetBatched(Tensor tensor) {
if (!isBatched(tensor)) {
return nullptr;
}
return unsafeGetBatched(tensor);
}
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(BatchDimsRef bdims) {
std::bitset<kVmapMaxTensorDims> is_bdim;
for (const auto& bdim : bdims) {
is_bdim.set(bdim.dim());
}
return is_bdim;
}
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
return out;
}
// Use this to construct a BatchedTensor from a regular Tensor
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
// Adds a batch dim to `tensor`, returning a BatchedTensor
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
}