-
Notifications
You must be signed in to change notification settings - Fork 38
/
symbolic_opset11.py
515 lines (411 loc) · 22.7 KB
/
symbolic_opset11.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.onnx.symbolic_helper as sym_help
import warnings
import numpy
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_opset9 import expand
from torch.nn.modules.utils import _single, _pair, _triple
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 11
@parse_args('v', 'f', 'f')
def hardtanh(g, self, min_val, max_val):
dtype = self.type().scalarType()
if dtype is not None:
dtype = 6 # float
else:
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
min_val = g.op("Constant", value_t=torch.tensor(min_val, dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
max_val = g.op("Constant", value_t=torch.tensor(max_val, dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
return g.op("Clip", self, min_val, max_val)
def clamp(g, self, min, max):
dtype = self.type().scalarType()
def _cast_if_not_none(tensor, dtype):
if tensor is not None and not sym_help._is_none(tensor):
return g.op("Cast", tensor, to_i=sym_help.cast_pytorch_to_onnx[dtype])
else:
return tensor
if dtype is not None:
min = _cast_if_not_none(min, dtype)
max = _cast_if_not_none(max, dtype)
return g.op("Clip", self, min, max)
@parse_args('v', 'i')
def pixel_shuffle(g, self, upscale_factor):
dims = self.type().sizes()
if len(dims) != 4:
return _unimplemented("pixel_shuffle", "only support 4d input")
return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, align_corners=None):
align_corners = sym_help._maybe_get_scalar(align_corners)
coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \
else "align_corners" if align_corners else "pytorch_half_pixel"
empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
input_size = g.op("Shape", input)
input_size_beg = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0])
output_size = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Long"])
output_size = g.op("Concat", input_size_beg, output_size, axis_i=0)
return g.op("Resize",
input,
empty_tensor, # roi only takes effect whith coordinate_transformation_mode="tf_crop_and_resize"
empty_tensor, # scales is not needed since we are sending out_size
output_size,
coordinate_transformation_mode_s=coordinate_transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=interpolate_mode, # nearest, linear, or cubic
nearest_mode_s="floor") # only valid when mode="nearest"
return symbolic_fn
upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest")
upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest")
upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest")
upsample_linear1d = _interpolate('upsample_linear1d', 3, "linear")
upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, "linear")
upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, "linear")
upsample_bicubic2d = _interpolate('upsample_bicubic2d', 4, "cubic")
def __interpolate(g, input, size, scale_factor, mode, align_corners):
mode = sym_help._maybe_get_const(mode, 's')
if 'linear' in mode:
mode = 'linear'
if 'cubic' in mode:
mode = 'cubic'
align_corners = sym_help._maybe_get_const(align_corners, 'b')
align_corners = False if not isinstance(align_corners, bool) else align_corners
coordinate_transformation_mode = "asymmetric" if mode == "nearest" \
else "align_corners" if align_corners else "pytorch_half_pixel"
# roi only takes effect whith coordinate_transformation_mode="tf_crop_and_resize"
roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
if not sym_help._is_none(size) :
input_size = g.op("Shape", input)
input_size = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0])
# in some cases size is not a packed list but size is a scalar
# We need to also verify that (sym_help._maybe_get_const(size, 't').dim() == 0)
# but this information is not always available. Try to get the dim,
# and if not assume that it is not a scalar.
try:
is_scalar = not sym_help._is_packed_list(size) and ((sym_help._maybe_get_const(size, 't').dim() == 0))
except AttributeError:
is_scalar = not sym_help._is_packed_list(size)
if not is_scalar:
warnings.warn("Cannot verify if the output_size is a scalar while exporting interpolate. Assuming that it is not a scalar.")
if is_scalar:
if not input.type().dim():
return sym_help._unimplemented("interpolate (with a scalar output_size)", "missing input shape (try giving an array of output_size values)")
size = unsqueeze(g, size, 0)
size = [size for i in range(input.type().dim() - 2)]
size = g.op("Concat", *size, axis_i=0)
size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long'])
size = g.op("Concat", input_size, size, axis_i=0)
scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
return g.op("Resize",
input,
roi,
scales,
size,
coordinate_transformation_mode_s=coordinate_transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=mode, # nearest, linear, or cubic
nearest_mode_s="floor")
else: # if not sym_help._is_none(scales)
if not input.type().dim():
return sym_help._unimplemented("interpolate (with scales)", "missing input shape")
scales = sym_help._interpolate_get_scales(g, scale_factor, input.type().dim())
return g.op("Resize",
input,
roi,
scales,
coordinate_transformation_mode_s=coordinate_transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=mode, # nearest, linear, or cubic
nearest_mode_s="floor") # only valid when mode="nearest"
@parse_args('v', 'i', 'v', 'v')
def gather(g, self, dim, index, sparse_grad=False):
if sym_help._maybe_get_const(sparse_grad, 'i'):
return _unimplemented("gather", "sparse_grad == True")
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, dim, index, sparse_grad, operator_s="gather")
return g.op("GatherElements", self, index, axis_i=dim)
@parse_args('v', 'i', 'v', 'v')
def scatter(g, self, dim, index, src):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, dim, index, src, operator_s="scatter")
return g.op("ScatterElements", self, index, src, axis_i=dim)
@parse_args('v', 'i', 'none')
def cumsum(g, self, dim, dtype=None):
dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int))
csum = g.op("CumSum", self, dim_tensor)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
csum = g.op("Cast", csum, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return csum
def masked_select(g, self, mask):
from torch.onnx.symbolic_opset9 import nonzero, expand_as
index = nonzero(g, expand_as(g, mask, self))
return g.op('GatherND', self, index)
def masked_scatter(g, self, mask, source):
from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size
index = nonzero(g, expand_as(g, mask, self))
# NOTE: source can have more elements than needed.
# It could also have arbitrary shape.
# This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
source = view(g, source, torch.LongTensor([-1]))
source = sym_help._slice_helper(g, source,
axes=torch.LongTensor([0]),
starts=torch.LongTensor([0]),
ends=size(g, index, torch.LongTensor([0])),
dynamic_slice=True)
return g.op('ScatterND', self, index, source)
@parse_args('v', 'i', 'i', 'i')
def _unique2(g, self, sorted, return_inverse, return_counts):
u, indices, inverse_indices, counts = g.op("Unique", self, sorted_i=sorted, outputs=4)
return u, inverse_indices, counts
def _avg_pool(name, tuple_fn):
@parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none')
def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None):
padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name)
if count_include_pad:
input = g.op("Pad", input,
g.op("Constant", value_t=torch.tensor(((0,) * 2 + padding) * 2)), mode_s='constant')
padding = (0,) * len(padding)
output = g.op("AveragePool", input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding * 2,
ceil_mode_i=ceil_mode)
return output
return symbolic_fn
avg_pool1d = _avg_pool('avg_pool1d', _single)
avg_pool2d = _avg_pool('avg_pool2d', _pair)
avg_pool3d = _avg_pool('avg_pool3d', _triple)
@parse_args('v', 'i', 'i', 'i', 'i')
def unique_dim(g, self, dim, sorted, return_inverse, return_counts):
u, indices, inverse_indices, counts = g.op("Unique", self, axis_i=dim, sorted_i=sorted, outputs=4)
return u, inverse_indices, counts
@parse_args('v', 'v', 'i', 'i', 'i', 'none')
def topk(g, self, k, dim, largest, sorted, out=None):
return sym_help._topk_helper(g, self, k, dim, largest=largest, sorted=sorted, out=out)
@parse_args('v', 'i', 'i', 'none')
def sort(g, self, dim, decending, out=None):
return sym_help._sort_helper(g, self, dim, decending=decending, out=out)
def round(g, self):
return g.op("Round", self)
# Generate paddings in ONNX order based on pad in pytorch.
# Arguments:
# dim: the dimension of the tensor.
# pad: the paddings in pytorch.
# The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
# where m is in range [0, n].
def _prepare_onnx_paddings(g, dim, pad):
# The desired order of paddings is
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
# n is the dimension of input.
# Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning
pad_len = torch.onnx.symbolic_opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0])))
# Set extension = [0] * (dim * 2 - len(pad))
extension = g.op("Sub", g.op("Mul", g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64)),
g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), pad_len)
# Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ]
# Currently ONNX only supports int64 type for Pad
pad = g.op("Cast", pad, to_i=sym_help.cast_pytorch_to_onnx['Long'])
paddings = g.op("Concat", pad, g.op("ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64)), axis_i=0)
# Reshape and reverse order and collate first beginnings and then ends
# paddings = [[..., 0, dim_n-1_begin, dim_n_begin],
# [..., 0, dim_n-1_end, dim_n_end]]
# Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end]
paddings = g.op("Reshape", paddings, g.op("Constant", value_t=torch.tensor([-1, 2])))
paddings = g.op("Transpose", torch.onnx.symbolic_opset10.flip(g, paddings, [0]), perm_i=[1, 0])
paddings = g.op("Reshape", paddings, g.op("Constant", value_t=torch.tensor([-1])))
padding_c = g.op("Cast", paddings, to_i=sym_help.cast_pytorch_to_onnx['Long'])
return padding_c
def constant_pad_nd(g, input, padding, value=None):
mode = "constant"
value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, input)
pad = _prepare_onnx_paddings(g, input.type().dim(), padding)
return g.op("Pad", input, pad, value, mode_s=mode)
def reflection_pad(g, input, padding):
mode = "reflect"
paddings = _prepare_onnx_paddings(g, input.type().dim(), padding)
return g.op("Pad", input, paddings, mode_s=mode)
def replication_pad(g, input, padding):
mode = "edge"
paddings = _prepare_onnx_paddings(g, input.type().dim(), padding)
return g.op("Pad", input, paddings, mode_s=mode)
reflection_pad1d = reflection_pad
reflection_pad2d = reflection_pad
reflection_pad3d = reflection_pad
replication_pad1d = replication_pad
replication_pad2d = replication_pad
replication_pad3d = replication_pad
def det(g, self):
return g.op("Det", self)
def logdet(g, input):
from torch.onnx.symbolic_opset9 import log
return log(g, det(g, input))
def arange(g, *args):
def _get_arange_dtype(dtype):
dtype = sym_help._maybe_get_const(dtype, 'i')
return dtype
if len(args) == 5:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[1])
type, end, start, step = sym_help._arange_cast_helper(g, end=args[0], dtype=dtype)
start_default = g.op("Constant", value_t=torch.tensor(0, dtype=sym_help.scalar_type_to_pytorch_type[type]))
delta_default = g.op("Constant", value_t=torch.tensor(1, dtype=sym_help.scalar_type_to_pytorch_type[type]))
arange_tensor = g.op("Range", start_default, end, delta_default)
elif len(args) == 6:
# aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[2])
type, end, start, step = sym_help._arange_cast_helper(g, start=args[0], end=args[1], dtype=dtype)
delta_default = g.op("Constant", value_t=torch.tensor(1, dtype=sym_help.scalar_type_to_pytorch_type[type]))
arange_tensor = g.op("Range", start, end, delta_default)
elif len(args) == 7:
# aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[3])
type, end, start, step = sym_help._arange_cast_helper(g, start=args[0], end=args[1], step=args[2], dtype=dtype)
arange_tensor = g.op("Range", start, end, step)
else:
raise NotImplementedError("Unknown aten::arange signature taking " + str(len(args)) + " arguments.")
return arange_tensor
@parse_args('v', 'i')
def _dim_arange(g, like, dim):
like_shape = g.op('Shape', like)
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("_caffe2::Range", stop)
return arange(g, stop, 4, None, None, None)
def size(g, self, dim):
return sym_help._size_helper(g, self, dim)
def squeeze(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [sym_help._get_const(dim, 'i', 'dim')]
return g.op("Squeeze", self, axes_i=dims)
@parse_args('v', 'i')
def unsqueeze(g, self, dim):
return g.op("Unsqueeze", self, axes_i=[dim])
def mm(g, self, other):
return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0)
def index_fill(g, self, dim, index, value):
dim_value = sym_help._parse_arg(dim, 'i')
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, index, value, dim_i=dim_value, operator_s="index_fill")
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, self)
expanded_value = expand(g, value, expanded_index_shape, None)
return scatter(g, self, dim, expanded_index, expanded_value)
def index_copy(g, self, dim, index, source):
dim_value = sym_help._parse_arg(dim, 'i')
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, index, source, dim_i=dim_value, operator_s="index_copy")
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
return scatter(g, self, dim, expanded_index, source)
def __rshift_(g, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
if other.type().scalarType() != self.type().scalarType():
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
if self.type().scalarType() == 'Byte':
return g.op('BitShift', self, other, direction_s="RIGHT")
two = g.op('Constant', value_t=torch.tensor(2, dtype=torch.float32))
# exponent (same type as self) has to be float or double in onnx::Pow
if not sym_help._is_fp(self):
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float'])
two_pow = g.op('Pow', two, other)
rshift = g.op('Div', self, two_pow)
return rshift
def __lshift_(g, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
if other.type().scalarType() != self.type().scalarType():
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
if self.type().scalarType() == 'Byte':
return g.op('BitShift', self, other, direction_s="LEFT")
two = g.op('Constant', value_t=torch.tensor(2, dtype=torch.float32))
# exponent (same type as self) has to be float or double in onnx::Pow
if not sym_help._is_fp(self):
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float'])
two_pow = g.op('Pow', two, other)
lshift = g.op('Mul', self, two_pow)
return lshift
def _get_im2col_indices_along_dim(g, input_d, kernel_size_d, dilation_d, padding_d, stride_d):
# Input is always 4-D (N, C, H, W)
# Calculate indices of sliding blocks along spatial dimension
# Slide kernel over input each dim d:
# each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
# with steps = stride
blocks_d = g.op("Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2)))
blocks_d = g.op("Sub", blocks_d, g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))))
# Stride kernel over input and find starting indices along dim d
blocks_d_indices = g.op("Range", g.op("Constant", value_t=torch.tensor(0)),
blocks_d, g.op("Constant", value_t=torch.tensor(stride_d)))
# Apply dilation on kernel and find its indices along dim d
kernel_grid = numpy.arange(0, kernel_size_d * dilation_d, dilation_d)
kernel_grid = g.op("Constant", value_t=torch.tensor([kernel_grid]))
# Broadcast and add kernel staring positions (indices) with
# kernel_grid along dim d, to get block indices along dim d
blocks_d_indices = g.op('Unsqueeze', blocks_d_indices, axes_i=[0]) # Reshape to [1, -1]
kernel_mask = g.op('Reshape', kernel_grid, g.op('Constant', value_t=torch.tensor([-1, 1])))
block_mask = g.op("Add", blocks_d_indices, kernel_mask)
return block_mask
def _get_im2col_padded_input(g, input, padding_h, padding_w):
# Input is always 4-D tensor (N, C, H, W)
# Padding tensor has the following format: (padding_h, padding_w)
# Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2))
return g.op("Pad", input, pad)
def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
channel_unfolded = g.op("Mul", channel_dim,
g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w)))
return g.op("Concat",
g.op("Unsqueeze", batch_dim, axes_i=[0]),
g.op("Unsqueeze", channel_unfolded, axes_i=[0]),
g.op("Constant", value_t=torch.tensor([-1])), axis_i=0)
@parse_args('v', 'is', 'is', 'is', 'is')
def im2col(g, input, kernel_size, dilation, padding, stride):
# Input is always 4-D tensor (N, C, H, W)
# All other args are int[2]
input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2)))
input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3)))
stride_h, stride_w = stride[0], stride[1]
padding_h, padding_w = padding[0], padding[1]
dilation_h, dilation_w = dilation[0], dilation[1]
kernel_h, kernel_w = kernel_size[0], kernel_size[1]
blocks_row_indices = _get_im2col_indices_along_dim(g, input_h, kernel_h, dilation_h, padding_h, stride_h)
blocks_col_indices = _get_im2col_indices_along_dim(g, input_w, kernel_w, dilation_w, padding_w, stride_w)
output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)
# For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
# [[[[1., 2., 3.,],
# [4., 5., 6.,],
# [7., 8., 9.,]]]]
# First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[1., 2., 3.],
# [4., 5., 6.]],
# [[4., 5., 6.],
# [7., 8., 9.]]]]]
# And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[[1., 2.],
# [4., 5.]],
# [[2., 3.],
# [5., 6]]],
# [[[4., 5.],
# [7., 8.]],
# [[5., 6.],
# [8., 9.]]]]]]
# Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
# [[[1., 2., 4., 5.],
# [2., 3., 5., 6.],
# [4., 5., 7., 8.],
# [5., 6., 8., 9.]]]
output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2)
output = g.op("Gather", output, blocks_col_indices, axis_i=4)
output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5])
return g.op("Reshape", output, output_shape)