forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MKLDNNConversions.cpp
141 lines (121 loc) · 4.65 KB
/
MKLDNNConversions.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/utils/ParamUtils.h>
namespace at { namespace native {
#if AT_MKLDNN_ENABLED()
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor) {
ideep::tensor& stensor = itensor_from_mkldnn(mkldnn_tensor);
auto dims = stensor.get_dims();
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
Tensor cpu_tensor = at::empty(
std::vector<int64_t>(dims.begin(), dims.end()),
mkldnn_tensor.options().layout(c10::kStrided));
if (stensor.is_empty()) return cpu_tensor;
auto pub_tensor = stensor.to_public(cpu_tensor.template data_ptr<float>());
cpu_tensor.as_strided_(dims, pub_tensor.get_strides());
return cpu_tensor;
}
Tensor dense_to_mkldnn(const Tensor& cpu_tensor) {
TORCH_CHECK(cpu_tensor.device().type() == DeviceType::CPU,
"dense_to_mkldnn expects CPU tensor input");
TORCH_CHECK(cpu_tensor.layout() == Layout::Strided,
"dense_to_mkldnn expects strided tensor input");
TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Float,
"dense_to_mkldnn expects float tensor input");
TORCH_CHECK(cpu_tensor.dim() <= 5,
"Can't convert cpu tensor with the number of dimensions > 5");
// TODO: consider to convert non-contiguous tensor to `ideep::tensor` directly.
auto cpu_tensor_cont = cpu_tensor.contiguous();
Tensor mkldnn_tensor = empty_mkldnn(cpu_tensor_cont.sizes(), cpu_tensor_cont.options());
ideep::tensor& dtensor = itensor_from_mkldnn(mkldnn_tensor);
dtensor.feed_from(dtensor.get_dims(),
ideep::tensor::data_type::f32,
(cpu_tensor_cont.template data_ptr<float>()));
return mkldnn_tensor;
}
// Mkldnn tensor has special non-public format for conv2d weights
// (dense_to_mkldnn only converts dense tensor to mkldnn tensor with
// public format). Ideep conv kernel will do implicit reorder if the
// weight is not already in this optimized format. By the time I'm
// writing this note, we are seeing ~20% perf cost of doing the
// on-the-fly reorder.
Tensor mkldnn_reorder_conv2d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
auto w = itensor_from_mkldnn(self);
// Legacy mkldnn conv2d jitted module may contain a 5-d weight with an extra
// dimension when groups > 1, having dimension [g, o/g, i, h, w] instead of
// [o, i, h, w]. Ideally we should reorder the weight back in serialization.
// For backward compatibility, we squash the first two dims (g * o/g) back to
// its original form.
if (w.ndims() == 5) {
auto wdims = w.get_dims();
w.reshape({wdims[0] * wdims[1], wdims[2], wdims[3], wdims[4]});
}
auto desc =
ideep::convolution_forward::expected_weights_desc(
w.get_dims(),
w.get_data_type(),
{stride.begin(), stride.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
{dilation.begin(), dilation.end()},
groups,
ideep::algorithm::convolution_direct);
ideep::tensor result;
result.init(desc);
result.feed_from(w);
return new_with_itensor_mkldnn(std::move(result), self.options());
}
Tensor mkldnn_reorder_conv3d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
auto w = itensor_from_mkldnn(self);
auto desc =
ideep::convolution_forward::expected_weights_desc(
w.get_dims(),
w.get_data_type(),
{stride.begin(), stride.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
{dilation.begin(), dilation.end()},
groups,
ideep::algorithm::convolution_direct);
ideep::tensor result;
result.init(desc);
result.feed_from(w);
return new_with_itensor_mkldnn(std::move(result), self.options());
}
#else
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor) {
TORCH_CHECK(false, "MKL-DNN build is disabled");
}
Tensor dense_to_mkldnn(const Tensor& cpu_tensor) {
TORCH_CHECK(false, "MKL-DNN build is disabled");
}
Tensor mkldnn_reorder_conv2d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
TORCH_CHECK(false, "mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled");
}
Tensor mkldnn_reorder_conv3d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
TORCH_CHECK(false, "mkldnn_reorder_conv3d_weight: MKL-DNN build is disabled");
}
#endif // AT_MKLDNN_ENABLED()
}}