forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Context.cpp
187 lines (153 loc) · 4.13 KB
/
Context.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#include <ATen/Config.h>
#include <ATen/Context.h>
#include <c10/core/TensorOptions.h>
#include <mutex>
#include <sstream>
#include <stdexcept>
#include <string>
#include <thread>
#include <ATen/Tensor.h>
#include <ATen/cpu/FlushDenormal.h>
#include <TH/TH.h> // for USE_LAPACK
#ifdef USE_FBGEMM
#include <fbgemm/Fbgemm.h>
#endif // USE_FBGEMM
namespace at {
Context::Context()
: thc_state(nullptr, [](THCState* p) { /* no-op */ }),
thh_state(nullptr, [](THHState* p) { /* no-op */ }) {}
// TODO: This could be bad juju if someone calls globalContext() in the
// destructor of an object with static lifetime.
Context& globalContext() {
static Context globalContext_;
return globalContext_;
}
// NB: This method is *purely* whether or not a user requested
// that CuDNN was enabled, it doesn't actually say anything about
// whether or not CuDNN is actually usable.
bool Context::userEnabledCuDNN() const {
return enabled_cudnn;
}
void Context::setUserEnabledCuDNN(bool e) {
enabled_cudnn = e;
}
bool Context::userEnabledMkldnn() const {
return enabled_mkldnn;
}
void Context::setUserEnabledMkldnn(bool e) {
enabled_mkldnn = e;
}
bool Context::deterministicCuDNN() const {
return deterministic_cudnn;
}
void Context::setDeterministicCuDNN(bool b) {
deterministic_cudnn = b;
}
bool Context::deterministic() const {
return _deterministic;
}
void Context::setDeterministic(bool b) {
_deterministic = b;
}
void Context::alertNotDeterministic(c10::string_view const& caller) {
if (globalContext().deterministic()) {
TORCH_CHECK(false,
caller, " does not have a deterministic implementation, but you set "
"'torch.set_deterministic(True)'. You can turn off determinism just "
"for this operation if that's acceptable for your application. You "
"can also file an issue at https://github.com/pytorch/pytorch/issues "
"to help us prioritize adding deterministic support for this operation.");
}
}
bool Context::benchmarkCuDNN() const {
return benchmark_cudnn;
}
void Context::setBenchmarkCuDNN(bool b) {
benchmark_cudnn = b;
}
bool Context::hasMKL() const {
#if AT_MKL_ENABLED()
return true;
#else
return false;
#endif
}
bool Context::hasMKLDNN() const {
#if AT_MKLDNN_ENABLED()
return true;
#else
return false;
#endif
}
bool Context::hasOpenMP() const {
#ifdef _OPENMP
return true;
#else
return false;
#endif
}
bool Context::hasLAPACK() const {
#ifdef USE_LAPACK
return true;
#else
return false;
#endif
}
at::QEngine Context::qEngine() const {
// If wasn't explicitly set - take the last one available
return quantized_engine.value_or(supportedQEngines().back());
}
void Context::setQEngine(at::QEngine e) {
const auto& qengines = supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), e) != qengines.end()) {
quantized_engine = e;
return;
}
TORCH_CHECK(false, "quantized engine ", toString(e), " is not supported");
}
const std::vector<at::QEngine>& Context::supportedQEngines() const {
static auto supported_qengines = []() {
std::vector<at::QEngine> engines = {};
// Engines are listed in priority order: later one wins
// By default we prefer FBGEMM if we're running on server side
// QNNPACK on server side has some issue, so we disable it by default.
#ifdef C10_MOBILE
engines.push_back(at::kNoQEngine);
#ifdef USE_PYTORCH_QNNPACK
engines.push_back(at::kQNNPACK);
#endif
#else // C10_MOBILE
#ifdef USE_PYTORCH_QNNPACK
engines.push_back(at::kQNNPACK);
#endif
engines.push_back(at::kNoQEngine);
#endif // C10_MOBILE
#ifdef USE_FBGEMM
if (fbgemm::fbgemmSupportedCPU()) {
engines.push_back(at::kFBGEMM);
}
#endif
return engines;
}();
return supported_qengines;
}
bool Context::isXNNPACKAvailable() const {
#ifdef USE_XNNPACK
return true;
#else
return false;
#endif
}
bool Context::releaseWeightsWhenPrepacking() const {
return release_original_weights;
}
void Context::setReleaseWeightsWhenPrepacking(bool e) {
release_original_weights = e;
}
bool Context::setFlushDenormal(bool on) {
return at::cpu::set_flush_denormal(on);
}
Allocator* getCPUAllocator() {
return getTHDefaultAllocator();
}
} // namespace at