forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_pytorch_onnx_onnxruntime.py
3755 lines (3022 loc) · 139 KB
/
test_pytorch_onnx_onnxruntime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import onnxruntime # noqa
import torch
import numpy as np
import io
import itertools
import copy
from torch.nn.utils import rnn as rnn_utils
from model_defs.lstm_flattening_result import LstmFlatteningResult
from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence
from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion, enableScriptTest,
skipIfUnsupportedOpsetVersion, skipIfNoLapack)
from test_pytorch_common import BATCH_SIZE
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
import model_defs.word_language_model as word_language_model
import torchvision
import onnx
def ort_test_with_input(ort_sess, input, output, rtol, atol):
input, _ = torch.jit._flatten(input)
output, _ = torch.jit._flatten(output)
def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()
inputs = list(map(to_numpy, input))
outputs = list(map(to_numpy, output))
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
ort_outs = ort_sess.run(None, ort_inputs)
# compare onnxruntime and PyTorch results
assert len(outputs) == len(ort_outs), "number of outputs differ"
# compare onnxruntime and PyTorch results
[np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
def run_model_test(self, model, batch_size=2, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
dynamic_axes=None, test_with_inputs=None,
input_names=None, output_names=None,
fixed_batch_size=False):
model.eval()
if input is None:
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
with torch.no_grad():
if isinstance(input, torch.Tensor):
input = (input,)
# In-place operators will update input tensor data as well.
# Thus inputs are replicated before every forward call.
input_copy = copy.deepcopy(input)
output = model(*input_copy)
if isinstance(output, torch.Tensor):
output = (output,)
# export the model to ONNX
f = io.BytesIO()
input_copy = copy.deepcopy(input)
torch.onnx._export(model, input_copy, f,
opset_version=self.opset_version,
example_outputs=output,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
input_names=input_names, output_names=output_names,
fixed_batch_size=fixed_batch_size)
# compute onnxruntime output prediction
ort_sess = onnxruntime.InferenceSession(f.getvalue())
input_copy = copy.deepcopy(input)
ort_test_with_input(ort_sess, input_copy, output, rtol, atol)
# if additional test inputs are provided run the onnx
# model with these inputs and check the outputs
if test_with_inputs is not None:
for test_input in test_with_inputs:
if isinstance(test_input, torch.Tensor):
test_input = (test_input,)
test_input_copy = copy.deepcopy(test_input)
output = model(*test_input_copy)
if isinstance(output, torch.Tensor):
output = (output,)
ort_test_with_input(ort_sess, test_input, output, rtol, atol)
class TestONNXRuntime(unittest.TestCase):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
opset_version = _export_onnx_opset_version
keep_initializers_as_inputs = True # For IR version 3 type export.
def setUp(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
np.random.seed(seed=0)
self.is_script_test_enabled = False
def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True,
batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None,
input_names=None, output_names=None, fixed_batch_size=False):
def _run_test(m):
return run_model_test(self, m, batch_size=batch_size,
input=input, use_gpu=use_gpu, rtol=rtol, atol=atol,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs,
input_names=input_names, output_names=output_names,
fixed_batch_size=fixed_batch_size)
if self.is_script_test_enabled:
script_model = torch.jit.script(model)
_run_test(script_model)
_run_test(model)
def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
dynamic_axes=None, input_names=None, output_names=None,
ort_optim_on=True):
import os
import tempfile
model.eval()
with torch.no_grad():
if isinstance(input, torch.Tensor):
input = (input,)
# In-place operators will update input tensor data as well.
# Thus inputs are replicated before every forward call.
input_copy = copy.deepcopy(input)
output = model(*input_copy)
if isinstance(output, torch.Tensor):
output = (output,)
# export the model to ONNX
with tempfile.TemporaryDirectory() as tmpdirname:
model_file_name = os.path.join(tmpdirname, 'model.onnx')
input_copy = copy.deepcopy(input)
torch.onnx.export(model, input_copy, model_file_name,
opset_version=self.opset_version,
example_outputs=output,
verbose=False,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
input_names=input_names, output_names=output_names,
use_external_data_format=True)
# compute onnxruntime output prediction
ort_sess_opt = onnxruntime.SessionOptions()
ort_sess_opt.graph_optimization_level = \
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED if ort_optim_on else \
onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
ort_sess = onnxruntime.InferenceSession(model_file_name, sess_options=ort_sess_opt)
input_copy = copy.deepcopy(input)
ort_test_with_input(ort_sess, input_copy, output, rtol, atol)
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
def test_embedding_model_with_external_data(self):
class LargeModel(torch.nn.Module):
def __init__(self):
super(LargeModel, self).__init__()
dim = 15
n = 4 * 100
self.emb = torch.nn.Embedding(n, dim)
self.lin1 = torch.nn.Linear(dim, 1)
self.seq = torch.nn.Sequential(
self.emb,
self.lin1,
)
def forward(self, input):
return self.seq(input)
model = LargeModel()
x = torch.tensor([2], dtype=torch.long)
self.run_model_test_with_external_data(model, x)
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
def test_mobilenet_v2_with_external_data(self):
model = torchvision.models.mobilenet_v2(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
# We are turning off Onnx Runtime optimization off in this test,
# because external data format is not supported to in ORT optimizer.
# Once that support is added, we can set ort_optim_on=True (default).
self.run_model_test_with_external_data(model, x, rtol=1e-3, atol=1e-5,
ort_optim_on=False)
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
def test_attribute_with_external_data(self):
class LargeModel(torch.nn.Module):
def forward(self, x):
return x + torch.ones(2, 1024)
x = torch.randn(2, 1)
self.run_model_test_with_external_data(LargeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
@unittest.skip("Enable this once large model with subgraph is supported in ORT")
def test_subgraph_with_external_data(self):
class LargeModel(torch.nn.Module):
def forward(self, x):
for i in range(x.size(0)):
x = x + torch.ones(2, 1024)
return x
x = torch.randn(2, 1)
self.run_model_test_with_external_data(torch.jit.script(LargeModel()), x)
# Export Torchvision models
def test_alexnet(self):
model = torchvision.models.alexnet(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,))
def test_densenets(self):
model = torchvision.models.densenet121(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def test_googlenet(self):
model = torchvision.models.googlenet(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def test_inception(self):
model = torchvision.models.inception_v3(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def test_mnasnet(self):
model = torchvision.models.mnasnet1_0(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def test_mobilenet(self):
model = torchvision.models.mobilenet_v2(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def test_resnet(self):
model = torchvision.models.resnet50(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,))
def test_shufflenet(self):
model = torchvision.models.shufflenet_v2_x1_0(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def test_squeezenet(self):
model = torchvision.models.squeezenet1_1(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,))
def test_vgg(self):
model = torchvision.models.vgg19(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
model = torchvision.models.vgg19_bn(pretrained=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
@skipIfUnsupportedMinOpsetVersion(11)
def test_fcn(self):
model = torchvision.models.segmentation.segmentation.fcn_resnet101(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
@skipIfUnsupportedMinOpsetVersion(11)
def test_deeplab(self):
model = torchvision.models.segmentation.segmentation.deeplabv3_resnet101(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def test_r3d_18_video(self):
model = torchvision.models.video.r3d_18(pretrained=True)
x = torch.randn(1, 3, 4, 112, 112, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def test_mc3_18_video(self):
model = torchvision.models.video.mc3_18(pretrained=True)
x = torch.randn(1, 3, 4, 112, 112, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def test_r2plus1d_18_video(self):
model = torchvision.models.video.r2plus1d_18(pretrained=True)
x = torch.randn(1, 3, 4, 112, 112, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def test_reshape_constant_fold(self):
class Reshape(torch.nn.Module):
def __init__(self, ):
super(Reshape, self).__init__()
self.register_buffer("weight", torch.ones(5))
def forward(self, x):
scale_1 = self.weight.reshape(1, -1, 1, 1)
return x * scale_1
x = torch.randn(4, 5)
self.run_test(Reshape(), (x,), rtol=1e-3, atol=1e-5)
def run_word_language_model(self, model_name):
ntokens = 50
emsize = 5
nhid = 5
nlayers = 5
dropout = 0.2
tied = False
batchsize = 5
model = word_language_model.RNNModel(model_name, ntokens, emsize,
nhid, nlayers, dropout, tied,
batchsize)
x = torch.arange(0, ntokens).long().view(-1, batchsize)
# Only support CPU version, since tracer is not working in GPU RNN.
self.run_test(model, (x, model.hidden))
@skipIfUnsupportedMinOpsetVersion(11)
def test_faster_rcnn(self):
model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200,
max_size=300)
model.eval()
x = torch.randn(2, 3, 200, 300, requires_grad=True)
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
def get_image_from_url(self, url):
import os
from urllib.parse import urlsplit
from urllib import request
from PIL import Image
from torchvision import transforms
from torch._utils_internal import get_writable_path
filename = os.path.basename(urlsplit(url)[2])
data_dir = get_writable_path(os.path.join(os.path.dirname(__file__)))
path = os.path.join(data_dir, filename)
data = request.urlopen(url, timeout=15).read()
with open(path, 'wb') as f:
f.write(data)
image = Image.open(path).convert("RGB")
image = image.resize((300, 200), Image.BILINEAR)
to_tensor = transforms.ToTensor()
return to_tensor(image)
def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url)
images = [image]
return images
@skipIfUnsupportedMinOpsetVersion(11)
def test_mask_rcnn(self):
model = torchvision.models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200,
max_size=300)
images = self.get_test_images()
self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
@skipIfUnsupportedMinOpsetVersion(11)
def test_keypoint_rcnn(self):
model = torchvision.models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200,
max_size=300)
images = self.get_test_images()
self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
def test_word_language_model_RNN_TANH(self):
self.run_word_language_model("RNN_TANH")
def test_word_language_model_RNN_RELU(self):
self.run_word_language_model("RNN_RELU")
def test_word_language_model_LSTM(self):
self.run_word_language_model("LSTM")
def test_word_language_model_GRU(self):
self.run_word_language_model("GRU")
def test_index_1d(self):
self._test_index_generic(lambda input: input[0])
def test_index_2d_1dimslice(self):
self._test_index_generic(lambda input: input[0:1, :])
def test_index_2d_sliceint(self):
self._test_index_generic(lambda input: input[1, :])
def test_index_2d_neg_slice(self):
self._test_index_generic(lambda input: input[0:-1, :])
@skipIfUnsupportedMinOpsetVersion(9)
def test_index_mask(self):
self._test_index_generic(lambda input: input[torch.tensor([0, 1, 0], dtype=torch.uint8)])
self._test_index_generic(lambda input: input[torch.tensor([0, 1, 0], dtype=torch.bool)])
def test_dict(self):
class MyModel(torch.nn.Module):
def forward(self, x_in):
x_out = {}
x_out["test_key_out"] = torch.add(x_in[list(x_in.keys())[0]], list(x_in.keys())[0])
return x_out
x = {torch.tensor(1.): torch.randn(1, 2, 3)}
self.run_test(MyModel(), (x,))
def test_dict_str(self):
class MyModel(torch.nn.Module):
def forward(self, x_in):
x_out = {}
x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.)
return x_out
x = {"test_key_in": torch.randn(1, 2, 3)}
self.run_test(MyModel(), (x,))
@skipIfUnsupportedMinOpsetVersion(9)
def test_cste_script(self):
class MyModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.zeros(x.size(0)), torch.ones((x.size(1), x.size(0)), dtype=torch.int64)
x = torch.randn(3, 4)
self.run_test(MyModel(), x)
def test_scalar_tensor(self):
class test(torch.nn.Module):
def forward(self, input):
return torch.scalar_tensor(input.size(0)), \
torch.scalar_tensor(input.size(1), dtype=torch.int64)
x = torch.randn(2, 3, 4)
y = torch.randn(7, 8, 9)
model = test()
self.run_test(model, x, test_with_inputs=[y],
input_names=['input_1'],
dynamic_axes={'input_1': [0, 1, 2]})
def test_hardtanh(self):
model = torch.nn.Hardtanh(-1.5, 2.5)
x = torch.arange(-5, 5).to(dtype=torch.float32)
self.run_test(model, x)
def test_hardtanh_script_with_default_values(self):
class MyModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.nn.functional.hardtanh(x)
x = torch.arange(-5, 5).to(dtype=torch.float32)
self.run_test(MyModel(), x)
def test_clamp(self):
class ClampModel(torch.nn.Module):
def forward(self, x):
return x.clamp(-0.5, 0.5)
x = torch.randn(3, 4)
self.run_test(ClampModel(), x)
class ClampMinModel(torch.nn.Module):
def forward(self, x):
return x.clamp(min=-0.5)
x = torch.randn(3, 4)
self.run_test(ClampMinModel(), x)
class ClampMaxModel(torch.nn.Module):
def forward(self, x):
return x.clamp(max=0.5)
x = torch.randn(3, 4)
self.run_test(ClampMaxModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_clamp_dyn(self):
class ClampMaxModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x.clamp(None, x.size(0))
x = torch.arange(16).view(4, 4).float()
self.run_test(ClampMaxModel(), x)
class ClampMinModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x.clamp(x.size(0), None)
x = torch.arange(16).view(4, 4).float()
self.run_test(ClampMinModel(), x)
class ClampMinMaxModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x.clamp(x.size(0), x.size(1))
x = torch.arange(16).view(2, 8).float()
self.run_test(ClampMinMaxModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_full_trace(self):
class FullModel(torch.nn.Module):
def forward(self, x):
return torch.full((3, 4), x, dtype=torch.long)
x = torch.tensor(12)
self.run_test(FullModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_full_script(self):
class FullModelScripting(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return torch.full((3, 4), x, dtype=torch.long)
x = torch.tensor(12)
self.run_test(FullModelScripting(), x)
def test_fuse_addmm(self):
class AddmmModel(torch.nn.Module):
def forward(self, x):
return torch.mm(x, x) + x
x = torch.ones(3, 3)
self.run_test(AddmmModel(), x)
def test_maxpool(self):
model = torch.nn.MaxPool1d(2, stride=1)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
def test_conv(self):
class TraceModel(torch.nn.Module):
def __init__(self):
super(TraceModel, self).__init__()
self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
self.conv3 = torch.nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
class ScriptModel(torch.jit.ScriptModule):
def __init__(self):
super(ScriptModel, self).__init__()
self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
self.conv3 = torch.nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
@torch.jit.script_method
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)
x1 = torch.randn(20, 16, 50)
x2 = torch.randn(20, 16, 50, 100)
x3 = torch.randn(20, 16, 10, 50, 100)
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
# TODO: Add ConvTranspose1d and ConvTranspose3d when supported in ORT
# TODO : Add test with dilation != 1 when ORT fixed
def test_conv_transpose(self):
class TraceModel(torch.nn.Module):
def __init__(self):
super(TraceModel, self).__init__()
self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1))
def forward(self, input2):
return self.conv2(input2)
class ScriptModel(torch.jit.ScriptModule):
def __init__(self):
super(ScriptModel, self).__init__()
self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1))
@torch.jit.script_method
def forward(self, input2):
return self.conv2(input2)
x2 = torch.randn(20, 16, 50, 100)
self.run_test(TraceModel(), (x2,), atol=10e-5)
self.run_test(ScriptModel(), (x2,), atol=10e-5)
def test_squeeze(self):
class Squeeze(torch.nn.Module):
def forward(self, x):
return torch.torch.squeeze(x, dim=-2)
x = torch.randn(2, 1, 4)
self.run_test(Squeeze(), x)
def test_unsqueeze(self):
class Unsqueeze(torch.nn.Module):
def forward(self, x):
return torch.unsqueeze(x, dim=-2)
x = torch.randn(2, 3, 4)
self.run_test(Unsqueeze(), x)
def test_maxpool_default_stride(self):
class MaxPoolModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.max_pool2d(x, 2)
model = MaxPoolModel()
x = torch.randn(10, 20, 16, 50)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(8)
def test_maxpool_adaptive(self):
model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
x = torch.randn(20, 16, 50, requires_grad=True)
self.run_test(model, x)
def test_maxpool_2d(self):
model = torch.nn.MaxPool2d(5, padding=(1, 2))
x = torch.randn(1, 20, 16, 50, requires_grad=True)
self.run_test(model, x)
def test_maxpool_1d_ceil(self):
model = torch.nn.MaxPool1d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
def test_maxpool_2d_ceil(self):
model = torch.nn.MaxPool2d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50, 32)
self.run_test(model, x)
def test_maxpool_3d_ceil(self):
model = torch.nn.MaxPool3d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50, 44, 31)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(8)
def test_maxpool_with_indices(self):
model = torch.nn.MaxPool1d(2, stride=1, return_indices=True)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_maxpool_dilation(self):
model = torch.nn.MaxPool1d(2, stride=1, dilation=2)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
def test_avgpool_default_stride(self):
class AvgPoolModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.avg_pool2d(x, 2)
model = AvgPoolModel()
x = torch.randn(10, 20, 16, 50)
self.run_test(model, x)
def test_avgpool(self):
model = torch.nn.AvgPool1d(2, stride=1)
x = torch.randn(20, 16, 50)
self.run_test(model, x)
def test_avgpool_1d_ceil(self):
model = torch.nn.AvgPool1d(3, 2, ceil_mode=True)
x = torch.randn(1, 1, 7)
self.run_test(model, x)
def test_avgpool_2d_ceil(self):
model = torch.nn.AvgPool2d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50, 32)
self.run_test(model, x)
def test_avgpool_3d_ceil(self):
model = torch.nn.AvgPool3d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50, 44, 31)
self.run_test(model, x)
@enableScriptTest()
def test_arithmetic(self):
class ArithmeticModule(torch.nn.Module):
def forward(self, x):
x = x + 2
x = x - 4
x = x * 6
x = x / 8
return x
x = torch.randn(2, 3, 4)
self.run_test(ArithmeticModule(), x)
def test_floor_div(self):
class FloorDivModule(torch.nn.Module):
def forward(self, x, y):
return x // 3, x // 2., \
x.to(dtype=torch.float64) // 3, x.to(dtype=torch.float64) // 2., \
x.to(dtype=torch.int64) // 3, x.to(dtype=torch.int64) // 2., \
x // (y + 1.).to(dtype=torch.int64), x // y, \
x.to(dtype=torch.float64) // y.to(dtype=torch.int64), x.to(dtype=torch.float64) // y.to(dtype=torch.float64), \
x.to(dtype=torch.int64) // y.to(dtype=torch.int64), x.to(dtype=torch.int64) // y
x = torch.randn(2, 3, 4)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4)
self.run_test(FloorDivModule(), (x, y))
def test_floor_div_script(self):
class FloorDivModule(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, y):
return x // 3, x // 2., x // y
x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
self.run_test(FloorDivModule(), (x, y))
def test_true_div(self):
class TrueDivModule(torch.nn.Module):
def forward(self, x, y):
return torch.true_divide(x, y)
x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
self.run_test(TrueDivModule(), (x, y))
self.run_test(TrueDivModule(), (x.float(), y))
self.run_test(TrueDivModule(), (x.to(torch.short), y.to(torch.short)))
# Note: true_divide cannot (generally) be exported via scripting
# since its type promotion logic is dependent on knowing the scalar types
# of the input tensors. That is, the ONNX graph is dependent on the
# data type of the inputs. This makes it appropriate for tracing only.
def test_true_div_trace(self):
class TrueDivModule(torch.nn.Module):
def forward(self, x, y):
return torch.true_divide(x, y)
x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
prev_default = torch.get_default_dtype()
torch.set_default_dtype(torch.float)
self.run_test(torch.jit.trace(TrueDivModule(), (x, y)), (x, y))
torch.set_default_dtype(torch.double)
self.run_test(torch.jit.trace(TrueDivModule(), (x, y)), (x, y))
torch.set_default_dtype(prev_default)
def test_slice_trace(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return x[0:1]
x = torch.randn(3)
self.run_test(MyModule(), x)
def test_slice_neg(self):
class NegSlice(torch.nn.Module):
def forward(self, x):
return x[-1:]
x = torch.randn(3, 4, 5)
self.run_test(NegSlice(), x)
def test_slice_neg_large(self):
class NegSlice(torch.nn.Module):
def forward(self, x):
return x[:, :, -3:-1, :, -1]
x = torch.randn(3, 4, 5, 6, 7)
self.run_test(NegSlice(), x)
def test_slice_neg_large_negone(self):
class NegSlice(torch.nn.Module):
def forward(self, x):
return x[:, :, :, :, -1]
x = torch.randn(3, 4, 5, 6, 7)
self.run_test(NegSlice(), x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_slice_dynamic(self):
class DynamicSliceExportMod(torch.nn.Module):
def forward(self, x):
results = []
for i in range(4):
results.append(x[:x.size(0) - i, i:x.size(2), i:3])
return tuple(results)
x = torch.rand(5, 5, 5)
y = torch.randn(6, 7, 8)
self.run_test(DynamicSliceExportMod(), x, test_with_inputs=[y],
input_names=['input_1'],
output_names=['output_1'],
dynamic_axes={'input_1': [0, 1, 2],
'output_1': [0, 1, 2]})
@skipIfUnsupportedMinOpsetVersion(10)
def test_slice_dynamic_script(self):
class DynamicSliceModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x[1:x.size(0)]
x = torch.rand(1, 2)
self.run_test(DynamicSliceModel(), x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_slice_dynamic_to_end(self):
class DynamicSliceExportMod(torch.nn.Module):
def forward(self, x):
results = []
for i in range(4):
results.append(x[:, i:, x.size(2) - 5])
return tuple(results)
x = torch.rand(5, 5, 5)
self.run_test(DynamicSliceExportMod(), x,
dynamic_axes={'input_1': [0, 1, 2],
'output_1': [0, 1, 2]})
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_dynamic(self):
class ArangeModel(torch.nn.Module):
def forward(self, input):
return torch.arange(input.shape[0]), \
torch.arange(12), \
torch.arange(start=input.shape[0], end=input.shape[0] + 5)
x = torch.randn(5, 3, 2)
y = torch.randn(8, 3, 2)
self.run_test(ArangeModel(), x, test_with_inputs=[y],
input_names=['input_1'],
output_names=['output_1', 'output_2', 'output_3'],
dynamic_axes={'input_1': [0],
'output_1': [0]})
self.run_test(torch.jit.script(ArangeModel()), x,
test_with_inputs=[y], input_names=['input_1'],
output_names=['output_1', 'output_2', 'output_3'],
dynamic_axes={'input_1': [0],
'output_1': [0]})
@skipIfUnsupportedMinOpsetVersion(11)
def test_arange(self):
class ArangeModel(torch.nn.Module):
def forward(self, start, end):
return torch.arange(start.size(0), end, 1.5, dtype=torch.int64)
x = torch.randn(2, 3, 4)
y = torch.tensor(8.5, dtype=torch.float)
self.run_test(ArangeModel(), (x, y))
@skipIfUnsupportedMinOpsetVersion(11)
def test_arange_no_type(self):
class ArangeModel(torch.nn.Module):
def forward(self, end):
return torch.arange(end), \
torch.arange(0, end)
x = torch.tensor(6.2, dtype=torch.float)
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_size(self):
class SizeModel(torch.nn.Module):
def forward(self, input):
return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape)
x = torch.randn(5, 3, 2)
self.run_test(SizeModel(), x)
def _test_index_generic(self, fn):
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, input):
return fn(input)
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), m1)
def test_tensor_index_advanced_indexing(self):
self._test_index_generic(
lambda input: input[:, torch.tensor([[0, 2], [1, 1]]), :, torch.tensor([2, 1]), torch.tensor([0, 3])])
self._test_index_generic(lambda input: input[..., torch.tensor([2, 1]), torch.tensor([0, 3])])
self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]])])
self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), torch.tensor([1]), 2:4, torch.tensor([[1], [4]])])
def test_tensor_index_advanced_indexing_consecutive(self):
self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None])
@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, ind, update):
x[ind] = update
return x
x = torch.randn(3, 4)
ind = torch.tensor([1], dtype=torch.long)
update = torch.ones(4)
self.run_test(IndexPutModel(), (x, ind, update))
@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_accumulate(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, ind, update):
return x.index_put((ind, ), update, accumulate=True)
x = torch.randn(3, 4)
ind = torch.tensor([2], dtype=torch.long)
update = torch.ones(4)
self.run_test(IndexPutModel(), (x, ind, update))
@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_slice_index(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, update):
x[1:2, 1:3, torch.tensor([1])] += update
return x
x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(1, 2, 1)
self.run_test(IndexPutModel(), (x, update))
class IndexPutModel2(torch.nn.Module):
def forward(self, x, update):
x[torch.tensor([0, 2]), torch.tensor([1, 2])] += update
return x
x = torch.randn(3, 4, 5)
update = torch.randn(2, 5)
self.run_test(IndexPutModel2(), (x, update))
class IndexPutModel3(torch.nn.Module):
def forward(self, x, update):
x[torch.tensor([0, 2]), 1:2] += update
return x
x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(2, 1, 1)
self.run_test(IndexPutModel3(), (x, update))
class IndexPutModel4(torch.nn.Module):
def forward(self, x, update):
x[torch.tensor([0, 2]), 2] += update
return x
x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(2, 1)
self.run_test(IndexPutModel4(), (x, update))
class IndexPutModel5(torch.nn.Module):
def forward(self, x, update):
x[1:3, torch.tensor([0, 2]), 2] += update
return x
x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(2, 1)
self.run_test(IndexPutModel5(), (x, update))
class IndexPutModel6(torch.nn.Module):
def forward(self, x, update):
x[1:3, 0] = update
return x
x = torch.randn(3, 4, 5)
update = torch.arange(2 * 5).to(torch.float).view(2, 5)
self.run_test(IndexPutModel6(), (x, update))
class IndexPutModel7(torch.nn.Module):
def forward(self, x, update):
x[1:, 0] = update
return x
x = torch.randn(3, 4, 5)
update = torch.arange(2 * 5).to(torch.float).view(2, 5)
self.run_test(IndexPutModel7(), (x, update))
class IndexPutModel8(torch.nn.Module):
def forward(self, x, update):
x[:3, 0] = update
return x
x = torch.randn(3, 4, 5)
update = torch.arange(3 * 5).to(torch.float).view(3, 5)
self.run_test(IndexPutModel8(), (x, update))