diff --git a/fluid_onnx/ops.py b/fluid_onnx/ops.py index d91315268..5516e3075 100644 --- a/fluid_onnx/ops.py +++ b/fluid_onnx/ops.py @@ -58,6 +58,11 @@ def activation_ops(act_type, operator, block): + """ Convert common activations with type specified by 'act_type', including + 'abs', 'ceil', 'exp', 'floor', 'log', 'reciprocal', 'relu', 'sigmoid', + 'softplus', 'softsign', 'sqrt' and 'tanh'. + """ + inputs, _, outputs = op_io_info(operator) return make_node( act_type, inputs=inputs.values()[0], outputs=outputs.values()[0]) @@ -114,16 +119,32 @@ def batch_norm_op(operator, block): bn_node) -def cast_op(): - pass +def cast_op(operator, block): + inputs, attrs, outputs = op_io_info(operator) + return make_node( + 'Cast', + inputs=inputs['X'], + outputs=outputs['Out'], + to=PADDLE_TO_ONNX_DTYPE[attrs['out_dtype']]) -def clip_op(): - pass +def clip_op(operator, block): + inputs, attrs, outputs = op_io_info(operator) + return make_node( + 'Clip', + inputs=inputs['X'], + outputs=outputs['Out'], + min=attrs['min'], + max=attrs['max']) -def concat_op(): - pass +def concat_op(operator, block): + inputs, attrs, outputs = op_io_info(operator) + return make_node( + 'Concat', + inputs=inputs['X'], + outputs=outputs['Out'], + axis=attrs['axis']) def constant_op(var, scope): @@ -156,8 +177,20 @@ def conv2d_op(operator, block): return conv2d -def convtranspose_op(): - pass +def conv2d_transpose_op(operator, block): + inputs, attrs, outputs = op_io_info(operator) + + kernel_shape = block.vars[inputs['Filter'][0]].shape + conv2d_transpose = make_node( + 'ConvTranspose', + inputs=inputs['Input'] + inputs['Filter'], + outputs=outputs['Output'], + dilations=attrs['dilations'], + kernel_shape=kernel_shape[-2:], + strides=attrs['strides'], + group=1, + pads=attrs['paddings'] + attrs['paddings']) + return conv2d_transpose def depthtospace_op(): @@ -184,12 +217,19 @@ def dropout_op(operator, block): def elementwise_ops(op_type, operator, block): + """Convert elementwise operators From to ONNX. Supported elementwise + 'op_type' includes 'Add', 'Div', 'Mul', 'Pow' and 'Sub'. + """ + inputs, attrs, outputs = op_io_info(operator) + rank_x = len(block.vars[inputs['X'][0]].shape) + rank_y = len(block.vars[inputs['Y'][0]].shape) + axis = rank_x - rank_y if attrs['axis'] == -1 else attrs['axis'] return make_node( op_type, inputs=inputs['X'] + inputs['Y'], outputs=outputs['Out'], - axis=attrs['axis'], + axis=axis, broadcast=1) @@ -260,8 +300,21 @@ def less_op(): pass -def log_op(): - pass +def binary_logical_ops(op_type, operator, block): + """Convert binary logical operators, i.e. 'And', 'Or' and 'Xor'. + """ + + inputs, _, outputs = op_io_info(operator) + return make_node( + op_type, inputs=inputs['X'] + inputs['Y'], outputs=outputs['Out']) + + +def unary_logical_ops(op_type, operator, block): + """Convert unary logical operators, i.e. 'Not'. + """ + + inputs, _, outputs = op_io_info(operator) + return make_node(op_type, inputs=inputs['X'], outputs=outputs['Out']) def logsoftmax_op(): @@ -422,6 +475,11 @@ def randomuniformlike_op(): def reduce_ops(op_type, operator, block): + """Convert reduce operators in Fluid to ONNX. 'op_type' specifies the + target ONNX operator type, supporting 'Reduce{Max, Mean, Min, Sum}' + right now. + """ + inputs, attrs, outputs = op_io_info(operator) rank = len(block.vars[inputs['X'][0]].shape) dim = attrs['dim'] @@ -549,19 +607,17 @@ def xor_op(): node_maker = { # Paddle op name : (ONNX op name, modifier) 'abs': partial(activation_ops, 'Abs'), - # '': 'And', # ? # 'ArgMax', NEEDS ATTENTION. # 'ArgMin', NEEDS ATTENTION. 'batch_norm': batch_norm_op, - 'cast': ('Cast', cast_op), + 'cast': cast_op, 'ceil': partial(activation_ops, 'Ceil'), - 'clip': ('Clip', clip_op), - 'concat': ('Concat', concat_op), + 'clip': clip_op, + 'concat': concat_op, 'constant': constant_op, 'conv2d': conv2d_op, - # Need to continue the mapping below. - '': 'ConvTranspose', + 'conv2d_transpose': conv2d_transpose_op, '': 'DepthToSpace', 'depthwise_conv2d': conv2d_op, 'dropout': dropout_op, @@ -588,6 +644,10 @@ def xor_op(): '': 'LeakyRelu', '': 'Less', 'log': partial(activation_ops, 'Log'), + 'logical_and': partial(binary_logical_ops, 'And'), + 'logical_or': partial(binary_logical_ops, 'Or'), + 'logical_not': partial(unary_logical_ops, 'Not'), + 'logical_xor': partial(binary_logical_ops, 'Xor'), ',': 'LogSoftmax', '': 'LpNormalization', '': 'LpPool', @@ -599,8 +659,6 @@ def xor_op(): '': 'Min', 'mul': mul_op, ',': 'Neg', - '': 'Not', - '': 'Or', '': 'PRelu', '': 'Pad', 'pool2d': pool2d_op, @@ -640,7 +698,6 @@ def xor_op(): '': 'TopK', '': 'Transpose', # 'Unsqueeze', NEEDS ATTENTION. - '': 'Xor', # 'experimental ATen' # ',': 'experimental Affine' # 'experimental ConstantFill' diff --git a/fluid_onnx/variables.py b/fluid_onnx/variables.py index 4db7f3aa8..47af770ce 100644 --- a/fluid_onnx/variables.py +++ b/fluid_onnx/variables.py @@ -36,7 +36,7 @@ def paddle_onnx_shape(paddle_shape): PADDLE_TO_ONNX_DTYPE = { core.VarDesc.VarType.FP32: onnx_pb2.TensorProto.FLOAT, - core.VarDesc.VarType.FP64: onnx_pb2.TensorProto.FLOAT16, + core.VarDesc.VarType.FP64: onnx_pb2.TensorProto.DOUBLE, # '': onnx_pb2.TensorProto.DOUBLE, core.VarDesc.VarType.INT32: onnx_pb2.TensorProto.INT32, core.VarDesc.VarType.INT16: onnx_pb2.TensorProto.INT16, diff --git a/tests/op_test.py b/tests/op_test.py index 22bae505b..42b973701 100644 --- a/tests/op_test.py +++ b/tests/op_test.py @@ -77,7 +77,7 @@ def create_var(block, name, np_list, var_proto): persistable = True if name in persistable_list else False return block.create_var( - dtype="float32", + dtype='float32', shape=shape, persistable=persistable, lod_level=lod_level, @@ -263,9 +263,17 @@ def eval_onnx_node(self): onnx_graph = make_graph(node_list, self.op_type, inputs, outputs) onnx_model = make_model(onnx_graph, producer_name='unittest') + # Expand input dictionary if there are tensor arrays + input_map = {} + for v in self.inputs: + if isinstance(self.inputs[v], list): + input_map.update(self.inputs[v]) + else: + input_map[v] = self.inputs[v] + # Run the Caffe2Backend with the ONNX model. rep = Caffe2Backend.prepare(onnx_model, device='CPU') - in_vals = [self.inputs[input.name] for input in inputs] + in_vals = [input_map[input.name] for input in inputs] outs = rep.run(in_vals) return outs diff --git a/tests/test_activation_ops.py b/tests/test_activation_ops.py index 2b0f55de2..b1310c2bd 100644 --- a/tests/test_activation_ops.py +++ b/tests/test_activation_ops.py @@ -19,9 +19,9 @@ class TestAbsOp(OpTest): def setUp(self): - X = np.random.random((13, 15)).astype("float32") + X = np.random.random((13, 15)).astype('float32') self.inputs = {'X': X} - self.outputs = {'Out': np.zeros((1, 1)).astype("float32")} + self.outputs = {'Out': np.zeros((1, 1)).astype('float32')} self.init_op_type() def init_op_type(self): diff --git a/tests/test_cast_op.py b/tests/test_cast_op.py new file mode 100644 index 000000000..67903dd9c --- /dev/null +++ b/tests/test_cast_op.py @@ -0,0 +1,37 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest + + +class TestCastOp(OpTest): + def setUp(self): + input = np.random.random((10, 10)) + self.inputs = {'X': input.astype('float32')} + self.outputs = {'Out': input.astype('float64')} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.FP32), + 'out_dtype': int(core.VarDesc.VarType.FP64) + } + self.op_type = 'cast' + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_clip_op.py b/tests/test_clip_op.py new file mode 100644 index 000000000..d39c8af43 --- /dev/null +++ b/tests/test_clip_op.py @@ -0,0 +1,33 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestClipOp(OpTest): + def setUp(self): + input = np.random.random((4, 5, 6)).astype('float32') + self.op_type = 'clip' + self.inputs = {'X': input} + self.attrs = {'min': 0.2, 'max': 0.8} + self.outputs = {'Out': np.zeros((1, 1)).astype('float32')} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_concat_op.py b/tests/test_concat_op.py new file mode 100644 index 000000000..40943962f --- /dev/null +++ b/tests/test_concat_op.py @@ -0,0 +1,47 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestConcatOp(OpTest): + def setUp(self): + self.op_type = 'concat' + self.init_test_data() + self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} + self.attrs = {'axis': self.axis} + self.outputs = {'Out': np.zeros((1, 1)).astype('float32')} + + def test_check_output(self): + self.check_output() + + def init_test_data(self): + self.x0 = np.random.random((2, 1, 4, 5)).astype('float32') + self.x1 = np.random.random((2, 2, 4, 5)).astype('float32') + self.x2 = np.random.random((2, 3, 4, 5)).astype('float32') + self.axis = 1 + + +class TestConcatOp2(OpTest): + def init_test_data(self): + self.x0 = np.random.random((2, 3, 4, 5)).astype('float32') + self.x1 = np.random.random((2, 3, 4, 5)).astype('float32') + self.x2 = np.random.random((2, 3, 4, 5)).astype('float32') + self.axis = 2 + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_conv2d_op.py b/tests/test_conv2d_op.py index 40aeed9bf..e5d797240 100644 --- a/tests/test_conv2d_op.py +++ b/tests/test_conv2d_op.py @@ -57,7 +57,7 @@ def setUp(self): self.outputs = {'Output': output} def init_conv_type(self): - self.op_type = "conv2d" + self.op_type = 'conv2d' def test_check_output(self): self.check_output(decimal=5) @@ -65,7 +65,7 @@ def test_check_output(self): class TestDepthwiseConv2dOp(TestConv2dOp): def init_conv_type(self): - self.op_type = "depthwise_conv2d" + self.op_type = 'depthwise_conv2d' if __name__ == '__main__': diff --git a/tests/test_conv2d_transpose_op.py b/tests/test_conv2d_transpose_op.py new file mode 100644 index 000000000..2fb7f39ab --- /dev/null +++ b/tests/test_conv2d_transpose_op.py @@ -0,0 +1,56 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestConv2dTransposeOp(OpTest): + def setUp(self): + self.use_cudnn = False + self.init_op_type() + self.init_test_case() + + input_ = np.random.random(self.input_size).astype('float32') + filter_ = np.random.random(self.filter_size).astype('float32') + + self.inputs = {'Input': input_, 'Filter': filter_} + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'data_format': 'AnyLayout' + } + + self.outputs = {'Output': np.zeros((1, 1))} + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.op_type = 'conv2d_transpose' + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_dropout_op.py b/tests/test_dropout_op.py index d503b7e60..397c2e482 100644 --- a/tests/test_dropout_op.py +++ b/tests/test_dropout_op.py @@ -19,8 +19,8 @@ class TestDropoutOp(OpTest): def setUp(self): - self.op_type = "dropout" - self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} + self.op_type = 'dropout' + self.inputs = {'X': np.random.random((32, 64, 2)).astype('float32')} self.attrs = {'dropout_prob': 0.8, 'is_test': True} self.outputs = {'Out': np.zeros((1, 1))} diff --git a/tests/test_elementwise_ops.py b/tests/test_elementwise_ops.py index 76c189391..37b01932a 100644 --- a/tests/test_elementwise_ops.py +++ b/tests/test_elementwise_ops.py @@ -19,41 +19,47 @@ class TestElementwiseAddOp(OpTest): def setUp(self): - self.init_op_type() - self.attrs = {"axis": 1} + self.attrs = {'axis': 2} + self.init() self.inputs = { - 'X': np.random.random((4, 2)).astype(np.float32), - 'Y': np.random.random((2, )).astype(np.float32) + 'X': np.random.random((2, 3, 4, 5)).astype(np.float32), + 'Y': np.random.random((4, 5)).astype(np.float32) } self.outputs = {'Out': np.zeros((1, 1))} - def init_op_type(self): - self.op_type = "elementwise_add" + def init(self): + self.op_type = 'elementwise_add' def test_check_output(self): self.check_output() +class TestElementwiseAddOpNegAxis(OpTest): + def init(self): + self.op_type = 'elementwise_add' + self.attrs = {'axis': -1} + + class TestElementwiseSubOp(TestElementwiseAddOp): - def init_op_type(self): - self.op_type = "elementwise_sub" + def init(self): + self.op_type = 'elementwise_sub' class TestElementwiseMulOp(TestElementwiseAddOp): - def init_op_type(self): - self.op_type = "elementwise_mul" + def init(self): + self.op_type = 'elementwise_mul' class TestElementwiseDivOp(TestElementwiseAddOp): - def init_op_type(self): - self.op_type = "elementwise_div" + def init(self): + self.op_type = 'elementwise_div' class TestElementwisePowOp(TestElementwiseAddOp): - def init_op_type(self): - self.op_type = "elementwise_pow" + def init(self): + self.op_type = 'elementwise_pow' if __name__ == '__main__': diff --git a/tests/test_logical_ops.py b/tests/test_logical_ops.py new file mode 100644 index 000000000..1d7dfe60f --- /dev/null +++ b/tests/test_logical_ops.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import op_test +import unittest +import numpy as np + + +def create_test_class(op_type, callback, binary_op=True): + class Cls(op_test.OpTest): + def setUp(self): + a = np.random.choice(a=[True, False], size=(10, 7)).astype(bool) + if binary_op: + b = np.random.choice(a=[True, False], size=(10, 7)).astype(bool) + c = callback(a, b) + else: + c = callback(a) + self.outputs = {'Out': c} + self.op_type = op_type + if binary_op: + self.inputs = {'X': a, 'Y': b} + else: + self.inputs = {'X': a} + + def test_output(self): + self.check_output() + + Cls.__name__ = op_type + globals()[op_type] = Cls + + +create_test_class('logical_and', lambda _a, _b: np.logical_and(_a, _b)) +create_test_class('logical_or', lambda _a, _b: np.logical_or(_a, _b)) +create_test_class('logical_not', lambda _a: np.logical_not(_a), False) +create_test_class('logical_xor', lambda _a, _b: np.logical_xor(_a, _b)) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_mul_op.py b/tests/test_mul_op.py index 0b2e73f30..e36a8a6bd 100644 --- a/tests/test_mul_op.py +++ b/tests/test_mul_op.py @@ -19,10 +19,10 @@ class TestMulOp(OpTest): def setUp(self): - self.op_type = "mul" + self.op_type = 'mul' self.inputs = { - 'X': np.random.random((15, 4, 12, 10)).astype("float32"), - 'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32") + 'X': np.random.random((15, 4, 12, 10)).astype('float32'), + 'Y': np.random.random((4, 30, 8, 2, 9)).astype('float32') } self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2} result = np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10), diff --git a/tests/test_pool2d_op.py b/tests/test_pool2d_op.py index bff216d9e..2ab6c71b6 100644 --- a/tests/test_pool2d_op.py +++ b/tests/test_pool2d_op.py @@ -19,7 +19,7 @@ class TestPool2dOp(OpTest): def setUp(self): - self.op_type = "pool2d" + self.op_type = 'pool2d' self.use_cudnn = False self.use_mkldnn = False self.dtype = np.float32 @@ -79,7 +79,7 @@ def init_test_case(self): self.paddings = [0, 0] def init_pool_type(self): - self.pool_type = "avg" + self.pool_type = 'avg' def init_global_pool(self): self.global_pool = False @@ -93,7 +93,7 @@ def init_test_case(self): self.paddings = [1, 1] def init_pool_type(self): - self.pool_type = "avg" + self.pool_type = 'avg' def init_global_pool(self): self.global_pool = False @@ -101,17 +101,17 @@ def init_global_pool(self): class TestPool2dOp3(TestPool2dOp): def init_pool_type(self): - self.pool_type = "max" + self.pool_type = 'max' class TestPool2dOp4(TestPool2dOp1): def init_pool_type(self): - self.pool_type = "max" + self.pool_type = 'max' class TestPool2dOp5(TestPool2dOp2): def init_pool_type(self): - self.pool_type = "max" + self.pool_type = 'max' if __name__ == '__main__': diff --git a/tests/test_reduce_ops.py b/tests/test_reduce_ops.py index 7f43f3a6a..5dba8321e 100644 --- a/tests/test_reduce_ops.py +++ b/tests/test_reduce_ops.py @@ -22,7 +22,7 @@ def setUp(self): self.init_op_type() self.init_keep_dim() self.init_reduce_all() - self.inputs = {'X': np.random.random((5, 6, 7, 8)).astype("float32")} + self.inputs = {'X': np.random.random((5, 6, 7, 8)).astype('float32')} self.attrs = { 'dim': 2, 'keep_dim': self.keep_dim, @@ -31,7 +31,7 @@ def setUp(self): self.outputs = {'Out': np.zeros((1, 1))} def init_op_type(self): - self.op_type = "reduce_sum" + self.op_type = 'reduce_sum' def init_keep_dim(self): self.keep_dim = True @@ -45,7 +45,7 @@ def test_check_output(self): class TestReduceMeanOp(TestReduceSumOp): def init_op_type(self): - self.op_type = "reduce_mean" + self.op_type = 'reduce_mean' def init_reduce_all(self): self.reduce_all = True @@ -53,7 +53,7 @@ def init_reduce_all(self): class TestReduceMaxOp(TestReduceSumOp): def init_op_type(self): - self.op_type = "reduce_max" + self.op_type = 'reduce_max' def init_keep_dim(self): self.keep_dim = False @@ -64,7 +64,7 @@ def init_reduce_all(self): class TestReduceMinOp(TestReduceSumOp): def init_op_type(self): - self.op_type = "reduce_min" + self.op_type = 'reduce_min' def init_keep_dim(self): self.keep_dim = False