forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TensorNames.h
74 lines (58 loc) · 2.46 KB
/
TensorNames.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#pragma once
#include <ATen/WrapDimUtils.h>
namespace at { namespace namedinference {
// TensorName and TensorNames are wrappers around Dimname and DimnameList
// that contain helper functions to make writing name inference rules easier.
//
// A TensorName represents a Dimname associated with some DimnameList (from a Tensor).
// This encapsulates all the information that is needed to check if names *match*
// and to *unify* names.
//
// Definition: Two names in two tensors *match* if they are equal, or if at
// least one of them is a wildcard that can be *refined* to the other name.
//
// Definition: unify(name, other) fails if the names do not match. Otherwise,
// it returns the most refined of name and other.
//
// Here is an example of checking if two names match.
// tensor: Tensor[A, None]
// other: Tensor[A]
//
// Let's say we wish to check if tensor.names[-1] matches other.names[-1].
// None (in tensor) cannot match A (in other) because if the None were refined
// to A, `tensor` would have duplicate names [A, A]. Therefore we need to check
// tensor.names [A, None] for the existence of A.
struct CAFFE2_API TensorName {
explicit TensorName(ArrayRef<Dimname> origin, int origin_idx)
: origin_(origin),
name_(origin[maybe_wrap_dim(origin_idx, origin.size())]),
origin_idx_(origin_idx) {}
// op_name is only used for error reporting.
const TensorName& unify(const TensorName& other, const char* op_name) const;
Dimname toDimname() const;
private:
ArrayRef<Dimname> origin_;
Dimname name_;
int origin_idx_; // A named tensor can have at most 64 dims.
CAFFE2_API friend std::ostream& operator<<(
std::ostream& out,
const TensorName& tensorname);
};
using TensorNameVec = SmallVector<TensorName, 10>;
struct CAFFE2_API TensorNames {
explicit TensorNames(ArrayRef<Dimname> names);
// Create TensorNames from names[start:end]. Each individual TensorName stores
// `names`, NOT names[start:end], because the original tensor's names are `names`.
explicit TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end);
// op_name is only used for error reporting.
TensorNames& unifyFromRightInplace(
const TensorNames& other,
const char* op_name = "unify");
void checkUnique(const char* op_name) const;
void append(TensorName&& name);
std::vector<Dimname> toDimnameVec() const;
private:
explicit TensorNames(TensorNameVec&& names) : names_(names) {};
TensorNameVec names_;
};
}} // namespace at::namedinference