diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index ae088f5c9e63..b56491f1ef2d 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -150,6 +150,10 @@ RUN bash /install/ubuntu_install_libxsmm.sh COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh RUN bash /install/ubuntu_install_onnx.sh +# NNEF +COPY install/ubuntu_install_nnef.sh /install/ubuntu_install_nnef.sh +RUN bash /install/ubuntu_install_nnef.sh + # AArch64 Architecture Envelope Model (AEM) COPY install/ubuntu_install_aprofile_aem.sh /install RUN bash /install/ubuntu_install_aprofile_aem.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index acb0310a41e2..7e7f92c684b4 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -17,7 +17,7 @@ # CI docker GPU env # tag: v0.60 -FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 COPY utils/apt-install-and-clear.sh /usr/local/bin/apt-install-and-clear @@ -118,6 +118,9 @@ RUN bash /install/ubuntu_install_paddle.sh COPY install/ubuntu_install_oneflow.sh /install/ubuntu_install_oneflow.sh RUN bash /install/ubuntu_install_oneflow.sh +COPY install/ubuntu_install_nnef.sh /install/ubuntu_install_nnef.sh +RUN bash /install/ubuntu_install_nnef.sh + # Rust env (build early; takes a while) COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh RUN bash /install/ubuntu_install_rust.sh diff --git a/docker/install/ubuntu_install_nnef.sh b/docker/install/ubuntu_install_nnef.sh new file mode 100644 index 000000000000..9e9699bf5719 --- /dev/null +++ b/docker/install/ubuntu_install_nnef.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u +set -o pipefail + +pip3 install \ + nnef_tools==1.0.5 \ + nnef==1.0.5 diff --git a/python/tvm/relax/frontend/nnef/__init__.py b/python/tvm/relax/frontend/nnef/__init__.py new file mode 100644 index 000000000000..5578a9ad5334 --- /dev/null +++ b/python/tvm/relax/frontend/nnef/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +NNEF frontend for converting graphs into Relax IRModels. +""" +from .nnef_frontend import from_nnef diff --git a/python/tvm/relax/frontend/nnef/nnef_frontend.py b/python/tvm/relax/frontend/nnef/nnef_frontend.py new file mode 100644 index 000000000000..05257b9d1d72 --- /dev/null +++ b/python/tvm/relax/frontend/nnef/nnef_frontend.py @@ -0,0 +1,307 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""NNEF: Neural Network Exchange Format frontend for TVM relay""" +import os +import typing +import nnef +import numpy as np + +import tvm +from tvm import relax +from tvm.ir import IRModule +from tvm.relax import expr as tvm_expr + +from .nnef_ops import _get_converter_map + + +def get_type(elem_type: str): + """ + Gives numpy style type for nnef primitive types, uses x32 versions. + + :param elem_type: string, (scalar, integer, logical, string) + :return: returns numpy dtype equivalent (float32, int32, bool, string) + """ + if elem_type == "scalar": + return "float32" + if elem_type == "integer": + return "int32" + if elem_type == "logical": + return "bool" + if elem_type == "string": + return "string" + raise TypeError(f'Type "{elem_type}" is not implemented') + + +# Converter class +class NNEFConverter: + """ + Helper class for class level attributes, for conversion of NNEF model. + Public method to use is from_nnef. + + Parameters + ---------- + + keep_params_in_input : bool, optional + If this parameter is true, the nnef variables will be converted to + constants, and be embedded into the relay model, allowing optimizations + at compile time. + If False the params will have to be added as inputs, + the model can't load them automatically + + """ + + def __init__(self, keep_params_in_input=False): + self._nodes = {} + self._consts = {} + self._inputs = {} + self._num_inputs = 0 + self._params = {} + self._num_params = 0 + self._keep_params_in_input = keep_params_in_input + self._bb = relax.BlockBuilder() + + def from_nnef(self, graph: nnef.Graph) -> tvm.IRModule: + """ + Convert an NNEF model into an equivalent TVM Relay IRModule. + + Parameters + ---------- + graph : nnef.Graph + An NNEF Graph object that was imported with nnef.load_graph. + Shapes should be inferred by nnef.infer_shapes on graph beforehand. + + Returns + ------- + mod : tvm.IRModule + The relay module for compilation + + params : dict of str to tvm.nd.NDArray + The parameter dictionary to be used + + """ + with self._bb.function("main"): + with self._bb.dataflow(): + self._parse_inputs(graph) + self._construct_nodes(graph) + + outputs = [self._nodes[n] for n in graph.outputs] + outputs = outputs[0] if len(outputs) == 1 else tvm_expr.Tuple(outputs) + + output_var = self._bb.emit_output(outputs) + + func_attrs = {"num_input": self._num_inputs} + + input_list = [value for value in self._inputs.values() if isinstance(value, relax.Var)] + + if self._keep_params_in_input and self._params: + param_var_list, param_value_list = map(list, zip(*self._params.values())) + input_list.append(param_var_list) + func_attrs["params"] = param_value_list + + self._bb.emit_func_output(output_var, input_list) + + relax_mod = self._bb.get() + relax_mod["main"] = relax_mod["main"].with_attrs(func_attrs) + return relax_mod + + def _parse_inputs(self, graph): + """Save inputs into class from inputs attrib of graph""" + for inp in graph.inputs: + self._num_inputs += 1 + tensor = graph.tensors[inp] + self._nodes[inp] = self._new_var(inp, shape=tensor.shape, dtype=get_type(tensor.dtype)) + self._inputs[inp] = self._nodes[inp] + + def _construct_nodes(self, graph): + """Construct TVM relay calls from every operation of the nnef graph""" + for op in graph.operations: + if op.name == "external": + # externals are handled as input, not needed, + # but nnef treats them as operations as well + continue + + if op.name == "variable": + self._set_variable(graph.tensors[op.outputs["output"]]) + + elif op.name == "constant": + self._set_const(op) + + else: + # every other operator can be grouped more easily, + # as it does not need self for conversion + self._set_operator(op) + + def _set_operator(self, node): + self._set_literal_inputs(node) + inputs = [] + for ink, inv in node.inputs.items(): + if isinstance(inv, list): + for i, linv in enumerate(inv): + if linv in self._nodes.keys(): + inputs.append(self._nodes[linv]) + else: # handle literal inputs + name = f"{node.name}_{ink}_{i}" + assert name in self._nodes, f"{name} has not been properly handled" + inputs.append(self._nodes[name]) + + else: + if inv in self._nodes.keys(): + inputs.append(self._nodes[inv]) + else: # handle literal inputs + name = f"{node.name}_{ink}" + assert name in self._nodes, f"{name} has not been properly handled" + inputs.append(self._nodes[name]) + + converted = self._get_relay_op_call(node.name, inputs, node.attribs) + converted = self._bb.normalize(converted) + + if not isinstance(converted.struct_info, relax.TupleStructInfo): + outputs_num = 1 + else: + outputs_num = len(converted.struct_info.fields) + + if outputs_num == 1: + # check if the singular ret val is a list of only one element + ret_val = list(node.outputs.values())[0] + if isinstance(ret_val, list): + self._nodes[ret_val[0]] = converted + else: + self._nodes[ret_val] = converted + else: + for i, out in zip(range(outputs_num), node.outputs["values"]): + self._nodes[out] = converted[i] + + def _set_const(self, node): + """Create a tvm.relay.Constant from a nnef constant tensor""" + name = node.outputs["output"] + data = node.attribs["value"] + shape = node.attribs["shape"] + if len(data) == 1: + data = np.full(shape, data, dtype=get_type(node.dtype)) + else: + data = np.array(data, dtype=get_type(node.dtype)) + self._consts[name] = tvm_expr.const(data) + self._nodes[name] = self._consts[name] + + def _set_variable(self, tensor): + """Create a tvm.relay.Var (or Constant) from a nnef variable tensor""" + tens_data = tensor.data + if not self._keep_params_in_input: + self._consts[tensor.name] = tvm_expr.const(tens_data) + self._nodes[tensor.name] = self._consts[tensor.name] + else: + var = self._new_var(tensor.name, shape=tensor.shape, dtype=get_type(tensor.dtype)) + self._nodes[tensor.name] = var + self._params[tensor.name] = (var, tvm.nd.array(tens_data)) + + def _set_literal_inputs(self, node): + """Checks if node has literal inputs and saves them into a tvm.relay.Constant. + naming as {node.name}_{input field name}""" + for field_name, value in node.inputs.items(): + if isinstance(value, list): + for v in value: + if v not in self._nodes.keys(): + self._nodes[f"{node.name}_{v}"] = tvm_expr.const(v) + + else: + if value not in self._nodes.keys(): + self._nodes[f"{node.name}_{field_name}"] = tvm_expr.const(value) + + def _get_relay_op_call(self, name, inputs, attrs): + """Returns the tvm.Call equivalent to the nnef operator""" + conv_map = _get_converter_map() + if name in conv_map: + + call = conv_map[name](self._bb, *inputs, **attrs) + else: + # This error is reached if NNEF is expanded with additional ops + raise NotImplementedError( + f"Operator {name} is not implemented, as {name} has been added after 1.0.5." + ) + return call + + def _infer_type(self, val): + if isinstance(val, bool): + return "bool", True + if isinstance(val, float): + return "float32", True + if isinstance(val, int): + return "int32", True + if isinstance(val, str): + # the string vals can be names of nodes in some of the cases + if isinstance(val, nnef.Identifier): + if val in self._nodes.keys(): + node = self._nodes[val] + if isinstance(node, tvm_expr.Var): + return node.type_annotation.dtype, False + if isinstance(node, tvm_expr.Constant): + return node.data.dtype, False + if isinstance(node, tvm_expr.Call): + return node.checked_type.dtype, False + raise Exception( + f"{val} has not been loaded into the model " + "but it should have been, as a var or call." + ) + return "string", True + + raise TypeError(f'Value "{val}" is not a recognized type') + + def _new_var(self, name, shape, dtype="float32"): + return relax.Var( + name_hint=name, + struct_info=relax.TensorStructInfo(shape=shape, dtype=dtype), + ) + + +def from_nnef( + model: typing.Union[str, os.PathLike, nnef.Graph], + keep_params_in_input: bool = False, +) -> IRModule: + """ + Convert an NNEF model into an equivalent TVM Relay IRModule. + + + Parameters + ---------- + model : os.PathLike or str or nnef.Graph + Path to an NNEF model directory, containing the graph.nnef (and weight files) + + keep_params_in_input : bool, optional + If this parameter is true, the nnef variables will be converted to + constants, and be embedded into the relax model, allowing optimizations + at compile time. + If False the params will have to be added as inputs, + the model can't load them automatically + + Returns + ------- + mod : tvm.IRModule + The relay module for compilation + + params : dict of str to tvm.nd.NDArray + The parameter dictionary to be used + """ + conv_clss = NNEFConverter(keep_params_in_input) + + if not isinstance(model, nnef.Graph): + model = nnef.load_graph(model) + + # fills in the nnef graph's shape information + nnef.infer_shapes(model) + + return conv_clss.from_nnef(graph=model) diff --git a/python/tvm/relax/frontend/nnef/nnef_ops.py b/python/tvm/relax/frontend/nnef/nnef_ops.py new file mode 100644 index 000000000000..c4150c64d211 --- /dev/null +++ b/python/tvm/relax/frontend/nnef/nnef_ops.py @@ -0,0 +1,1957 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""NNEF frontend converter helper funcs and ops""" +import math + +import itertools +from functools import reduce + +import numpy as np + +import tvm +from tvm import relax +from tvm.relax import expr as tvm_expr +from tvm.relax import op as tvm_op +from tvm import topi + + +# Base methods + + +def dimension_picker(prefix, kernel_shape, suffix=""): + """ + Returns the correct name for nth dimensional operator. Uses the "kernel_shape" attribute.\n + E.g.call: dimension_picker(op_name)(attr) + + :param prefix: the name of the operator (e.g. conv) + :param kernel_shape: shape of the tensor to fit the operation + :param suffix: optional suffix for ops + :return: "prefix`n`d" where n is the correct dimension for the kernel + """ + + rank = len(kernel_shape[2:]) + if rank == 1: + return prefix + "1d" + suffix + if rank == 2: + return prefix + "2d" + suffix + if rank == 3: + return prefix + "3d" + suffix + op_name = prefix + "1d/2d/3d" + msg = f"Only 1D, 2D, and 3D kernels are supported for operator {op_name}." + raise tvm.error.OpAttributeInvalid(msg) + + +def _size_conv(size, rank): + # window of size (DH)W is only possible when it is checked outside, + # which is needed for alternative solution + if rank == 3: + if len(size) == 1: + return size + if len(size) == 3: + assert ( + size[0] == 1 and size[1] == 1 + ), "Incorrect window dimensions, first two dimensions must be 1" + return size[2] + if rank == 4: + if len(size) == 2: + return size + if len(size) == 4: + assert ( + size[0] == 1 and size[1] == 1 + ), "Incorrect window dimensions, first two dimensions must be 1" + return size[2:] + if rank == 5: + if len(size) == 3: + return size + if len(size) == 5: + assert ( + size[0] == 1 and size[1] == 1 + ), "Incorrect window dimensions, first two dimensions must be 1" + return size[2:] + + raise ValueError(f"Unexpected window size, got {len(size)}") + + +def _stride_conv(stride, rank): + if rank == 3: + # {conv style} :: [s] -> [s] + if len(stride) == 1: + return stride + # {pool style} :: [N, C, s] -> asrt N,C == 1; [s] + if len(stride) == 3: + assert ( + stride[0] == 1 and stride[1] == 1 + ), "Not supported stride dimensions, first two dimensions must be 1" + return stride[2:] + if rank == 4: + # {conv style} :: [sh, sw] -> [sh, sw] + if len(stride) == 2: + return stride + # {pool style} :: [N, C, sh, sw] -> asrt N,C == 1; [sh, sw] + if len(stride) == 4: + assert ( + stride[0] == 1 and stride[1] == 1 + ), "Not supported stride dimensions, first two dimensions must be 1" + return stride[2:] + if rank == 5: + # {conv style} :: [sd, sh, sw] -> [sd, sh, sw] + if len(stride) == 3: + return stride + # {pool style} :: [N, C, sd, sh, sw] -> asrt N,C == 1; [sd, sh, sw] + if len(stride) == 5: + assert ( + stride[0] == 1 and stride[1] == 1 + ), "Not supported stride dimensions, first two dimensions must be 1" + return stride[2:] + raise ValueError(f"Unexpected stride in {rank - 2}D, got {len(stride)}: {stride}") + + +def _padding_conv(padding, rank, keepdims=False): + if isinstance(padding[0], (tuple, list)): + # 1D + if rank == 3: + # {conv style} :: [(l,r)] -> (l,r) + if len(padding) == 1: + return padding[0] + if len(padding) == 3: + # {pool style} :: [(batch),(channel),(l,r)] -> asrt N,C == 0, (l,r) + if not keepdims: + assert padding[0] == (0, 0) and padding[1] == (0, 0), ( + "Incorrect padding. " "Padding on C,I dimensions not supported" + ) + return padding[2] + # {sliding window style} :: [(batch),(channel),(l,r)] -> [(batch),(channel),(l,r)] + else: + return padding + + # 2D + + if rank == 4: + # {conv style} :: [(u,d),(l,r)] -> (u, l, d, r) + if len(padding) == 2: + # change UDLR to ULDR padding, LC is faster here + return [x[i] for i in [0, 1] for x in padding] + + if len(padding) == 4: + # {pool style} :: [(batch size),(channel),(u,d),(l,r)] -> + # -> asrt N,C == 0, (u, l, d, r) + if not keepdims: + assert padding[0] == (0, 0) and padding[1] == (0, 0), ( + "Incorrect padding. " "Padding on C,I dimensions not supported" + ) + # itertools is faster than LC (slicing) + return list(itertools.chain.from_iterable(zip(padding[2], padding[3]))) + # {sliding window style} :: [(batch),(channel),(u,d),(l,r)] -> + # -> [(batch),(channel),(u,d),(l,r)] + else: + return padding + + # 3D + + if rank == 5: + # {conv style} :: [(f,b),(u,d),(l,r)] -> (f, u, l, b, d, r) + if len(padding) == 3: + # LC is faster + return [x[i] for i in [0, 1] for x in padding] + + if len(padding) == 5: + # {pool style} :: [(batch size),(channel),(f,b)(u,p),(l,r)] -> + # -> asrt N,C == 0, (f, u, l, b, d, r) + if not keepdims: + assert padding[0] == (0, 0) and padding[1] == (0, 0), ( + "Incorrect padding. " "Padding on C,I dimensions not supported" + ) + # itertools faster barely + return list( + itertools.chain.from_iterable(zip(padding[2], padding[3], padding[4])) + ) + # {s-w style} :: [(batch),(channel),(f,b),(u,d),(l,r)] -> + # -> [(batch),(channel),(f,b),(u,d),(l,r)] + else: + return padding + + raise ValueError( + f"Incorrect padding style for {rank - 2}D operand. Only length of {rank - 2}, {rank} " + f"supported, got {len(padding)}: {padding}" + ) + + raise ValueError("nnef should not have singular padding") + + +def _calculate_nnef_padding(active_shape, strides, kernel_shape, dilation): + """Ordering of nnef autopad and tvm autopad are sometimes different, + this method calculates nnef like padding from dimensions + + Parameters + ---------- + active_shape + the data dimensions + strides + the strides over the active dimensions + kernel_shape + the shape of the window, must have the same rank as active shape + dilation + the dilations over the active dimensions + """ + output = [(ui + (s - 1)) // s for ui, s in zip(active_shape, strides)] + dilated = [(f - 1) * d + 1 for f, d in zip(kernel_shape, dilation)] + total = [ + max(0, (di - 1) * s + df - ui) + for di, s, df, ui in zip(output, strides, dilated, active_shape) + ] + padding = [(pad // 2, (pad + 1) // 2) for pad in total] + return padding + + +def _calculate_nnef_padding_deconv(data_sh, strides, kernel_active_sh, dilation, output_shape): + out_sh = output_shape[2:] if output_shape else [ui * s for ui, s in zip(data_sh, strides)] + dilated = [(f - 1) * d + 1 for f, d in zip(kernel_active_sh[2:], dilation)] + total = [ + max(0, (di - 1) * s + df - ui) for di, s, df, ui in zip(data_sh, strides, dilated, out_sh) + ] + return total, out_sh + + +def __unexpected_attrs(op, kwargs): + raise NotImplementedError( + f"{op} received unexpected attributes(s), possibly mismatched versions. " + "Attributes(s) ignored: " + ", ".join(f"{k} := {v}" for k, v in kwargs.items()) + ) + + +# Conversion map, operator functions + + +def _get_converter_map(): + return { # Unary + "copy": copy_converter, # arithmetic + "neg": neg_converter, + "rcp": rcp_converter, + "exp": exp_converter, + "log": log_converter, + "sin": sin_converter, + "cos": cos_converter, + "tan": tan_converter, + "sinh": sinh_converter, + "cosh": cosh_converter, + "tanh": tanh_converter, + "asin": asin_converter, + "acos": acos_converter, + "atan": atan_converter, + "asinh": asinh_converter, + "acosh": acosh_converter, + "atanh": atanh_converter, + "abs": abs_converter, + "sign": sign_converter, + "not": not_converter, # logical + "floor": floor_converter, # rounding + "ceil": ceil_converter, + "round": round_converter, + # Binary + "add": add_converter, # arithmetic + "sub": sub_converter, + "mul": mul_converter, + "div": div_converter, + "pow": pow_converter, + "lt": lt_converter, # comparison + "gt": gt_converter, + "le": le_converter, + "ge": ge_converter, + "eq": eq_converter, + "ne": ne_converter, + "and": and_converter, # logical + "or": or_converter, + # select + "select": select_converter, + # simplifier + "sqr": sqr_converter, + "sqrt": sqrt_converter, + "rsqr": rsqr_converter, + "rsqrt": rsqrt_converter, + "log2": log2_converter, + "min": min_converter, + "max": max_converter, + "clamp": clamp_converter, + # sliding-window + "conv": conv_converter, + "deconv": deconv_converter, + "box": box_converter, + "debox": debox_converter, + "argmax_pool": ndop, + "sample": ndop, + "desample": ndop, + "nearest_downsample": nearest_downsample_converter, + "area_downsample": area_downsample_converter, + "nearest_upsample": nearest_upsample_converter, + "multilinear_upsample": multilinear_upsample_converter, + # reduce + "sum_reduce": sum_reduce_converter, + "max_reduce": max_reduce_converter, + "min_reduce": min_reduce_converter, + "argmax_reduce": argmax_reduce_converter, + "argmin_reduce": argmin_reduce_converter, + "all_reduce": all_reduce_converter, + "any_reduce": any_reduce_converter, + "mean_reduce": mean_reduce_converter, + # tensor shape + "reshape": reshape_converter, + "squeeze": squeeze_converter, + "unsqueeze": unsqueeze_converter, + "transpose": transpose_converter, + "split": split_converter, + "concat": concat_converter, + "stack": stack_converter, + "unstack": unstack_converter, + "slice": slice_converter, + "pad": pad_converter, + "tile": tile_converter, + # region-of-interest - not needed - not supported + "avg_roi_pool": ndop, + "max_roi_pool": ndop, + "roi_resample": ndop, + "avg_roi_align": ndop, + "max_roi_align": ndop, + # matrix multiplication + "matmul": matmul_converter, + # variables + "update": ndop, # --- not used + # Compound + "sigmoid": sigmoid_converter, # activation + "relu": relu_converter, + "prelu": prelu_converter, + "leaky_relu": leaky_relu_converter, + "elu": elu_converter, + "selu": selu_converter, + "gelu": gelu_converter, + "silu": silu_converter, + "softmax": softmax_converter, + "softplus": softplus_converter, + "linear": linear_converter, # linear + "separable_conv": separable_conv_converter, + "separable_deconv": separable_deconv_converter, + "max_pool_with_index": ndop, # pooling + "max_pool": max_pool_converter, + "avg_pool": avg_pool_converter, + "rms_pool": rms_pool_converter, + "local_response_normalization": local_response_normalization_converter, # normalization + "local_mean_normalization": local_mean_normalization_converter, + "local_variance_normalization": local_variance_normalization_converter, + "local_contrast_normalization": local_contrast_normalization_converter, + "l1_normalization": l1_normalization_converter, + "l2_normalization": l2_normalization_converter, + "batch_normalization": batch_normalization_converter, + "min_max_linear_quantize": ndop, # quantization + "zero_point_linear_quantize": ndop, + "linear_quantize": ndop, + "logarithmic_quantize": ndop, + # MISC + "copy_n": ndop, + "add_n": ndop, + "moments": ndop, + } + + +# pylint: disable=unused-argument + +# not implemented ops +def ndop(*args, **kwargs): + # print(args, kwargs) + raise Exception("Not supported operator was called, please check for compatibility") + + +# # Unary ops + + +def copy_converter(bbuilder, data, **kwargs): + """Copy converter""" + if kwargs: + __unexpected_attrs("copy", kwargs) + + return bbuilder.emit_te(topi.identity, data) + + +def neg_converter(bbuilder, data, **kwargs): + """Neg converter""" + if kwargs: + __unexpected_attrs("neg", kwargs) + + return relax.op.unary.negative(data) + + +def rcp_converter(bbuilder, data, **kwargs): + """Rcp converter""" + if kwargs: + __unexpected_attrs("rcp", kwargs) + + if isinstance(data, relax.Call): + d_type = data.checked_type.dtype + else: + d_type = data.struct_info.dtype + + return div_converter(bbuilder, tvm_expr.const(1, dtype=d_type), data) + + +def exp_converter(bbuilder, data, **kwargs): + """Exp converter""" + if kwargs: + __unexpected_attrs("exp", kwargs) + + return relax.op.unary.exp(data) + + +def log_converter(bbuilder, data, **kwargs): + """Log converter""" + if kwargs: + __unexpected_attrs("log", kwargs) + + return relax.op.unary.log(data) + + +def sin_converter(bbuilder, data, **kwargs): + """Sin converter""" + if kwargs: + __unexpected_attrs("sin", kwargs) + + return relax.op.unary.sin(data) + + +def cos_converter(bbuilder, data, **kwargs): + """Cos converter""" + if kwargs: + __unexpected_attrs("cos", kwargs) + + return relax.op.unary.cos(data) + + +def tan_converter(bbuilder, data, **kwargs): + """Tan converter""" + if kwargs: + __unexpected_attrs("tan", kwargs) + + return relax.op.unary.tan(data) + + +def sinh_converter(bbuilder, data, **kwargs): + """Sinh converter""" + if kwargs: + __unexpected_attrs("sinh", kwargs) + + return relax.op.unary.sinh(data) + + +def cosh_converter(bbuilder, data, **kwargs): + """Cosh converter""" + if kwargs: + __unexpected_attrs("cosh", kwargs) + + return relax.op.unary.cosh(data) + + +def tanh_converter(bbuilder, data, **kwargs): + """Tanh converter""" + if kwargs: + __unexpected_attrs("tanh", kwargs) + + return relax.op.unary.tanh(data) + + +def asin_converter(bbuilder, data, **kwargs): + """Asin converter""" + if kwargs: + __unexpected_attrs("asin", kwargs) + + return relax.op.unary.asin(data) + + +def acos_converter(bbuilder, data, **kwargs): + """Acos converter""" + if kwargs: + __unexpected_attrs("acos", kwargs) + + return relax.op.unary.acos(data) + + +def atan_converter(bbuilder, data, **kwargs): + """Atan converter""" + if kwargs: + __unexpected_attrs("atan", kwargs) + + return relax.op.unary.atan(data) + + +def asinh_converter(bbuilder, data, **kwargs): + """Asinh converter""" + if kwargs: + __unexpected_attrs("asinh", kwargs) + + return relax.op.unary.asinh(data) + + +def acosh_converter(bbuilder, data, **kwargs): + """Acosh converter""" + if kwargs: + __unexpected_attrs("acosh", kwargs) + + return relax.op.unary.acosh(data) + + +def atanh_converter(bbuilder, data, **kwargs): + """Atanh converter""" + if kwargs: + __unexpected_attrs("atanh", kwargs) + + return relax.op.unary.atanh(data) + + +def abs_converter(bbuilder, data, **kwargs): + """Abs converter""" + if kwargs: + __unexpected_attrs("abs", kwargs) + + return relax.op.unary.abs(data) + + +def sign_converter(bbuilder, data, **kwargs): + """Sign converter""" + if kwargs: + __unexpected_attrs("sign", kwargs) + + return relax.op.unary.sign(data) + + +def not_converter(bbuilder, data, **kwargs): + """Not converter""" + if kwargs: + __unexpected_attrs("not", kwargs) + + return relax.op.unary.logical_not(data) + + +def floor_converter(bbuilder, data, **kwargs): + """Floor converter""" + if kwargs: + __unexpected_attrs("floor", kwargs) + + return relax.op.unary.floor(data) + + +def ceil_converter(bbuilder, data, **kwargs): + """Ceil converter""" + if kwargs: + __unexpected_attrs("ceil", kwargs) + + return relax.op.unary.ceil(data) + + +def round_converter(bbuilder, data, **kwargs): + """Round converter""" + if kwargs: + __unexpected_attrs("round", kwargs) + + return relax.op.unary.round(data) + + +# # Binary ops + + +def add_converter(bbuilder, lhs, rhs, **kwargs): + """Add converter""" + if kwargs: + __unexpected_attrs("add", kwargs) + + return relax.op.binary.add(lhs, rhs) + + +def sub_converter(bbuilder, lhs, rhs, **kwargs): + """Sub converter""" + if kwargs: + __unexpected_attrs("sub", kwargs) + + return relax.op.binary.subtract(lhs, rhs) + + +def mul_converter(bbuilder, lhs, rhs, **kwargs): + """Mul converter""" + if kwargs: + __unexpected_attrs("mul", kwargs) + + lhs = bbuilder.normalize(lhs) + rhs = bbuilder.normalize(rhs) + + l_ndim = len(lhs.struct_info.shape) + r_ndim = len(rhs.struct_info.shape) + + if l_ndim > r_ndim > 0: + rhs = relax.op.expand_dims(rhs, [d + 2 for d in range(l_ndim - r_ndim)]) + if r_ndim > l_ndim > 0: + lhs = relax.op.expand_dims(lhs, [d + 2 for d in range(r_ndim - l_ndim)]) + + return relax.op.binary.multiply(lhs, rhs) + + +def div_converter(bbuilder, lhs, rhs, **kwargs): + """Div converter""" + if kwargs: + __unexpected_attrs("div", kwargs) + + return relax.op.binary.divide(lhs, rhs) + + +def pow_converter(bbuilder, lhs, rhs, **kwargs): + """Pow converter""" + if kwargs: + __unexpected_attrs("pow", kwargs) + + return relax.op.binary.power(lhs, rhs) + + +def lt_converter(bbuilder, lhs, rhs, **kwargs): + """Lt converter""" + if kwargs: + __unexpected_attrs("lt", kwargs) + + return relax.op.binary.less(lhs, rhs) + + +def gt_converter(bbuilder, lhs, rhs, **kwargs): + """Gt converter""" + if kwargs: + __unexpected_attrs("gt", kwargs) + + return relax.op.binary.greater(lhs, rhs) + + +def le_converter(bbuilder, lhs, rhs, **kwargs): + """Le converter""" + if kwargs: + __unexpected_attrs("le", kwargs) + + return relax.op.binary.less_equal(lhs, rhs) + + +def ge_converter(bbuilder, lhs, rhs, **kwargs): + """Ge converter""" + if kwargs: + __unexpected_attrs("ge", kwargs) + + return relax.op.binary.greater_equal(lhs, rhs) + + +def eq_converter(bbuilder, lhs, rhs, **kwargs): + """Eq converter""" + if kwargs: + __unexpected_attrs("eq", kwargs) + + return relax.op.binary.equal(lhs, rhs) + + +def ne_converter(bbuilder, lhs, rhs, **kwargs): + """Ne converter""" + if kwargs: + __unexpected_attrs("ne", kwargs) + + return relax.op.binary.not_equal(lhs, rhs) + + +def and_converter(bbuilder, lhs, rhs, **kwargs): + """And converter""" + if kwargs: + __unexpected_attrs("and", kwargs) + + return relax.op.binary.logical_and(lhs, rhs) + + +def or_converter(bbuilder, lhs, rhs, **kwargs): + """Or converter""" + if kwargs: + __unexpected_attrs("or", kwargs) + + return relax.op.binary.logical_or(lhs, rhs) + + +# # Select op + + +def select_converter(bbuilder, condition, t_val, f_val, **kwargs): + """Select converter""" + if kwargs: + __unexpected_attrs("select", kwargs) + + return relax.op.where(condition, t_val, f_val) + + +# # Simplifier ops + + +def sqr_converter(bbuilder, data, **kwargs): + """sqr converter""" + if kwargs: + __unexpected_attrs("sqr", kwargs) + + if isinstance(data, relax.Call): + d_type = data.checked_type.dtype + else: + d_type = data.struct_info.dtype + + return pow_converter(bbuilder, data, tvm_expr.const(2.0, dtype=d_type)) + + +def sqrt_converter(bbuilder, data, **kwargs): + """sqrt converter""" + if kwargs: + __unexpected_attrs("sqrt", kwargs) + + return relax.op.unary.sqrt(data) + + +def rsqr_converter(bbuilder, data, **kwargs): + """rsqr converter""" + if kwargs: + __unexpected_attrs("rsqr", kwargs) + + if isinstance(data, relax.Call): + d_type = data.checked_type.dtype + else: + d_type = data.struct_info.dtype + + return pow_converter(bbuilder, data, tvm_expr.const(-2.0, dtype=d_type)) + + +def rsqrt_converter(bbuilder, data, **kwargs): + """rsqrt converter""" + if kwargs: + __unexpected_attrs("rsqrt", kwargs) + + return relax.op.unary.rsqrt(data) + + +def log2_converter(bbuilder, data, **kwargs): + """log2 converter""" + if kwargs: + __unexpected_attrs("log2", kwargs) + + # no equivalent in Relax, using TOpI + return bbuilder.emit_te(topi.log2, data) + + +def min_converter(bbuilder, lhs, rhs, **kwargs): + """Min converter""" + if kwargs: + __unexpected_attrs("min", kwargs) + + return relax.op.binary.minimum(lhs, rhs) + + +def max_converter(bbuilder, lhs, rhs, **kwargs): + """Max converter""" + if kwargs: + __unexpected_attrs("max", kwargs) + + return relax.op.binary.maximum(lhs, rhs) + + +def clamp_converter(bbuilder, x, a, b, **kwargs): + """Clamp converter""" + if kwargs: + __unexpected_attrs("clamp", kwargs) + + # only works if b and a are Constant floats, not tensors + if isinstance(a, tvm_expr.Constant) and isinstance(b, tvm_expr.Constant): + return relax.op.clip( + x, tvm_expr.PrimValue(a.data.numpy().item()), tvm_expr.PrimValue(b.data.numpy().item()) + ) + + return max_converter(bbuilder, min_converter(bbuilder, x, b), a) + + +# # Sliding-window ops + + +def conv_converter( + bbuilder, data, kernel, bias, border, stride, padding, dilation, groups, **kwargs +): + """Convolution converter, + skips bias if it's 0.0 (no bias)""" + if kwargs: + __unexpected_attrs("conv", kwargs) + + if border != "constant": + print(f"Currently {border} border is not supported, used `constant` border") + + kernel_shape = [v.value for v in kernel.struct_info.shape.values] + dshape = [v.value for v in data.struct_info.shape.values] + + if hasattr(data.struct_info, "ndim"): + ndim = data.struct_info.ndim + else: + ndim = len(data.struct_info.shape) + + strides = _stride_conv(stride, ndim) if stride else (1,) * (ndim - 2) + + dilation = _stride_conv(dilation, ndim) if dilation else (1,) * (ndim - 2) + + if not padding: + padding = _calculate_nnef_padding(dshape[2:], strides, kernel_shape[2:], dilation) + + pad = _padding_conv(padding, ndim) + + channels = kernel_shape[0] + + if groups == 0: + groups = channels + + if ndim == 3: + op = relax.op.nn.conv1d + elif ndim == 4: + op = relax.op.nn.conv2d + elif ndim == 5: + op = relax.op.nn.conv3d + else: + raise NotImplementedError("Ndim > 5 not supported for convolution.") + + conv_out = op( + data=data, + weight=kernel, + strides=strides, + padding=pad, + dilation=dilation, + groups=groups, + ) + + res = None + if isinstance(bias, tvm_expr.Constant): + # nnef has bias of 0 if it is not needed + if (bias.data.numpy() == 0).all(): + res = conv_out + + if not res: + bias = relax.op.reshape( + bias, + [1, -1] + + [ + 1, + ] + * (ndim - 2), + ) + res = relax.op.add(conv_out, bias) + + return res + + +def deconv_converter( + bbuilder, data, kernel, bias, border, stride, padding, dilation, output_shape, groups, **kwargs +): + """Deconvolution converter, using convxd_transpose + skips bias if it's 0.0 (no bias)""" + if kwargs: + __unexpected_attrs("deconv", kwargs) + + if border != "constant": + print(f"Currently {border} border is not supported, used `constant` border") + + kernel_shape = [v.value for v in kernel.struct_info.shape.values] + + rank = len(kernel_shape) + + strides = _stride_conv(stride, rank) if stride else (1,) * (rank - 2) + + dilation = _stride_conv(dilation, rank) if dilation else (1,) * (rank - 2) + + total, out_sh = _calculate_nnef_padding_deconv( + [v.value for v in data.struct_info.shape.values], + strides, + kernel_shape, + dilation, + output_shape, + ) + + if padding: + pad = _padding_conv(padding, rank) + else: + pad = _padding_conv([(pad // 2, (pad + 1) // 2) for pad in total], rank) + + if groups == 0: + groups = kernel_shape[0] + + # limit output padding to modulo stride because of tvm checks + out_pad = ( + [(x - (y - t)) % s for x, y, t, s in zip(output_shape[2:], out_sh, total, stride)] + if output_shape + else (0, 0) + ) + + if rank == 3: + op = relax.op.nn.conv1d_transpose + elif rank == 4: + op = relax.op.nn.conv2d_transpose + else: + raise NotImplementedError("Ndim > 4 not supported for deconvolution. 3D WIP.") + + deconv_out = op( + data=data, + weight=kernel, + strides=strides, + padding=pad, + dilation=dilation, + groups=groups, + output_padding=out_pad, + ) + + res = None + if isinstance(bias, tvm_expr.Constant): + if (bias.data.numpy() == 0).all(): + res = deconv_out + + if not res: + bias = relax.op.reshape( + bias, + [1, -1] + + [ + 1, + ] + * (rank - 2), + ) + res = relax.op.add(deconv_out, bias) + + return res + + +def box_converter(bbuilder, data, size, border, padding, stride, dilation, normalize, **kwargs): + """Box operator converter, + summation over sliding window, equal to conv with constant filter""" + if kwargs: + __unexpected_attrs("box", kwargs) + + dshape = [v.value for v in data.struct_info.shape.values] + + if isinstance(data, relax.Call): + d_type = data.checked_type.dtype + else: + d_type = data.struct_info.dtype + + if size[:2] == [1, 1]: + size[0] = dshape[1] + if normalize: + kernel = relax.op.full(size, relax.const(1 / math.prod(size[2:]), d_type), d_type) + else: + kernel = relax.op.ones(size, d_type) + + kernel = bbuilder.normalize(kernel) + + out = conv_converter( + bbuilder, + data, + kernel, + tvm_expr.const(0, dtype=d_type), + border, + stride, + padding, + dilation, + dshape[1], + ) + else: + # if boxing on channel or batch dims avg pool can solve with permute + # we need permute indexes with inactive shape + active shape format, so active at the back + + def _apply_permutation(items, perm): + return [items[ind] for ind in perm] + + inactive = [i for i, s in enumerate(size) if s == 1] + active = [i for i, s in enumerate(size) if s != 1] + permuted_ins = inactive + active + inverse = [0] * len(permuted_ins) + for i, p in enumerate(permuted_ins): + inverse[p] = i + + data = relax.op.permute_dims(data, permuted_ins) + size = _apply_permutation(size, permuted_ins) + + data = bbuilder.normalize(data) + + out = avg_pool_converter( + bbuilder, data, size[2:], border, padding, stride[2:], dilation[2:] + ) + + out = relax.op.permute_dims(out, inverse) + + if not normalize: + out = bbuilder.normalize(out) + out = mul_converter( + bbuilder, out, tvm_expr.const(math.prod(size), dtype=out.struct_info.dtype) + ) + + return out + + +def debox_converter( + bbuilder, data, size, border, padding, stride, dilation, normalize, output_shape, **kwargs +): + """Debox operator converter, + inverse of box, equal to deconv with constant filter""" + if kwargs: + __unexpected_attrs("debox", kwargs) + + dshape = [v.value for v in data.struct_info.shape.values] + + if isinstance(data, relax.Call): + d_type = data.checked_type.dtype + else: + d_type = data.struct_info.dtype + + size[0] = dshape[1] + if normalize: + kernel = relax.op.full(relax.const(1 / math.prod(size[2:]), d_type), size, d_type) + else: + kernel = relax.op.ones(size, d_type) + + kernel = bbuilder.normalize(kernel) + + out = deconv_converter( + bbuilder, + data, + kernel, + tvm_expr.const(0, dtype=d_type), + border, + stride, + padding, + dilation, + output_shape, + groups=dshape[1], + ) + return out + + +def nearest_downsample_converter(bbuilder, data, factor, **kwargs): + """Nearest neighbour downsample converter""" + if kwargs: + __unexpected_attrs("nearest_downsample", kwargs) + + dims = 2 + len(factor) + + return box_converter( + bbuilder, + data, + size=[1] * dims, + border="constant", + padding=[(0, 0)] * dims, + stride=[1, 1] + factor, + dilation=(1,) * (dims - 2), + normalize=False, + ) + + +def area_downsample_converter(bbuilder, data, factor, **kwargs): + """Area downsample converter""" + if kwargs: + __unexpected_attrs("area_downsample", kwargs) + + dims = 2 + len(factor) + + return box_converter( + bbuilder, + data, + size=[1, 1] + factor, + border="constant", + padding=[(0, 0)] * dims, + stride=[1, 1] + factor, + dilation=(1,) * (dims - 2), + normalize=True, + ) + + +def nearest_upsample_converter(bbuilder, data, factor, **kwargs): + """Nearest neighbour upsample converter""" + if kwargs: + __unexpected_attrs("nearest_upsample", kwargs) + + dshape = [v.value for v in data.struct_info.shape.values] + new_size = [d * f for d, f in zip(dshape[2:], factor)] + + ndims = len(dshape) + + if ndims == 3: + op = topi.image.resize1d + if ndims == 4: + op = topi.image.resize2d + if ndims == 5: + op = topi.image.resize3d + + return bbuilder.emit_te( + op, + data, + [ + 0, + ] + * ndims, # dummy value so typecheck goes through, roi is not used + new_size, + method="nearest_neighbor", + rounding_method="round", + ) + + +def multilinear_upsample_converter(bbuilder, data, factor, method, border, **kwargs): + """Multilinear upsampling converter""" + if kwargs: + __unexpected_attrs("linear_upsample", kwargs) + + # for aligned and symmetric replicate resize can be used + dshape = [v.value for v in data.struct_info.shape.values] + ndims = len(dshape) + + if ndims == 3: + op = topi.image.resize1d + if ndims == 4: + op = topi.image.resize2d + if ndims == 5: + op = topi.image.resize3d + + new_size = [d * f for d, f in zip(dshape[2:], factor)] + if method == "aligned": + # conversion from nn.upsampling to image.resizexd, re: discuss:11650 + return bbuilder.emit_te( + op, + data, + [ + 0, + ] + * ndims, # dummy value so typecheck goes through, roi is not used + new_size, + method="linear", + coordinate_transformation_mode="align_corners", + ) + if method == "symmetric" and border == "replicate": + return bbuilder.emit_te( + op, + data, + [ + 0, + ] + * ndims, # dummy value so typecheck goes through, roi is not used + new_size, + method="linear", + coordinate_transformation_mode="half_pixel", + ) + + # other combinations need to be calculated with convolution + def _upsample_weights_1d(fact, symm): + if symm: + _weights = [1 - (i + 0.5) / fact for i in range(fact)] + _weights = list(reversed(_weights)) + _weights + else: + _weights = [1 - abs(i) / float(fact) for i in range(-fact + 1, fact)] + return np.array(_weights) + + def _upsample_weights_nd(fact, symm): + _weights = [_upsample_weights_1d(f, symm) for f in fact] + return reduce(np.multiply, np.ix_(*_weights)) + + n, c = dshape[:2] + + symmetric = method == "symmetric" + weights = _upsample_weights_nd(factor, symmetric) + weights = np.reshape(weights, newshape=(1, 1) + weights.shape) + kernel = tile_converter(bbuilder, tvm_expr.const(weights), (c, 1) + (1,) * len(factor)) + kernel = bbuilder.normalize(kernel) + + output_shape = [n, c] + [f * s for f, s in zip(factor, dshape[2:])] + + if symmetric: + return deconv_converter( + bbuilder, + data, + kernel, + tvm_expr.const(0.0), + border="constant", + stride=factor, + padding=[(f - 1, f - 1) for f in factor], + dilation=[], + groups=c, + output_shape=output_shape, + ) + else: + replicate = border == "replicate" + if replicate: + data = pad_converter( + bbuilder, + data, + [(0, 0), (0, 0)] + [(1, 0)] * len(factor), + border, + tvm_expr.const(0.0), + ) + data = bbuilder.normalize(data) + padding = factor + else: + padding = [f // 2 for f in factor] + + return deconv_converter( + bbuilder, + data, + kernel, + tvm_expr.const(0.0), + border="constant", + stride=factor, + padding=[(p, p - 1) for p in padding], + dilation=[], + groups=c, + output_shape=output_shape, + ) + + +# # Reduce ops + + +def sum_reduce_converter(bbuilder, data, axes, normalize, keepdims=True, **kwargs): + """Sum reduce converter""" + + if kwargs: + __unexpected_attrs("sum_reduce", kwargs) + + out = relax.op.sum(data, axes, keepdims=keepdims) + if normalize: + return l2_normalization_converter(bbuilder, out, 0, [x - 2 for x in axes], 0.0) + return out + + +def max_reduce_converter(bbuilder, data, axes, keepdims=True, **kwargs): + """Max reduce converter""" + if kwargs: + __unexpected_attrs("max_reduce", kwargs) + + return relax.op.max(data, axes, keepdims=keepdims) + + +def min_reduce_converter(bbuilder, data, axes, keepdims=True, **kwargs): + """Min reduce converter""" + if kwargs: + __unexpected_attrs("min_reduce", kwargs) + + return relax.op.min(data, axes, keepdims=keepdims) + + +def argmax_reduce_converter(bbuilder, data, axes, keepdims=True, **kwargs): + """Argmax reduce converter""" + if kwargs: + __unexpected_attrs("argmax_reduce", kwargs) + + # relax.op.argmax only supports singular axis, using TOpI + return bbuilder.emit_te(topi.argmax, data, axes, keepdims=keepdims) + + +def argmin_reduce_converter(bbuilder, data, axes, keepdims=True, **kwargs): + """Argmin reduce converter""" + if kwargs: + __unexpected_attrs("argmin_reduce", kwargs) + + # relax.op.argmin only supports singular axis, using TOpI + return bbuilder.emit_te(topi.argmin, data, axes, keepdims=keepdims) + + +def all_reduce_converter(bbuilder, data, axes, keepdims=True, **kwargs): + """All reduce converter""" + if kwargs: + __unexpected_attrs("all_reduce", kwargs) + + # no equivalent in Relax, using TOpI + return bbuilder.emit_te(topi.all, data, axes, keepdims) + + +def any_reduce_converter(bbuilder, data, axes, keepdims=True, **kwargs): + """Any reduce converter""" + if kwargs: + __unexpected_attrs("any_reduce", kwargs) + + # no equivalent in Relax, using TOpI + return bbuilder.emit_te(topi.any, data, axes, keepdims) + + +def mean_reduce_converter(bbuilder, data, axes, keepdims=True, **kwargs): + """Mean reduce converter""" + if kwargs: + __unexpected_attrs("mean_reduce", kwargs) + + return relax.op.mean(data, axes, keepdims=keepdims) + + +# # Tensor shape ops + + +def reshape_converter(bbuilder, data, shape, axis_start, axis_count, **kwargs): + """Reshape converter""" + if kwargs: + __unexpected_attrs("reshape", kwargs) + + dshape = [v.value for v in data.struct_info.shape.values] + if axis_count == -1: + newshape = dshape[:axis_start] + shape + else: + newshape = dshape + newshape[axis_start : axis_start + axis_count] = shape + + return relax.op.reshape(data, newshape) + + +def squeeze_converter(bbuilder, data, axes, **kwargs): + """Squeeze converter""" + if kwargs: + __unexpected_attrs("squeeze", kwargs) + + return relax.op.squeeze(data, axes) + + +def unsqueeze_converter(bbuilder, data, axes, **kwargs): + """Unsqueeze converter""" + if kwargs: + __unexpected_attrs("unsqueeze", kwargs) + + axes = sorted(axes) + for axis in axes: + if axis < 0 and isinstance(data, tvm_expr.Var): + axis = len(data.type_annotation.concrete_shape) + len(axes) + axis + + data = tvm_op.expand_dims(data, axis=axis) + return data + + +def transpose_converter(bbuilder, data, axes, **kwargs): + """Transpose converter""" + if kwargs: + __unexpected_attrs("transpose", kwargs) + + return relax.op.permute_dims(data, axes) + + +def split_converter(bbuilder, data, axis, ratios, **kwargs): + """Split converter""" + if kwargs: + __unexpected_attrs("split", kwargs) + + axis_len = [v.value for v in data.struct_info.shape.values][axis] + rat_mul = axis_len / sum(ratios) + ratio_list = [(r * rat_mul) for r in ratios] + + s = 0 + indices = [] + for rat in ratio_list[:-1]: + s += rat + # Strictly needs int + indices.append(int(s)) + + return relax.op.split(data, indices, axis) + + +def concat_converter(bbuilder, *data, axis, **kwargs): + """Concat converter""" + if kwargs: + __unexpected_attrs("concat", kwargs) + + return relax.op.concat(data, axis) + + +def stack_converter(bbuilder, *data, axis, **kwargs): + """Stack converter""" + if kwargs: + __unexpected_attrs("stack", kwargs) + + data = [relax.op.expand_dims(d, axis) for d in data] + + return relax.op.concat(data, axis) + + +def unstack_converter(bbuilder, data, axis, **kwargs): + """Unstack converter""" + if kwargs: + __unexpected_attrs("unstack", kwargs) + + split = split_converter( + bbuilder, data, axis, [1] * [v.value for v in data.struct_info.shape.values][axis] + ) + split = bbuilder.normalize(split) + res = [] + + for i in range(len(split.struct_info.fields)): + res.append(squeeze_converter(bbuilder, split[i], axis)) + return tvm.relax.Tuple(relax.Tuple(res)) + + +def slice_converter(bbuilder, data, axes, begin, end, stride, **kwargs): + """Slice converter""" + if kwargs: + __unexpected_attrs("slice", kwargs) + + if not stride: + stride = [1] * len(axes) + + return relax.op.strided_slice(data, begin=begin, end=end, strides=stride, axes=axes) + + +def pad_converter(bbuilder, data, padding, border, value, **kwargs): + """Pad converter""" + if kwargs: + __unexpected_attrs("pad", kwargs) + + if border not in ["constant", "replicate", "reflect"]: + print(f"{border} border type is not supported in padding. Assumed constant") + border = "constant" + if border == "replicate": + border = "edge" + + # padding can only be tuple even though docs say tuple> + pad = sum(padding, ()) + pad_before, pad_after = zip(*padding) + + # reflect can only work with TOPI mirror_pad + if border == "reflect": + return bbuilder.emit_te(tvm.topi.nn.mirror_pad, data, pad_before, pad_after, "REFLECT") + if border == "edge": + raise tvm.error.OpNotImplemented( + "Replicate - Edge mode is currently not supperted in TVM relax" + ) + + # constant works with normal relax.nn.pad + return relax.op.nn.pad(data, pad, value, border) + + +def tile_converter(bbuilder, data, repeats, **kwargs): + """Tile converter""" + if kwargs: + __unexpected_attrs("tile", kwargs) + + return relax.op.tile(data, repeats) + + +# # Region-of-interest ops + + +# # Matrix multiplication +def matmul_converter(bbuilder, a, b, **kwargs): + """Matmul converter + real signature: matmul_converter(a, b, transposeA, transposeB)""" + + transpose_a = kwargs.pop("transposeA") + transpose_b = kwargs.pop("transposeB") + if kwargs: + __unexpected_attrs("matmul", kwargs) + + if transpose_a: + ndim = len(a.struct_info.shape.values) + axes = list(range(ndim - 2)) + axes.append(ndim - 1) + axes.append(ndim - 2) + a = relax.op.permute_dims(a, axes) + + if transpose_b: + ndim = len(a.struct_info.shape.values) + axes = list(range(ndim - 2)) + axes.append(ndim - 1) + axes.append(ndim - 2) + b = relax.op.permute_dims(b, axes) + + a = bbuilder.normalize(a) + b = bbuilder.normalize(b) + + return relax.op.matmul(a, b) + + +# # Variable updates +# # Compound ops + + +def sigmoid_converter(bbuilder, data, **kwargs): + """Sigmoid converter""" + if kwargs: + __unexpected_attrs("sigmoid", kwargs) + + return relax.op.unary.sigmoid(data) + + +def relu_converter(bbuilder, data, **kwargs): + """RELU converter""" + if kwargs: + __unexpected_attrs("relu", kwargs) + + return relax.op.nn.relu(data) + + +def prelu_converter(bbuilder, data, alpha, **kwargs): + """PRELU converter""" + if kwargs: + __unexpected_attrs("prelu", kwargs) + + # prelu can't handle float vals but NNEF supports direct parameter, this is just in case + if isinstance(alpha, tvm_expr.Constant): + if alpha.data.numpy().size == 1: + return relax.op.nn.leakyrelu(data, alpha.data.numpy().item()) + + # alpha needs to be a tensor whose rank is the same as of data, + # and the only non 1 dim is the channel dims + axes = [ + 0, + ] + [a + 2 for a in range(data.struct_info.ndim - 2)] + alpha = relax.op.expand_dims(alpha, axes) + + # using select for prelu + return select_converter( + bbuilder, data < tvm_expr.const(0.0), mul_converter(bbuilder, alpha, data), data + ) + + +def leaky_relu_converter(bbuilder, data, alpha, **kwargs): + """Leaky RELU converter""" + if kwargs: + __unexpected_attrs("leaky_relu", kwargs) + + return relax.op.nn.leakyrelu(data, alpha) + + +def elu_converter(bbuilder, data, alpha, **kwargs): + """ELU converter""" + if kwargs: + __unexpected_attrs("elu", kwargs) + + return select_converter( + bbuilder, + lt_converter(bbuilder, data, tvm_expr.const(0.0)), + mul_converter( + bbuilder, + tvm_expr.const(alpha), + sub_converter(bbuilder, exp_converter(bbuilder, data), tvm_expr.const(1.0)), + ), + data, + ) + + +def selu_converter(bbuilder, data, alpha, **kwargs): + """SELU converter + True signature is selu_converter(data, alpha, lambda)""" + lambda_var = kwargs.pop("lambda") + + if kwargs: + __unexpected_attrs("selu", kwargs) + + return mul_converter( + bbuilder, + tvm_expr.const(lambda_var), + select_converter( + bbuilder, + data < tvm_expr.const(0.0), + mul_converter( + bbuilder, + tvm_expr.const(alpha), + sub_converter(bbuilder, exp_converter(bbuilder, data), tvm_expr.const(1.0)), + ), + data, + ), + ) + + +def gelu_converter(bbuilder, data, **kwargs): + """GELU converter + NNEF definition for GELU: + the exact definition of GELU is x * Phi(x) where Phi(x) is the + CDF of the standard normal distribution, which can be approximated + for example by sigmoid(1.702 * x) + + `mul_converter(data, sigmoid_converter(mul_converter(tvm_expr.const(1.702), data)))` + + But in this case we will use the erf to calculate normcdf (same as to pytorch GELU impl) + """ + if kwargs: + __unexpected_attrs("gelu", kwargs) + + return relax.op.nn.gelu(data) + + +def silu_converter(bbuilder, data, **kwargs): + """SiLU converter""" + if kwargs: + __unexpected_attrs("silu", kwargs) + + return mul_converter(bbuilder, data, sigmoid_converter(bbuilder, data)) + + +def softmax_converter(bbuilder, data, axes, **kwargs): + """Softmax converter""" + if kwargs: + __unexpected_attrs("softmax", kwargs) + + if len(axes) > 1: + print("Multiple axes not supported, operation has been done along the first axis in axes.") + axis = axes[0] + + return relax.op.nn.softmax(data, axis) + + +def softplus_converter(bbuilder, data, **kwargs): + """Softplus converter""" + if kwargs: + __unexpected_attrs("softplus", kwargs) + + return log_converter( + bbuilder, add_converter(bbuilder, exp_converter(bbuilder, data), tvm_expr.const(1.0)) + ) + + +# # linear ops + + +def linear_converter(bbuilder, data, _filter, bias, **kwargs): + """Linear converter""" + if kwargs: + __unexpected_attrs("linear", kwargs) + + out = matmul_converter(bbuilder, data, _filter, transposeA=False, transposeB=True) + out = bbuilder.normalize(out) + res = None + + if isinstance(bias, tvm_expr.Constant): + if (bias.data.numpy() == 0).all(): + res = out + + if hasattr(data.struct_info, "ndim"): + ndim = data.struct_info.ndim + else: + ndim = len(data.struct_info.shape) + + if not res: + bias = relax.op.reshape( + bias, + [1, -1] + + [ + 1, + ] + * (ndim - 2), + ) + res = relax.op.add(out, bias) + + return res + + +def separable_conv_converter( + bbuilder, + data, + plane_filter, + point_filter, + bias, + border, + padding, + stride, + dilation, + groups, + **kwargs, +): + """Separable convolution converter""" + if kwargs: + __unexpected_attrs("separable_conv", kwargs) + + if isinstance(data, relax.Call): + d_type = data.checked_type.dtype + else: + d_type = data.struct_info.dtype + + filtered = conv_converter( + bbuilder, + data, + plane_filter, + tvm_expr.const(0, dtype=d_type), + border, + stride, + padding, + dilation, + 0, + ) + + filtered = bbuilder.normalize(filtered) + + return conv_converter(bbuilder, filtered, point_filter, bias, "constant", [], [], [], groups) + + +def separable_deconv_converter( + bbuilder, + data, + plane_filter, + point_filter, + bias, + border, + padding, + stride, + dilation, + output_shape, + groups, + **kwargs, +): + """Separable deconvolution converter""" + if kwargs: + __unexpected_attrs("separable_deconv", kwargs) + + if isinstance(data, relax.Call): + d_type = data.checked_type.dtype + else: + d_type = data.struct_info.dtype + + filtered = deconv_converter( + bbuilder, + data, + point_filter, + tvm_expr.const(0, dtype=d_type), + "constant", + [], + [], + [], + [], + groups, + ) + + filtered = bbuilder.normalize(filtered) + + return deconv_converter( + bbuilder, filtered, plane_filter, bias, border, stride, padding, dilation, output_shape, 0 + ) + + +def max_pool_converter(bbuilder, data, size, border, padding, stride, dilation, **kwargs): + """Max pool converter""" + if kwargs: + __unexpected_attrs("max_pool", kwargs) + + if border != "constant": + print(f"Currently {border} border is not supported, used `constant` border") + + dshape = [v.value for v in data.struct_info.shape.values] + rank = len(dshape) + + pool_size = _size_conv(size, rank) + strides = _stride_conv(stride, rank) if stride else (1,) * (rank - 2) + + dilation = _stride_conv(dilation, rank) if dilation else (1,) * (rank - 2) + + if not padding: + # padding is truncated to `conv style` (only active layers are present) + padding = _calculate_nnef_padding(dshape[2:], strides, pool_size, dilation) + + pad = _padding_conv(padding, rank) + + if border == "constant": + padding = [(0, 0), (0, 0)] + padding + data = pad_converter(bbuilder, data, padding, border, tvm_expr.const(0.0)) + data = bbuilder.normalize(data) + pad = (0, 0) + + if rank == 3: + op = relax.op.nn.max_pool1d + elif rank == 4: + op = relax.op.nn.max_pool2d + elif rank == 5: + op = relax.op.nn.max_pool3d + else: + raise NotImplementedError("Ndim > 5 not supported for max pool.") + + return op( + data, + pool_size=pool_size, + strides=strides, + dilation=dilation, + padding=pad, + ) + + +def avg_pool_converter(bbuilder, data, size, border, padding, stride, dilation, **kwargs): + """Avg pool converter""" + if kwargs: + __unexpected_attrs("avg_pool", kwargs) + + if border not in ["constant", "ignore"]: + print(f"Currently {border} border is not supported, used `constant` border") + + dshape = [v.value for v in data.struct_info.shape.values] + rank = len(dshape) + pool_size = _size_conv(size, rank) + strides = _stride_conv(stride, rank) if stride else (1,) * (rank - 2) + + dilation = _stride_conv(dilation, rank) if dilation else (1,) * (rank - 2) + + # padding is truncated to `conv style` (only active layers are present) + active_shape = dshape[2:] + if not padding: + padding = _calculate_nnef_padding(active_shape, strides, pool_size, dilation) + + pad = _padding_conv(padding, rank) + + if rank == 3: + op = relax.op.nn.avg_pool1d + elif rank == 4: + op = relax.op.nn.avg_pool2d + elif rank == 5: + op = relax.op.nn.avg_pool3d + else: + raise NotImplementedError("Ndim > 5 not supported for avg pool.") + + return op( + data, + pool_size=pool_size, + strides=strides, + dilation=dilation, + padding=pad, + count_include_pad=border != "ignore", + ) + + +def rms_pool_converter(bbuilder, data, size, border, padding, stride, dilation, **kwargs): + """Rms pool converter""" + if kwargs: + __unexpected_attrs("rms_pool", kwargs) + + return sqrt_converter( + bbuilder, + avg_pool_converter( + bbuilder, + bbuilder.normalize(sqr_converter(bbuilder, data)), + size=size, + border=border, + padding=padding, + stride=stride, + dilation=dilation, + ), + ) + + +# # Normalization + + +def local_response_normalization_converter(bbuilder, data, size, alpha, beta, bias, **kwargs): + """LRN converter""" + if kwargs: + __unexpected_attrs("local_response_normalization", kwargs) + + axis = [i for i in range(len(size)) if size[i] > 1] + if len(axis) == 1: + axis = axis[0] + else: + print("Multi axis LRN is not implemented properly, using first axis where size != 1") + axis = axis[0] + size = size[axis] + + return bbuilder.emit_te(topi.nn.lrn, data, size, axis, alpha, beta, bias) + + +def local_mean_normalization_converter(bbuilder, data, size, **kwargs): + """LMN converter""" + if kwargs: + __unexpected_attrs("local_mean_normalization", kwargs) + + mean = box_converter(bbuilder, data, size, "constant", [], [], [], normalize=True) + mean = bbuilder.normalize(mean) + + return sub_converter(bbuilder, data, mean) + + +def local_variance_normalization_converter(bbuilder, data, size, bias, epsilon, **kwargs): + """LVN converter""" + if kwargs: + __unexpected_attrs("local_variance_normalization", kwargs) + + sigma = box_converter( + bbuilder, + bbuilder.normalize(sqr_converter(bbuilder, data)), + size, + "constant", + [], + [], + [], + normalize=True, + ) + sigma = bbuilder.normalize(sigma) + + return div_converter( + bbuilder, + data, + max_converter( + bbuilder, + add_converter(bbuilder, sqrt_converter(bbuilder, sigma), tvm_expr.const(bias)), + tvm_expr.const(epsilon), + ), + ) + + +def local_contrast_normalization_converter(bbuilder, data, size, bias, epsilon, **kwargs): + """LCN converter""" + if kwargs: + __unexpected_attrs("local_contrast_normalization", kwargs) + + centered = local_mean_normalization_converter(bbuilder, data, size) + centered = bbuilder.normalize(centered) + return local_variance_normalization_converter(bbuilder, centered, size, bias, epsilon) + + +def l1_normalization_converter(bbuilder, data, axes, bias, epsilon, **kwargs): + """L1 norm converter""" + if kwargs: + __unexpected_attrs("l1_normalization", kwargs) + + sigma = sum_reduce_converter(bbuilder, abs_converter(bbuilder, data), axes, False) + return div_converter( + bbuilder, + data, + max_converter( + bbuilder, add_converter(bbuilder, sigma, tvm_expr.const(bias)), tvm_expr.const(epsilon) + ), + ) + + +def l2_normalization_converter(bbuilder, data, axes, bias, epsilon, **kwargs): + """L2 norm converter""" + if kwargs: + __unexpected_attrs("l2_normalization", kwargs) + + # relay style l2_norm not supported, used equation from NNEF + + sigma = sum_reduce_converter( + bbuilder, sqr_converter(bbuilder, data), axes=axes, normalize=False + ) + + res = div_converter( + bbuilder, + data, + max_converter( + bbuilder, + add_converter(bbuilder, sqrt_converter(bbuilder, sigma), tvm_expr.const(bias)), + tvm_expr.const(epsilon), + ), + ) + return res + + +def batch_normalization_converter(bbuilder, data, mean, variance, offset, scale, epsilon, **kwargs): + """Batch norm converter""" + if kwargs: + __unexpected_attrs("batch_normalization", kwargs) + + mean = squeeze_converter(bbuilder, mean, 0) + variance = squeeze_converter(bbuilder, variance, 0) + offset = squeeze_converter(bbuilder, offset, 0) + scale = squeeze_converter(bbuilder, scale, 0) + + mean = bbuilder.normalize(mean) + variance = bbuilder.normalize(variance) + offset = bbuilder.normalize(offset) + scale = bbuilder.normalize(scale) + + res = bbuilder.emit_te(topi.nn.batch_norm, data, scale, offset, mean, variance, 1, epsilon) + return res[0] + + +# # Misc ops diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index fbbd4f99212d..d85d4a837188 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -34,3 +34,4 @@ from .caffe import from_caffe from .paddlepaddle import from_paddle from .change_datatype import ChangeDatatype +from .nnef import from_nnef diff --git a/python/tvm/relay/frontend/nnef.py b/python/tvm/relay/frontend/nnef.py new file mode 100644 index 000000000000..e56cdc84abc4 --- /dev/null +++ b/python/tvm/relay/frontend/nnef.py @@ -0,0 +1,323 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""NNEF: Neural Network Exchange Format frontend for TVM relay""" +import os +import typing +import nnef +import numpy as np + +import tvm +from tvm import relay +from tvm.ir import IRModule +from tvm.relay import expr as tvm_expr +from tvm.relay import analysis, function +from tvm.relay.frontend.common import new_var, fold_constant, set_span, infer_type + +from .nnef_ops import _get_converter_map + + +def get_type(elem_type: str): + """ + Gives numpy style type for nnef primitive types, uses x32 versions. + + :param elem_type: string, (scalar, integer, logical, string) + :return: returns numpy dtype equivalent (float32, int32, bool, string) + """ + if elem_type == "scalar": + return "float32" + if elem_type == "integer": + return "int32" + if elem_type == "logical": + return "bool" + if elem_type == "string": + return "string" + raise TypeError(f'Type "{elem_type}" is not implemented') + + +def make_parameter_span(source_name_list, name_sep="."): + return name_sep.join(source_name_list) + + +# Converter class +class NNEFConverter: + """ + Helper class for class level attributes, for conversion of NNEF model. + Public method to use is from_nnef. + + Parameters + ---------- + + freeze_vars : bool, optional + If this parameter is true, the nnef variables will be converted to + constants, and be embedded into the relay model, allowing optimizations + at compile time. + + """ + + def __init__(self, freeze_vars=False): + self._nodes = {} + self._consts = {} + self._inputs = {} + self._num_inputs = 0 + self._params = {} + self._num_params = 0 + self._freeze_vars = freeze_vars + + def from_nnef(self, graph: nnef.Graph) -> typing.Tuple[tvm.IRModule, dict]: + """ + Convert an NNEF model into an equivalent TVM Relay IRModule. + + Parameters + ---------- + graph : nnef.Graph + An NNEF Graph object that was imported with nnef.load_graph. + Shapes should be inferred by nnef.infer_shapes on graph beforehand. + + Returns + ------- + mod : tvm.IRModule + The relay module for compilation + + params : dict of str to tvm.nd.NDArray + The parameter dictionary to be used + + """ + self._parse_inputs(graph) + self._construct_nodes(graph) + + outputs = [self._nodes[n] for n in graph.outputs] + outputs = outputs[0] if len(outputs) == 1 else tvm_expr.Tuple(outputs) + + nodes = {v: k for k, v in self._nodes.items()} + free_vars = analysis.free_vars(outputs) + free_vars = [nodes[var] for var in free_vars] + for i_name in self._params.keys(): + if i_name in free_vars and i_name not in self._inputs: + self._inputs[i_name] = self._nodes[i_name] + func = function.Function(list(self._inputs.values()), outputs) + return IRModule.from_expr(func), self._params + + def _parse_inputs(self, graph): + """Save inputs into class from inputs attrib of graph""" + for inp in graph.inputs: + self._num_inputs += 1 + tensor = graph.tensors[inp] + self._nodes[inp] = new_var(inp, shape=tensor.shape, dtype=get_type(tensor.dtype)) + self._inputs[inp] = self._nodes[inp] + + def _construct_nodes(self, graph): + """Construct TVM relay calls from every operation of the nnef graph""" + for op in graph.operations: + if op.name == "external": + # externals are handled as input, not needed, + # but nnef treats them as operations as well + continue + + if op.name == "variable": + self._set_variable(graph.tensors[op.outputs["output"]]) + + elif op.name == "constant": + self._set_const(op) + + else: + # every other operator can be grouped more easily, + # as it does not need self for conversion + self._set_operator(op) + + def _set_operator(self, node): + self._set_literal_inputs(node) + self._set_parameter_span(node, node.name) + inputs = [] + for ink, inv in node.inputs.items(): + if isinstance(inv, list): + for i, linv in enumerate(inv): + if linv in self._nodes.keys(): + inputs.append(self._nodes[linv]) + else: # handle literal inputs + name = f"{node.name}_{ink}_{i}" + assert name in self._nodes, f"{name} has not been properly handled" + inputs.append(self._nodes[name]) + + else: + if inv in self._nodes.keys(): + inputs.append(self._nodes[inv]) + else: # handle literal inputs + name = f"{node.name}_{ink}" + assert name in self._nodes, f"{name} has not been properly handled" + inputs.append(self._nodes[name]) + + converted = self._get_relay_op_call(node.name, inputs, node.attribs) + + if not isinstance(converted, tvm_expr.TupleWrapper): + outputs_num = 1 + else: + outputs_num = len(converted) + + if outputs_num == 1: + if not isinstance(converted, tvm_expr.TupleWrapper): + converted = fold_constant(converted) + else: + converted = fold_constant(converted.astuple()) + else: + converted = tvm_expr.TupleWrapper(fold_constant(converted.astuple()), len(converted)) + + converted = set_span(converted, node.name) + + if outputs_num == 1: + # check if the singular ret val is a list of only one element + ret_val = list(node.outputs.values())[0] + if isinstance(ret_val, list): + self._nodes[ret_val[0]] = converted + else: + self._nodes[ret_val] = converted + else: + for i, out in zip(range(outputs_num), node.outputs["values"]): + self._nodes[out] = converted[i] + + def _set_const(self, node): + """Create a tvm.relay.Constant from a nnef constant tensor""" + name = node.outputs["output"] + data = node.attribs["value"] + shape = node.attribs["shape"] + if len(data) == 1: + data = np.full(shape, data, dtype=get_type(node.dtype)) + else: + data = np.array(data, dtype=get_type(node.dtype)) + self._consts[name] = tvm_expr.const(data) + self._nodes[name] = self._consts[name] + + def _set_variable(self, tensor): + """Create a tvm.relay.Var (or Constant if freeze_vars) from a nnef variable tensor""" + tens_data = tensor.data + if self._freeze_vars: + self._consts[tensor.name] = tvm_expr.const(tens_data) + self._nodes[tensor.name] = self._consts[tensor.name] + else: + self._nodes[tensor.name] = new_var( + tensor.name, shape=tensor.shape, dtype=get_type(tensor.dtype) + ) + self._params[tensor.name] = tens_data + + def _set_literal_inputs(self, node): + """Checks if node has literal inputs and saves them into a tvm.relay.Constant. + naming as {node.name}_{input field name}""" + for field_name, value in node.inputs.items(): + if isinstance(value, list): + for v in value: + if v not in self._nodes.keys(): + self._nodes[f"{node.name}_{v}"] = tvm_expr.const(v) + + else: + if value not in self._nodes.keys(): + self._nodes[f"{node.name}_{field_name}"] = tvm_expr.const(value) + + def _set_parameter_span(self, node, node_source_name): + for field_name, name in node.inputs.items(): + if isinstance(name, list): + for n in name: + self._set_par_span_helper(node, node_source_name, n, field_name) + else: + self._set_par_span_helper(node, node_source_name, name, field_name) + + def _set_par_span_helper(self, node, node_source_name, name, field_name): + if name not in self._nodes.keys(): + name = f"{node.name}_{field_name}" + + expr = self._nodes.get(name) + if expr: + expr_with_span = set_span(expr, make_parameter_span([node_source_name, name])) + self._nodes[name] = expr_with_span + if name in self._inputs: + self._inputs[name] = expr_with_span + if isinstance(expr, relay.Constant): + self._consts[name] = expr_with_span + + def _get_relay_op_call(self, name, inputs, attrs): + """Returns the tvm.Call equivalent to the nnef operator""" + conv_map = _get_converter_map() + if name in conv_map: + call = conv_map[name](*inputs, **attrs) + else: + # This error is reached if NNEF is expanded with additional ops + raise NotImplementedError( + f"Operator {name} is not implemented, as {name} has been added after 1.0.5." + ) + return call + + def _infer_type(self, val): + if isinstance(val, bool): + return "bool", True + if isinstance(val, float): + return "float32", True + if isinstance(val, int): + return "int32", True + if isinstance(val, str): + # the string vals can be names of nodes in some of the cases + if isinstance(val, nnef.Identifier): + if val in self._nodes.keys(): + node = self._nodes[val] + if isinstance(node, tvm_expr.Var): + return node.type_annotation.dtype, False + if isinstance(node, tvm_expr.Constant): + return node.data.dtype, False + if isinstance(node, tvm_expr.Call): + return infer_type(node).checked_type.dtype, False + raise Exception( + f"{val} has not been loaded into the model " + "but it should have been, as a var or call." + ) + return "string", True + + raise TypeError(f'Value "{val}" is not a recognized type') + + +def from_nnef( + model: typing.Union[str, os.PathLike, nnef.Graph], + freeze_vars: bool = False, +) -> typing.Tuple[IRModule, dict]: + """ + Convert an NNEF model into an equivalent TVM Relay IRModule. + + + Parameters + ---------- + model : os.PathLike or str or nnef.Graph + Path to an NNEF model directory, containing the graph.nnef (and weight files) + + freeze_vars : bool, optional + If this parameter is true, the nnef variables will be converted to + constants, and be embedded into the relay model, allowing optimizations + at compile time. + + Returns + ------- + mod : tvm.IRModule + The relay module for compilation + + params : dict of str to tvm.nd.NDArray + The parameter dictionary to be used + """ + conv_clss = NNEFConverter(freeze_vars) + + if not isinstance(model, nnef.Graph): + model = nnef.load_graph(model) + + # fills in the nnef graph's shape information + nnef.infer_shapes(model) + + return conv_clss.from_nnef(graph=model) diff --git a/python/tvm/relay/frontend/nnef_ops.py b/python/tvm/relay/frontend/nnef_ops.py new file mode 100644 index 000000000000..49b2dfb2eabe --- /dev/null +++ b/python/tvm/relay/frontend/nnef_ops.py @@ -0,0 +1,1695 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""NNEF frontend converter helper funcs and ops""" +import math + +import itertools +from functools import reduce + +import numpy as np + +import tvm +from tvm import relay +from tvm.relay import expr as tvm_expr +from tvm.relay import op as tvm_op +from tvm.relay.frontend.common import get_relay_op, infer_shape, infer_type + + +# Base methods + + +def dimension_picker(prefix, kernel_shape, suffix=""): + """ + Returns the correct name for nth dimensional operator. Uses the "kernel_shape" attribute.\n + E.g.call: dimension_picker(op_name)(attr) + + :param prefix: the name of the operator (e.g. conv) + :param kernel_shape: shape of the tensor to fit the operation + :param suffix: optional suffix for ops + :return: "prefix`n`d" where n is the correct dimension for the kernel + """ + + rank = len(kernel_shape[2:]) + if rank == 1: + return prefix + "1d" + suffix + if rank == 2: + return prefix + "2d" + suffix + if rank == 3: + return prefix + "3d" + suffix + op_name = prefix + "1d/2d/3d" + msg = f"Only 1D, 2D, and 3D kernels are supported for operator {op_name}." + raise tvm.error.OpAttributeInvalid(msg) + + +def _size_conv(size, rank): + # window of size (DH)W is only possible when it is checked outside, + # which is needed for alternative solution + if rank == 3: + if len(size) == 1: + return size + if len(size) == 3: + assert ( + size[0] == 1 and size[1] == 1 + ), "Incorrect window dimensions, first two dimensions must be 1" + return size[2] + if rank == 4: + if len(size) == 2: + return size + if len(size) == 4: + assert ( + size[0] == 1 and size[1] == 1 + ), "Incorrect window dimensions, first two dimensions must be 1" + return size[2:] + if rank == 5: + if len(size) == 3: + return size + if len(size) == 5: + assert ( + size[0] == 1 and size[1] == 1 + ), "Incorrect window dimensions, first two dimensions must be 1" + return size[2:] + + raise ValueError(f"Unexpected window size, got {len(size)}") + + +def _stride_conv(stride, rank): + if rank == 3: + # {conv style} :: [s] -> [s] + if len(stride) == 1: + return stride + # {pool style} :: [N, C, s] -> asrt N,C == 1; [s] + if len(stride) == 3: + assert ( + stride[0] == 1 and stride[1] == 1 + ), "Not supported stride dimensions, first two dimensions must be 1" + return stride[2:] + if rank == 4: + # {conv style} :: [sh, sw] -> [sh, sw] + if len(stride) == 2: + return stride + # {pool style} :: [N, C, sh, sw] -> asrt N,C == 1; [sh, sw] + if len(stride) == 4: + assert ( + stride[0] == 1 and stride[1] == 1 + ), "Not supported stride dimensions, first two dimensions must be 1" + return stride[2:] + if rank == 5: + # {conv style} :: [sd, sh, sw] -> [sd, sh, sw] + if len(stride) == 3: + return stride + # {pool style} :: [N, C, sd, sh, sw] -> asrt N,C == 1; [sd, sh, sw] + if len(stride) == 5: + assert ( + stride[0] == 1 and stride[1] == 1 + ), "Not supported stride dimensions, first two dimensions must be 1" + return stride[2:] + raise ValueError(f"Unexpected stride in {rank - 2}D, got {len(stride)}: {stride}") + + +def _padding_conv(padding, rank, keepdims=False): + if isinstance(padding[0], (tuple, list)): + # 1D + if rank == 3: + # {conv style} :: [(l,r)] -> (l,r) + if len(padding) == 1: + return padding[0] + if len(padding) == 3: + # {pool style} :: [(batch),(channel),(l,r)] -> asrt N,C == 0, (l,r) + if not keepdims: + assert padding[0] == (0, 0) and padding[1] == (0, 0), ( + "Incorrect padding. " "Padding on C,I dimensions not supported" + ) + return padding[2] + # {sliding window style} :: [(batch),(channel),(l,r)] -> [(batch),(channel),(l,r)] + else: + return padding + + # 2D + + if rank == 4: + # {conv style} :: [(u,d),(l,r)] -> (u, l, d, r) + if len(padding) == 2: + # change UDLR to ULDR padding, LC is faster here + return [x[i] for i in [0, 1] for x in padding] + + if len(padding) == 4: + # {pool style} :: [(batch size),(channel),(u,d),(l,r)] -> + # -> asrt N,C == 0, (u, l, d, r) + if not keepdims: + assert padding[0] == (0, 0) and padding[1] == (0, 0), ( + "Incorrect padding. " "Padding on C,I dimensions not supported" + ) + # itertools is faster than LC (slicing) + return list(itertools.chain.from_iterable(zip(padding[2], padding[3]))) + # {sliding window style} :: [(batch),(channel),(u,d),(l,r)] -> + # -> [(batch),(channel),(u,d),(l,r)] + else: + return padding + + # 3D + + if rank == 5: + # {conv style} :: [(f,b),(u,d),(l,r)] -> (f, u, l, b, d, r) + if len(padding) == 3: + # LC is faster + return [x[i] for i in [0, 1] for x in padding] + + if len(padding) == 5: + # {pool style} :: [(batch size),(channel),(f,b)(u,p),(l,r)] -> + # -> asrt N,C == 0, (f, u, l, b, d, r) + if not keepdims: + assert padding[0] == (0, 0) and padding[1] == (0, 0), ( + "Incorrect padding. " "Padding on C,I dimensions not supported" + ) + # itertools faster barely + return list( + itertools.chain.from_iterable(zip(padding[2], padding[3], padding[4])) + ) + # {s-w style} :: [(batch),(channel),(f,b),(u,d),(l,r)] -> + # -> [(batch),(channel),(f,b),(u,d),(l,r)] + else: + return padding + + raise ValueError( + f"Incorrect padding style for {rank - 2}D operand. Only length of {rank - 2}, {rank} " + f"supported, got {len(padding)}: {padding}" + ) + + raise ValueError("nnef should not have singular padding") + + +def _calculate_nnef_padding(active_shape, strides, kernel_shape, dilation): + """Ordering of nnef autopad and tvm autopad are sometimes different, + this method calculates nnef like padding from dimensions + + Parameters + ---------- + active_shape + the data dimensions + strides + the strides over the active dimensions + kernel_shape + the shape of the window, must have the same rank as active shape + dilation + the dilations over the active dimensions + """ + output = [(ui + (s - 1)) // s for ui, s in zip(active_shape, strides)] + dilated = [(f - 1) * d + 1 for f, d in zip(kernel_shape, dilation)] + total = [ + max(0, (di - 1) * s + df - ui) + for di, s, df, ui in zip(output, strides, dilated, active_shape) + ] + padding = [(pad // 2, (pad + 1) // 2) for pad in total] + return padding + + +def _calculate_nnef_padding_deconv(data_sh, strides, kernel_active_sh, dilation, output_shape): + out_sh = output_shape[2:] if output_shape else [ui * s for ui, s in zip(data_sh, strides)] + dilated = [(f - 1) * d + 1 for f, d in zip(kernel_active_sh[2:], dilation)] + total = [ + max(0, (di - 1) * s + df - ui) for di, s, df, ui in zip(data_sh, strides, dilated, out_sh) + ] + return total, out_sh + + +def __unexpected_attrs(op, kwargs): + raise NotImplementedError( + f"{op} received unexpected attributes(s), possibly mismatched versions. " + "Attributes(s) ignored: " + ", ".join(f"{k} := {v}" for k, v in kwargs.items()) + ) + + +# Conversion map, operator functions + + +def _get_converter_map(): + return { # Unary + "copy": copy_converter, # arithmetic + "neg": neg_converter, + "rcp": rcp_converter, + "exp": exp_converter, + "log": log_converter, + "sin": sin_converter, + "cos": cos_converter, + "tan": tan_converter, + "sinh": sinh_converter, + "cosh": cosh_converter, + "tanh": tanh_converter, + "asin": asin_converter, + "acos": acos_converter, + "atan": atan_converter, + "asinh": asinh_converter, + "acosh": acosh_converter, + "atanh": atanh_converter, + "abs": abs_converter, + "sign": sign_converter, + "not": not_converter, # logical + "floor": floor_converter, # rounding + "ceil": ceil_converter, + "round": round_converter, + # Binary + "add": add_converter, # arithmetic + "sub": sub_converter, + "mul": mul_converter, + "div": div_converter, + "pow": pow_converter, + "lt": lt_converter, # comparison + "gt": gt_converter, + "le": le_converter, + "ge": ge_converter, + "eq": eq_converter, + "ne": ne_converter, + "and": and_converter, # logical + "or": or_converter, + # select + "select": select_converter, + # simplifier + "sqr": sqr_converter, + "sqrt": sqrt_converter, + "rsqr": rsqr_converter, + "rsqrt": rsqrt_converter, + "log2": log2_converter, + "min": min_converter, + "max": max_converter, + "clamp": clamp_converter, + # sliding-window + "conv": conv_converter, + "deconv": deconv_converter, + "box": box_converter, + "debox": debox_converter, + "argmax_pool": ndop, + "sample": ndop, + "desample": ndop, + "nearest_downsample": nearest_downsample_converter, + "area_downsample": area_downsample_converter, + "nearest_upsample": nearest_upsample_converter, + "multilinear_upsample": multilinear_upsample_converter, + # reduce + "sum_reduce": sum_reduce_converter, + "max_reduce": max_reduce_converter, + "min_reduce": min_reduce_converter, + "argmax_reduce": argmax_reduce_converter, + "argmin_reduce": argmin_reduce_converter, + "all_reduce": all_reduce_converter, + "any_reduce": any_reduce_converter, + "mean_reduce": mean_reduce_converter, + # tensor shape + "reshape": reshape_converter, + "squeeze": squeeze_converter, + "unsqueeze": unsqueeze_converter, + "transpose": transpose_converter, + "split": split_converter, + "concat": concat_converter, + "stack": stack_converter, + "unstack": unstack_converter, + "slice": slice_converter, + "pad": pad_converter, + "tile": tile_converter, + # region-of-interest - not needed - not supported + "avg_roi_pool": ndop, + "max_roi_pool": ndop, + "roi_resample": ndop, + "avg_roi_align": ndop, + "max_roi_align": ndop, + # matrix multiplication + "matmul": matmul_converter, + # variables + "update": ndop, # --- not used + # Compound + "sigmoid": sigmoid_converter, # activation + "relu": relu_converter, + "prelu": prelu_converter, + "leaky_relu": leaky_relu_converter, + "elu": elu_converter, + "selu": selu_converter, + "gelu": gelu_converter, + "silu": silu_converter, + "softmax": softmax_converter, + "softplus": softplus_converter, + "linear": linear_converter, # linear + "separable_conv": separable_conv_converter, + "separable_deconv": separable_deconv_converter, + "max_pool_with_index": ndop, # pooling + "max_pool": max_pool_converter, + "avg_pool": avg_pool_converter, + "rms_pool": rms_pool_converter, + "local_response_normalization": local_response_normalization_converter, # normalization + "local_mean_normalization": local_mean_normalization_converter, + "local_variance_normalization": local_variance_normalization_converter, + "local_contrast_normalization": local_contrast_normalization_converter, + "l1_normalization": l1_normalization_converter, + "l2_normalization": l2_normalization_converter, + "batch_normalization": batch_normalization_converter, + "min_max_linear_quantize": ndop, # quantization + "zero_point_linear_quantize": ndop, + "linear_quantize": ndop, + "logarithmic_quantize": ndop, + # MISC + "copy_n": ndop, + "add_n": ndop, + "moments": ndop, + } + + +# not implemented ops +def ndop(*args, **kwargs): + raise Exception("Not supported operator was called, please check for compatibility") + + +# # Unary ops + + +def copy_converter(data, **kwargs): + """Copy converter""" + if kwargs: + __unexpected_attrs("copy", kwargs) + + return get_relay_op("copy")(data) + + +def neg_converter(data, **kwargs): + """Neg converter""" + if kwargs: + __unexpected_attrs("neg", kwargs) + + return get_relay_op("negative")(data) + + +def rcp_converter(data, **kwargs): + """Rcp converter""" + if kwargs: + __unexpected_attrs("rcp", kwargs) + + if isinstance(data, relay.Call): + d_type = infer_type(data).checked_type.dtype + else: + d_type = data.type_annotation.dtype + + return div_converter(tvm_expr.const(1, dtype=d_type), data) + + +def exp_converter(data, **kwargs): + """Exp converter""" + if kwargs: + __unexpected_attrs("exp", kwargs) + + return get_relay_op("exp")(data) + + +def log_converter(data, **kwargs): + """Log converter""" + if kwargs: + __unexpected_attrs("log", kwargs) + + return get_relay_op("log")(data) + + +def sin_converter(data, **kwargs): + """Sin converter""" + if kwargs: + __unexpected_attrs("sin", kwargs) + + return get_relay_op("sin")(data) + + +def cos_converter(data, **kwargs): + """Cos converter""" + if kwargs: + __unexpected_attrs("cos", kwargs) + + return get_relay_op("cos")(data) + + +def tan_converter(data, **kwargs): + """Tan converter""" + if kwargs: + __unexpected_attrs("tan", kwargs) + + return get_relay_op("tan")(data) + + +def sinh_converter(data, **kwargs): + """Sinh converter""" + if kwargs: + __unexpected_attrs("sinh", kwargs) + + return get_relay_op("sinh")(data) + + +def cosh_converter(data, **kwargs): + """Cosh converter""" + if kwargs: + __unexpected_attrs("cosh", kwargs) + + return get_relay_op("cosh")(data) + + +def tanh_converter(data, **kwargs): + """Tanh converter""" + if kwargs: + __unexpected_attrs("tanh", kwargs) + + return get_relay_op("tanh")(data) + + +def asin_converter(data, **kwargs): + """Asin converter""" + if kwargs: + __unexpected_attrs("asin", kwargs) + + return get_relay_op("asin")(data) + + +def acos_converter(data, **kwargs): + """Acos converter""" + if kwargs: + __unexpected_attrs("acos", kwargs) + + return get_relay_op("acos")(data) + + +def atan_converter(data, **kwargs): + """Atan converter""" + if kwargs: + __unexpected_attrs("atan", kwargs) + + return get_relay_op("atan")(data) + + +def asinh_converter(data, **kwargs): + """Asinh converter""" + if kwargs: + __unexpected_attrs("asinh", kwargs) + + return get_relay_op("asinh")(data) + + +def acosh_converter(data, **kwargs): + """Acosh converter""" + if kwargs: + __unexpected_attrs("acosh", kwargs) + + return get_relay_op("acosh")(data) + + +def atanh_converter(data, **kwargs): + """Atanh converter""" + if kwargs: + __unexpected_attrs("atanh", kwargs) + + return get_relay_op("atanh")(data) + + +def abs_converter(data, **kwargs): + """Abs converter""" + if kwargs: + __unexpected_attrs("abs", kwargs) + + return get_relay_op("abs")(data) + + +def sign_converter(data, **kwargs): + """Sign converter""" + if kwargs: + __unexpected_attrs("sign", kwargs) + + return get_relay_op("sign")(data) + + +def not_converter(data, **kwargs): + """Not converter""" + if kwargs: + __unexpected_attrs("not", kwargs) + + return get_relay_op("logical_not")(data) + + +def floor_converter(data, **kwargs): + """Floor converter""" + if kwargs: + __unexpected_attrs("floor", kwargs) + + return get_relay_op("floor")(data) + + +def ceil_converter(data, **kwargs): + """Ceil converter""" + if kwargs: + __unexpected_attrs("ceil", kwargs) + + return get_relay_op("ceil")(data) + + +def round_converter(data, **kwargs): + """Round converter""" + if kwargs: + __unexpected_attrs("round", kwargs) + + return get_relay_op("round")(data) + + +# # Binary ops + + +def add_converter(lhs, rhs, **kwargs): + """Add converter""" + if kwargs: + __unexpected_attrs("add", kwargs) + + return get_relay_op("add")(lhs, rhs) + + +def sub_converter(lhs, rhs, **kwargs): + """Sub converter""" + if kwargs: + __unexpected_attrs("sub", kwargs) + + return get_relay_op("subtract")(lhs, rhs) + + +def mul_converter(lhs, rhs, **kwargs): + """Mul converter""" + if kwargs: + __unexpected_attrs("mul", kwargs) + + return get_relay_op("multiply")(lhs, rhs) + + +def div_converter(lhs, rhs, **kwargs): + """Div converter""" + if kwargs: + __unexpected_attrs("div", kwargs) + + return get_relay_op("divide")(lhs, rhs) + + +def pow_converter(lhs, rhs, **kwargs): + """Pow converter""" + if kwargs: + __unexpected_attrs("pow", kwargs) + + return get_relay_op("power")(lhs, rhs) + + +def lt_converter(lhs, rhs, **kwargs): + """Lt converter""" + if kwargs: + __unexpected_attrs("lt", kwargs) + + return get_relay_op("less")(lhs, rhs) + + +def gt_converter(lhs, rhs, **kwargs): + """Gt converter""" + if kwargs: + __unexpected_attrs("gt", kwargs) + + return get_relay_op("greater")(lhs, rhs) + + +def le_converter(lhs, rhs, **kwargs): + """Le converter""" + if kwargs: + __unexpected_attrs("le", kwargs) + + return get_relay_op("less_equal")(lhs, rhs) + + +def ge_converter(lhs, rhs, **kwargs): + """Ge converter""" + if kwargs: + __unexpected_attrs("ge", kwargs) + + return get_relay_op("greater_equal")(lhs, rhs) + + +def eq_converter(lhs, rhs, **kwargs): + """Eq converter""" + if kwargs: + __unexpected_attrs("eq", kwargs) + + return get_relay_op("equal")(lhs, rhs) + + +def ne_converter(lhs, rhs, **kwargs): + """Ne converter""" + if kwargs: + __unexpected_attrs("ne", kwargs) + + return get_relay_op("not_equal")(lhs, rhs) + + +def and_converter(lhs, rhs, **kwargs): + """And converter""" + if kwargs: + __unexpected_attrs("and", kwargs) + + return get_relay_op("logical_and")(lhs, rhs) + + +def or_converter(lhs, rhs, **kwargs): + """Or converter""" + if kwargs: + __unexpected_attrs("or", kwargs) + + return get_relay_op("logical_or")(lhs, rhs) + + +# # Select op + + +def select_converter(condition, t_val, f_val, **kwargs): + """Select converter""" + if kwargs: + __unexpected_attrs("select", kwargs) + + return get_relay_op("where")(condition, t_val, f_val) + + +# # Simplifier ops + + +def sqr_converter(data, **kwargs): + """sqr converter""" + if kwargs: + __unexpected_attrs("sqr", kwargs) + + if isinstance(data, relay.Call): + d_type = infer_type(data).checked_type.dtype + else: + d_type = data.type_annotation.dtype + + return get_relay_op("power")(data, tvm_expr.const(2.0, dtype=d_type)) + + +def sqrt_converter(data, **kwargs): + """sqrt converter""" + if kwargs: + __unexpected_attrs("sqrt", kwargs) + + return get_relay_op("sqrt")(data) + + +def rsqr_converter(data, **kwargs): + """rsqr converter""" + if kwargs: + __unexpected_attrs("rsqr", kwargs) + + if isinstance(data, relay.Call): + d_type = infer_type(data).checked_type.dtype + else: + d_type = data.type_annotation.dtype + + return get_relay_op("power")(data, tvm_expr.const(-2.0, dtype=d_type)) + + +def rsqrt_converter(data, **kwargs): + """rsqrt converter""" + if kwargs: + __unexpected_attrs("rsqrt", kwargs) + + return get_relay_op("rsqrt")(data) + + +def log2_converter(data, **kwargs): + """log2 converter""" + if kwargs: + __unexpected_attrs("log2", kwargs) + + return get_relay_op("log2")(data) + + +def min_converter(lhs, rhs, **kwargs): + """Min converter""" + if kwargs: + __unexpected_attrs("min", kwargs) + + return get_relay_op("minimum")(lhs, rhs) + + +def max_converter(lhs, rhs, **kwargs): + """Max converter""" + if kwargs: + __unexpected_attrs("max", kwargs) + + return get_relay_op("maximum")(lhs, rhs) + + +def clamp_converter(x, a, b, **kwargs): + """Clamp converter""" + if kwargs: + __unexpected_attrs("clamp", kwargs) + + # only works if b and a are Constant floats, not tensors + if isinstance(a, tvm_expr.Constant) and isinstance(b, tvm_expr.Constant): + return get_relay_op("clip")(x, float(a.data.numpy()), float(b.data.numpy())) + + return max_converter(min_converter(x, b), a) + + +# # Sliding-window ops + + +def conv_converter(data, kernel, bias, border, stride, padding, dilation, groups, **kwargs): + """Convolution converter, + skips bias if it's 0.0 (no bias)""" + if kwargs: + __unexpected_attrs("conv", kwargs) + + if border != "constant": + print(f"Currently {border} border is not supported, used `constant` border") + + kernel_shape = infer_shape(kernel) + dshape = infer_shape(data) + + strides = _stride_conv(stride, len(kernel_shape)) if stride else (1,) * (len(kernel_shape) - 2) + + dilation = dilation if dilation else ((1,) * (len(kernel_shape) - 2)) + + if not padding: + padding = _calculate_nnef_padding(dshape[2:], strides, kernel_shape[2:], dilation) + + pad = _padding_conv(padding, len(kernel_shape)) + + channels = kernel_shape[0] + + if groups == 0: + groups = channels + + op = get_relay_op(dimension_picker("conv", kernel_shape)) + conv_out = op( + data=data, + weight=kernel, + strides=strides, + padding=pad, + dilation=dilation, + groups=groups, + channels=channels, + kernel_size=kernel_shape[2:], + ) + + res = None + if isinstance(bias, tvm_expr.Constant): + # nnef has bias of 0 if it is not needed + if (bias.data.numpy() == 0).all(): + res = conv_out + + if not res: + # squeeze needed as nnef has bias of shape [1, channel] + res = tvm_op.nn.bias_add(conv_out, relay.squeeze(bias, axis=0)) + + return res + + +def deconv_converter( + data, kernel, bias, border, stride, padding, dilation, output_shape, groups, **kwargs +): + """Deconvolution converter, using convxd_transpose + skips bias if it's 0.0 (no bias)""" + if kwargs: + __unexpected_attrs("deconv", kwargs) + + if border != "constant": + print(f"Currently {border} border is not supported, used `constant` border") + + kernel_shape = infer_shape(kernel) + + rank = len(kernel_shape) + + strides = _stride_conv(stride, rank) if stride else (1,) * (rank - 2) + + dilation = dilation if dilation else ((1,) * (rank - 2)) + + total, out_sh = _calculate_nnef_padding_deconv( + infer_shape(data), strides, kernel_shape, dilation, output_shape + ) + + if padding: + pad = _padding_conv(padding, rank) + else: + pad = _padding_conv([(pad // 2, (pad + 1) // 2) for pad in total], rank) + + if groups == 0: + groups = kernel_shape[0] + channels = kernel_shape[1] * groups + + # limit output padding to modulo stride because of tvm checks + out_pad = ( + [(x - (y - t)) % s for x, y, t, s in zip(output_shape[2:], out_sh, total, stride)] + if output_shape + else (0, 0) + ) + + op = get_relay_op(dimension_picker("conv", kernel_shape, suffix="_transpose")) + deconv_out = op( + data=data, + weight=kernel, + strides=strides, + padding=pad, + dilation=dilation, + groups=groups, + channels=channels, + kernel_size=kernel_shape[2:], + output_padding=out_pad, + ) + + res = None + if isinstance(bias, tvm_expr.Constant): + if bias.data.numpy() == np.array([0.0]): + res = deconv_out + + if not res: + # squeeze needed bc nnef has bias of shape [1, channel] + res = tvm_op.nn.bias_add(deconv_out, relay.squeeze(bias, axis=0)) + + return res + + +def box_converter(data, size, border, padding, stride, dilation, normalize, **kwargs): + """Box operator converter, + summation over sliding window, equal to conv with constant filter""" + if kwargs: + __unexpected_attrs("box", kwargs) + + dshape = infer_shape(data) + + if isinstance(data, relay.Call): + d_type = infer_type(data).checked_type.dtype + else: + d_type = data.type_annotation.dtype + + size[0] = dshape[1] + if normalize: + kernel = relay.full(tvm_op.const(1 / math.prod(size[2:]), d_type), size, d_type) + else: + kernel = relay.ones(size, d_type) + + out = conv_converter( + data, kernel, tvm_expr.const(0, dtype=d_type), border, stride, padding, dilation, dshape[1] + ) + return out + + +def debox_converter( + data, size, border, padding, stride, dilation, normalize, output_shape, **kwargs +): + """Debox operator converter, + inverse of box, equal to deconv with constant filter""" + if kwargs: + __unexpected_attrs("debox", kwargs) + + dshape = infer_shape(data) + + if isinstance(data, relay.Call): + d_type = infer_type(data).checked_type.dtype + else: + d_type = data.type_annotation.dtype + + size[0] = dshape[1] + if normalize: + kernel = relay.full(tvm_op.const(1 / math.prod(size[2:]), d_type), size, d_type) + else: + kernel = relay.ones(size, d_type) + out = deconv_converter( + data, + kernel, + tvm_expr.const(0, dtype=d_type), + border, + stride, + padding, + dilation, + output_shape, + groups=dshape[1], + ) + return out + + +def nearest_downsample_converter(data, factor, **kwargs): + """Nearest neighbour downsample converter""" + if kwargs: + __unexpected_attrs("nearest_downsample", kwargs) + + dims = 2 + len(factor) + + return box_converter( + data, + size=[1] * dims, + border="constant", + padding=[(0, 0)] * dims, + stride=[1, 1] + factor, + dilation=(1,) * (dims - 2), + normalize=False, + ) + + +def area_downsample_converter(data, factor, **kwargs): + """Area downsample converter""" + if kwargs: + __unexpected_attrs("area_downsample", kwargs) + + dims = 2 + len(factor) + + return box_converter( + data, + size=[1, 1] + factor, + border="constant", + padding=[(0, 0)] * dims, + stride=[1, 1] + factor, + dilation=(1,) * (dims - 2), + normalize=True, + ) + + +def nearest_upsample_converter(data, factor, **kwargs): + """Nearest neighbour upsample converter""" + if kwargs: + __unexpected_attrs("nearest_upsample", kwargs) + + # conversion from nn.upsampling to image.resizexd, re: discuss:11650 + # + dshape = infer_shape(data) + new_size = [d * f for d, f in zip(dshape[2:], factor)] + return get_relay_op(dimension_picker("resize", dshape))( + data, + new_size, + method="nearest_neighbor", + # coordinate_transformation_mode="asymmetric", + rounding_method="round", + ) + + +def multilinear_upsample_converter(data, factor, method, border, **kwargs): + """Multilinear upsampling converter""" + if kwargs: + __unexpected_attrs("linear_upsample", kwargs) + + # for aligned and symmetric replicate resize can be used + dshape = infer_shape(data) + new_size = [d * f for d, f in zip(dshape[2:], factor)] + if method == "aligned": + # conversion from nn.upsampling to image.resizexd, re: discuss:11650 + return get_relay_op(dimension_picker("resize", dshape))( + data, + new_size, + method="linear", + coordinate_transformation_mode="align_corners", + ) + if method == "symmetric" and border == "replicate": + return get_relay_op(dimension_picker("resize", dshape))( + data, + new_size, + method="linear", + coordinate_transformation_mode="half_pixel", + ) + + # other combinations need to be calculated with convolution + def _upsample_weights_1d(fact, symm): + if symm: + _weights = [1 - (i + 0.5) / fact for i in range(fact)] + _weights = list(reversed(_weights)) + _weights + else: + _weights = [1 - abs(i) / float(fact) for i in range(-fact + 1, fact)] + return np.array(_weights) + + def _upsample_weights_nd(fact, symm): + _weights = [_upsample_weights_1d(f, symm) for f in fact] + return reduce(np.multiply, np.ix_(*_weights)) + + n, c = dshape[:2] + + symmetric = method == "symmetric" + weights = _upsample_weights_nd(factor, symmetric) + weights = np.reshape(weights, newshape=(1, 1) + weights.shape) + kernel = tile_converter(tvm_expr.const(weights), (c, 1) + (1,) * len(factor)) + + output_shape = [n, c] + [f * s for f, s in zip(factor, dshape[2:])] + + if symmetric: + return deconv_converter( + data, + kernel, + tvm_expr.const(0.0), + border="constant", + stride=factor, + padding=[(f - 1, f - 1) for f in factor], + dilation=[], + groups=c, + output_shape=output_shape, + ) + else: + replicate = border == "replicate" + if replicate: + data = pad_converter( + data, [(0, 0), (0, 0)] + [(1, 0)] * len(factor), border, tvm_expr.const(0.0) + ) + padding = factor + else: + padding = [f // 2 for f in factor] + + return deconv_converter( + data, + kernel, + tvm_expr.const(0.0), + border="constant", + stride=factor, + padding=[(p, p - 1) for p in padding], + dilation=[], + groups=c, + output_shape=output_shape, + ) + + +# # Reduce ops + + +def sum_reduce_converter(data, axes, normalize, keepdims=True, **kwargs): + """Sum reduce converter""" + + if kwargs: + __unexpected_attrs("sum_reduce", kwargs) + + out = get_relay_op("sum")(data, axes, keepdims=keepdims) + if normalize: + return l2_normalization_converter(out, 0, [x - 2 for x in axes], 0.0) + return out + + +def max_reduce_converter(data, axes, keepdims=True, **kwargs): + """Max reduce converter""" + if kwargs: + __unexpected_attrs("max_reduce", kwargs) + + return get_relay_op("max")(data, axes, keepdims=keepdims) + + +def min_reduce_converter(data, axes, keepdims=True, **kwargs): + """Min reduce converter""" + if kwargs: + __unexpected_attrs("min_reduce", kwargs) + + return get_relay_op("min")(data, axes, keepdims=keepdims) + + +def argmax_reduce_converter(data, axes, keepdims=True, **kwargs): + """Argmax reduce converter""" + if kwargs: + __unexpected_attrs("argmax_reduce", kwargs) + + return get_relay_op("argmax")(data, axes, keepdims=keepdims) + + +def argmin_reduce_converter(data, axes, keepdims=True, **kwargs): + """Argmin reduce converter""" + if kwargs: + __unexpected_attrs("argmin_reduce", kwargs) + + return get_relay_op("argmin")(data, axes, keepdims=keepdims) + + +def all_reduce_converter(data, axes, keepdims=True, **kwargs): + """All reduce converter""" + if kwargs: + __unexpected_attrs("all_reduce", kwargs) + + return get_relay_op("all")(data, axes, keepdims=keepdims) + + +def any_reduce_converter(data, axes, keepdims=True, **kwargs): + """Any reduce converter""" + if kwargs: + __unexpected_attrs("any_reduce", kwargs) + + return get_relay_op("any")(data, axes, keepdims=keepdims) + + +def mean_reduce_converter(data, axes, keepdims=True, **kwargs): + """Mean reduce converter""" + if kwargs: + __unexpected_attrs("mean_reduce", kwargs) + + return get_relay_op("mean")(data, axes, keepdims=keepdims) + + +# # Tensor shape ops + + +def reshape_converter(data, shape, axis_start, axis_count, **kwargs): + """Reshape converter""" + if kwargs: + __unexpected_attrs("reshape", kwargs) + + dshape = list(infer_shape(data)) + if axis_count == -1: + newshape = dshape[:axis_start] + shape + else: + newshape = dshape + newshape[axis_start : axis_start + axis_count] = shape + + return get_relay_op("reshape")(data, newshape) + + +def squeeze_converter(data, axes, **kwargs): + """Squeeze converter""" + if kwargs: + __unexpected_attrs("squeeze", kwargs) + return relay.squeeze(data, axes) + + +def unsqueeze_converter(data, axes, **kwargs): + """Unsqueeze converter""" + if kwargs: + __unexpected_attrs("unsqueeze", kwargs) + + axes = sorted(axes) + for axis in axes: + if axis < 0 and isinstance(data, tvm_expr.Var): + axis = len(data.type_annotation.concrete_shape) + len(axes) + axis + + data = tvm_op.expand_dims(data, axis=axis, num_newaxis=1) + return data + + +def transpose_converter(data, axes, **kwargs): + """Transpose converter""" + if kwargs: + __unexpected_attrs("transpose", kwargs) + + return get_relay_op("transpose")(data, axes) + + +def split_converter(data, axis, ratios, **kwargs): + """Split converter""" + if kwargs: + __unexpected_attrs("split", kwargs) + + axis_len = infer_shape(data)[axis] + rat_mul = axis_len / sum(ratios) + ratio_list = [(r * rat_mul) for r in ratios] + + s = 0 + indices = [] + for rat in ratio_list[:-1]: + s += rat + # Strictly needs int + indices.append(int(s)) + + return get_relay_op("split")(data, indices, axis) + + +def concat_converter(*data, axis, **kwargs): + """Concat converter""" + if kwargs: + __unexpected_attrs("concat", kwargs) + + return get_relay_op("concatenate")(data, axis) + + +def stack_converter(*data, axis, **kwargs): + """Stack converter""" + if kwargs: + __unexpected_attrs("stack", kwargs) + + return get_relay_op("stack")(data, axis) + + +def unstack_converter(data, axis, **kwargs): + """Unstack converter""" + if kwargs: + __unexpected_attrs("unstack", kwargs) + + split = split_converter(data, axis, [1] * infer_shape(data)[axis]) + res = [] + for i in range(len(split)): + res.append(squeeze_converter(split[i], axis)) + return tvm_expr.TupleWrapper(relay.Tuple(res), len(res)) + + +def slice_converter(data, axes, begin, end, stride, **kwargs): + """Slice converter""" + if kwargs: + __unexpected_attrs("slice", kwargs) + + if not stride: + stride = [1] * len(axes) + + return get_relay_op("strided_slice")(data, begin, end, strides=stride, axes=axes) + + +def pad_converter(data, padding, border, value, **kwargs): + """Pad converter""" + if kwargs: + __unexpected_attrs("pad", kwargs) + + if border not in ["constant", "replicate", "reflect"]: + print(f"{border} border type is not supported in padding. Assumed constant") + border = "constant" + if border == "replicate": + border = "edge" + + return get_relay_op("pad")(data, padding, value, border) + + +def tile_converter(data, repeats, **kwargs): + """Tile converter""" + if kwargs: + __unexpected_attrs("tile", kwargs) + + return get_relay_op("tile")(data, repeats) + + +# # Region-of-interest ops + + +# # Matrix multiplication +def matmul_converter(a, b, **kwargs): + """Matmul converter + real signature: matmul_converter(a, b, transposeA, transposeB)""" + + transpose_a = kwargs.pop("transposeA") + transpose_b = kwargs.pop("transposeB") + if kwargs: + __unexpected_attrs("matmul", kwargs) + + a_shape = infer_shape(a) + b_shape = infer_shape(b) + a_rank = len(a_shape) + b_rank = len(b_shape) + + if a_rank == 2 and b_rank == 2: + out = get_relay_op("matmul")(a, b, transpose_a=transpose_a, transpose_b=transpose_b) + else: + batch_shape = [1] * (max(a_rank, b_rank) - 2) + + for i, j in enumerate(reversed(a_shape[:-2])): + batch_shape[i] = j + + for i, j in enumerate(reversed(b_shape[:-2])): + # Need to check if axis can be broadcasted + if batch_shape[i] == 1 or j == 1 or batch_shape[i] == j: + batch_shape[i] = max(batch_shape[i], j) + else: + msg = "Batch dimensions are not broadcastable." + raise AssertionError(msg) + + batch_shape = batch_shape[::-1] + + a = tvm_op.broadcast_to(a, batch_shape + list(a_shape[-2:])) + b = tvm_op.broadcast_to(b, batch_shape + list(b_shape[-2:])) + + out = get_relay_op("batch_matmul")( + tvm_op.reshape(a, [-1, *a_shape[-2:]]), + tvm_op.reshape(b, [-1, *b_shape[-2:]]), + transpose_b=transpose_b, + transpose_a=transpose_a, + ) + + out_shape = batch_shape + [a_shape[-2]] + [b_shape[-1]] + out = tvm_op.reshape(out, out_shape) + + return out + + +# # Variable updates +# # Compound ops + + +def sigmoid_converter(data, **kwargs): + """Sigmoid converter""" + if kwargs: + __unexpected_attrs("sigmoid", kwargs) + + return get_relay_op("sigmoid")(data) + + +def relu_converter(data, **kwargs): + """RELU converter""" + if kwargs: + __unexpected_attrs("relu", kwargs) + + return get_relay_op("relu")(data) + + +def prelu_converter(data, alpha, **kwargs): + """PRELU converter""" + if kwargs: + __unexpected_attrs("prelu", kwargs) + + # prelu can"t handle float vals but NNEF supports direct parameter, this is just in case + if isinstance(alpha, tvm_expr.Constant): + if alpha.data.numpy().size == 1: + return get_relay_op("leaky_relu")(data, alpha.data.numpy().item()) + + return get_relay_op("prelu")(data, alpha) + + +def leaky_relu_converter(data, alpha, **kwargs): + """Leaky RELU converter""" + if kwargs: + __unexpected_attrs("leaky_relu", kwargs) + + return get_relay_op("leaky_relu")(data, alpha) + + +def elu_converter(data, alpha, **kwargs): + """ELU converter""" + if kwargs: + __unexpected_attrs("elu", kwargs) + + return select_converter( + lt_converter(data, tvm_expr.const(0.0)), + mul_converter( + tvm_expr.const(alpha), sub_converter(exp_converter(data), tvm_expr.const(1.0)) + ), + data, + ) + + +def selu_converter(data, alpha, **kwargs): + """SELU converter + True signature is selu_converter(data, alpha, lambda)""" + lambda_var = kwargs.pop("lambda") + + if kwargs: + __unexpected_attrs("selu", kwargs) + + return mul_converter( + tvm_expr.const(lambda_var), + select_converter( + data < tvm_expr.const(0.0), + mul_converter( + tvm_expr.const(alpha), sub_converter(exp_converter(data), tvm_expr.const(1.0)) + ), + data, + ), + ) + + +def gelu_converter(data, **kwargs): + """GELU converter + NNEF definition for GELU: + the exact definition of GELU is x * Phi(x) where Phi(x) is the + CDF of the standard normal distribution, which can be approximated + for example by sigmoid(1.702 * x) + + `mul_converter(data, sigmoid_converter(mul_converter(tvm_expr.const(1.702), data)))` + + But in this case we will use the erf to calculate normcdf (same as to pytorch GELU impl) + """ + if kwargs: + __unexpected_attrs("gelu", kwargs) + + return data * ( + tvm_expr.const(0.5) + tvm_op.erf(data * tvm_expr.const(0.5**0.5)) * tvm_expr.const(0.5) + ) + + +def silu_converter(data, **kwargs): + """SiLU converter""" + if kwargs: + __unexpected_attrs("silu", kwargs) + + return mul_converter(data, sigmoid_converter(data)) + + +def softmax_converter(data, axes, **kwargs): + """Softmax converter""" + if kwargs: + __unexpected_attrs("softmax", kwargs) + + if len(axes) > 1: + print("Multiple axes not supported, operation has been done along the first axis in axes.") + axis = axes[0] + + return get_relay_op("softmax")(data, axis) + + +def softplus_converter(data, **kwargs): + """Softplus converter""" + if kwargs: + __unexpected_attrs("softplus", kwargs) + + return log_converter(add_converter(exp_converter(data), tvm_expr.const(1.0))) + + +# # linear ops + + +def linear_converter(data, _filter, bias, **kwargs): + """Linear converter""" + if kwargs: + __unexpected_attrs("linear", kwargs) + + out = get_relay_op("matmul")(data, _filter, transpose_b=True) + res = None + + if isinstance(bias, tvm_expr.Constant): + if (bias.data.numpy() == 0).all(): + res = out + + if not res: + # squeeze needed because nnef has bias of shape [1, channel] + res = tvm_op.nn.bias_add(out, relay.squeeze(bias, axis=0)) + + return res + + +def separable_conv_converter( + data, plane_filter, point_filter, bias, border, padding, stride, dilation, groups, **kwargs +): + """Separable convolution converter""" + if kwargs: + __unexpected_attrs("separable_conv", kwargs) + + if isinstance(data, relay.Call): + d_type = infer_type(data).checked_type.dtype + else: + d_type = data.type_annotation.dtype + + filtered = conv_converter( + data, plane_filter, tvm_expr.const(0, dtype=d_type), border, stride, padding, dilation, 0 + ) + + return conv_converter(filtered, point_filter, bias, "constant", [], [], [], groups) + + +def separable_deconv_converter( + data, + plane_filter, + point_filter, + bias, + border, + padding, + stride, + dilation, + output_shape, + groups, + **kwargs, +): + """Separable deconvolution converter""" + if kwargs: + __unexpected_attrs("separable_deconv", kwargs) + + if isinstance(data, relay.Call): + d_type = infer_type(data).checked_type.dtype + else: + d_type = data.type_annotation.dtype + + filtered = deconv_converter( + data, point_filter, tvm_expr.const(0, dtype=d_type), "constant", [], [], [], [], groups + ) + + return deconv_converter( + filtered, plane_filter, bias, border, stride, padding, dilation, output_shape, 0 + ) + + +def max_pool_converter(data, size, border, padding, stride, dilation, **kwargs): + """Max pool converter""" + if kwargs: + __unexpected_attrs("max_pool", kwargs) + + if border != "constant": + print(f"Currently {border} border is not supported, used `constant` border") + + dshape = infer_shape(data) + rank = len(dshape) + + pool_size = _size_conv(size, rank) + strides = _stride_conv(stride, rank) if stride else (1,) * (rank - 2) + + dilation = dilation if dilation else ((1,) * (rank - 2)) + + if not padding: + # padding is truncated to `conv style` (only active layers are present) + padding = _calculate_nnef_padding(dshape[2:], strides, pool_size, dilation) + + pad = _padding_conv(padding, rank) + + if border == "constant": + padding = [(0, 0), (0, 0)] + padding + data = pad_converter(data, padding, border, tvm_expr.const(0.0)) + pad = (0, 0) + + op = get_relay_op(dimension_picker("max_pool", dshape)) + return op( + data, + pool_size=pool_size, + strides=strides, + dilation=dilation, + padding=pad, + ) + + +def avg_pool_converter(data, size, border, padding, stride, dilation, **kwargs): + """Avg pool converter""" + if kwargs: + __unexpected_attrs("avg_pool", kwargs) + + if border not in ["constant", "ignore"]: + print(f"Currently {border} border is not supported, used `constant` border") + + dshape = infer_shape(data) + rank = len(dshape) + pool_size = _size_conv(size, rank) + strides = _stride_conv(stride, rank) if stride else (1,) * (rank - 2) + + dilation = dilation if dilation else ((1,) * (rank - 2)) + + # padding is truncated to `conv style` (only active layers are present) + active_shape = dshape[2:] + if not padding: + padding = _calculate_nnef_padding(active_shape, strides, pool_size, dilation) + + pad = _padding_conv(padding, rank) + + op = get_relay_op(dimension_picker("avg_pool", dshape)) + return op( + data, + pool_size=pool_size, + strides=strides, + dilation=dilation, + padding=pad, + count_include_pad=border != "ignore", + ) + + +def rms_pool_converter(data, size, border, padding, stride, dilation, **kwargs): + """Rms pool converter""" + if kwargs: + __unexpected_attrs("rms_pool", kwargs) + + return sqrt_converter( + avg_pool_converter( + sqr_converter(data), + size=size, + border=border, + padding=padding, + stride=stride, + dilation=dilation, + ) + ) + + +# # Normalization + + +def local_response_normalization_converter(data, size, alpha, beta, bias): + """LRN converter""" + axis = [i for i in range(len(size)) if size[i] > 1] + if len(axis) == 1: + axis = axis[0] + else: + print("Multi axis LRN is not implemented properly, using first axis where size != 1") + axis = axis[0] + size = size[axis] + return get_relay_op("lrn")(data, size, axis, bias, alpha, beta) + + +def local_mean_normalization_converter(data, size, **kwargs): + """LMN converter""" + if kwargs: + __unexpected_attrs("local_mean_normalization", kwargs) + + mean = box_converter(data, size, "constant", [], [], [], normalize=True) + return sub_converter(data, mean) + + +def local_variance_normalization_converter(data, size, bias, epsilon, **kwargs): + """LVN converter""" + if kwargs: + __unexpected_attrs("local_variance_normalization", kwargs) + + sigma = box_converter(sqr_converter(data), size, "constant", [], [], [], normalize=True) + return div_converter( + data, + max_converter( + add_converter(sqrt_converter(sigma), tvm_expr.const(bias)), tvm_expr.const(epsilon) + ), + ) + + +def local_contrast_normalization_converter(data, size, bias, epsilon, **kwargs): + """LCN converter""" + if kwargs: + __unexpected_attrs("local_contrast_normalization", kwargs) + + centered = local_mean_normalization_converter(data, size) + return local_variance_normalization_converter(centered, size, bias, epsilon) + + +def l1_normalization_converter(data, axes, bias, epsilon, **kwargs): + """L1 norm converter""" + if kwargs: + __unexpected_attrs("l1_normalization", kwargs) + + sigma = sum_reduce_converter(abs_converter(data), axes, False) + return div_converter( + data, max_converter(add_converter(sigma, tvm_expr.const(bias)), tvm_expr.const(epsilon)) + ) + + +def l2_normalization_converter(data, axes, bias, epsilon, **kwargs): + """L2 norm converter""" + if kwargs: + __unexpected_attrs("l2_normalization", kwargs) + + epsilon = epsilon**2 + if bias != 0.0: + print("Bias is not supported, assumed 0.0.") + # data = add_converter(data, tvm_expr.const(bias)) + + return get_relay_op("l2_normalize")(data, epsilon, axes) + + +# ok ish + + +def batch_normalization_converter(data, mean, variance, offset, scale, epsilon, **kwargs): + """Batch norm converter""" + if kwargs: + __unexpected_attrs("batch_normalization", kwargs) + + mean = squeeze_converter(mean, 0) + variance = squeeze_converter(variance, 0) + offset = squeeze_converter(offset, 0) + scale = squeeze_converter(scale, 0) + + return get_relay_op("batch_norm")(data, scale, offset, mean, variance, epsilon=epsilon)[0] + + +# # Misc ops diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 757c00e0e344..47c6c6248767 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -96,6 +96,8 @@ "groovy", # Python-parseable config files "ini", + # NNEF graph file + "nnef", } # List of file names allowed diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes index 93478df8dde0..840211aa9b09 100644 --- a/tests/lint/rat-excludes +++ b/tests/lint/rat-excludes @@ -40,6 +40,9 @@ dist .node_repl_history node_modules +# NNEF graphs +.*\.nnef + # Specific files package-list MANIFEST diff --git a/tests/python/frontend/nnef/cases/abs_2d/graph.nnef b/tests/python/frontend/nnef/cases/abs_2d/graph.nnef new file mode 100644 index 000000000000..1f101b10b4d6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/abs_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = abs(input); +} diff --git a/tests/python/frontend/nnef/cases/abs_4d/graph.nnef b/tests/python/frontend/nnef/cases/abs_4d/graph.nnef new file mode 100644 index 000000000000..b4449bae13da --- /dev/null +++ b/tests/python/frontend/nnef/cases/abs_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = abs(input); +} diff --git a/tests/python/frontend/nnef/cases/acos_2d/graph.nnef b/tests/python/frontend/nnef/cases/acos_2d/graph.nnef new file mode 100644 index 000000000000..c6551c478506 --- /dev/null +++ b/tests/python/frontend/nnef/cases/acos_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = acos(input); +} diff --git a/tests/python/frontend/nnef/cases/acos_4d/graph.nnef b/tests/python/frontend/nnef/cases/acos_4d/graph.nnef new file mode 100644 index 000000000000..0a6b58a3407f --- /dev/null +++ b/tests/python/frontend/nnef/cases/acos_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = acos(input); +} diff --git a/tests/python/frontend/nnef/cases/acosh_2d/graph.nnef b/tests/python/frontend/nnef/cases/acosh_2d/graph.nnef new file mode 100644 index 000000000000..c6551c478506 --- /dev/null +++ b/tests/python/frontend/nnef/cases/acosh_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = acos(input); +} diff --git a/tests/python/frontend/nnef/cases/acosh_4d/graph.nnef b/tests/python/frontend/nnef/cases/acosh_4d/graph.nnef new file mode 100644 index 000000000000..0a6b58a3407f --- /dev/null +++ b/tests/python/frontend/nnef/cases/acosh_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = acos(input); +} diff --git a/tests/python/frontend/nnef/cases/add_2d/graph.nnef b/tests/python/frontend/nnef/cases/add_2d/graph.nnef new file mode 100644 index 000000000000..ccb1d0dbf7f8 --- /dev/null +++ b/tests/python/frontend/nnef/cases/add_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = add(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/add_4d/graph.nnef b/tests/python/frontend/nnef/cases/add_4d/graph.nnef new file mode 100644 index 000000000000..63ab32aeab90 --- /dev/null +++ b/tests/python/frontend/nnef/cases/add_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = add(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/add_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/add_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..cc5227e78896 --- /dev/null +++ b/tests/python/frontend/nnef/cases/add_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = add(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/add_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/add_4d_constant/graph.nnef new file mode 100644 index 000000000000..a490b8bac1f7 --- /dev/null +++ b/tests/python/frontend/nnef/cases/add_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = add(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/all_reduce_channel/graph.nnef b/tests/python/frontend/nnef/cases/all_reduce_channel/graph.nnef new file mode 100644 index 000000000000..4655d7f2d03f --- /dev/null +++ b/tests/python/frontend/nnef/cases/all_reduce_channel/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = all_reduce(input, axes = [1]); +} diff --git a/tests/python/frontend/nnef/cases/all_reduce_spatial/graph.nnef b/tests/python/frontend/nnef/cases/all_reduce_spatial/graph.nnef new file mode 100644 index 000000000000..e225df3e38b9 --- /dev/null +++ b/tests/python/frontend/nnef/cases/all_reduce_spatial/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = all_reduce(input, axes = [2,3]); +} diff --git a/tests/python/frontend/nnef/cases/and_2d/graph.nnef b/tests/python/frontend/nnef/cases/and_2d/graph.nnef new file mode 100644 index 000000000000..9aab6ac9e743 --- /dev/null +++ b/tests/python/frontend/nnef/cases/and_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = and(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/and_4d/graph.nnef b/tests/python/frontend/nnef/cases/and_4d/graph.nnef new file mode 100644 index 000000000000..7692dd8689f7 --- /dev/null +++ b/tests/python/frontend/nnef/cases/and_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = and(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/and_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/and_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..4010bb0b9a1f --- /dev/null +++ b/tests/python/frontend/nnef/cases/and_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = and(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/and_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/and_4d_constant/graph.nnef new file mode 100644 index 000000000000..35dee4bb4839 --- /dev/null +++ b/tests/python/frontend/nnef/cases/and_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = and(input, false); +} diff --git a/tests/python/frontend/nnef/cases/any_reduce_channel/graph.nnef b/tests/python/frontend/nnef/cases/any_reduce_channel/graph.nnef new file mode 100644 index 000000000000..40c1c62adef3 --- /dev/null +++ b/tests/python/frontend/nnef/cases/any_reduce_channel/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = any_reduce(input, axes = [1]); +} diff --git a/tests/python/frontend/nnef/cases/any_reduce_spatial/graph.nnef b/tests/python/frontend/nnef/cases/any_reduce_spatial/graph.nnef new file mode 100644 index 000000000000..296877019aa8 --- /dev/null +++ b/tests/python/frontend/nnef/cases/any_reduce_spatial/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = any_reduce(input, axes = [2,3]); +} diff --git a/tests/python/frontend/nnef/cases/area_downsample/graph.nnef b/tests/python/frontend/nnef/cases/area_downsample/graph.nnef new file mode 100644 index 000000000000..df4c5c0951e6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/area_downsample/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = area_downsample(input, factor = [2,2]); +} diff --git a/tests/python/frontend/nnef/cases/argmax_reduce_channel/graph.nnef b/tests/python/frontend/nnef/cases/argmax_reduce_channel/graph.nnef new file mode 100644 index 000000000000..10c00d26fda9 --- /dev/null +++ b/tests/python/frontend/nnef/cases/argmax_reduce_channel/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = argmax_reduce(input, axes = [1]); +} diff --git a/tests/python/frontend/nnef/cases/argmax_reduce_spatial/graph.nnef b/tests/python/frontend/nnef/cases/argmax_reduce_spatial/graph.nnef new file mode 100644 index 000000000000..696dba65ba60 --- /dev/null +++ b/tests/python/frontend/nnef/cases/argmax_reduce_spatial/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = argmax_reduce(input, axes = [2,3]); +} diff --git a/tests/python/frontend/nnef/cases/argmin_reduce_channel/graph.nnef b/tests/python/frontend/nnef/cases/argmin_reduce_channel/graph.nnef new file mode 100644 index 000000000000..dc048a1677ed --- /dev/null +++ b/tests/python/frontend/nnef/cases/argmin_reduce_channel/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = argmin_reduce(input, axes = [1]); +} diff --git a/tests/python/frontend/nnef/cases/argmin_reduce_spatial/graph.nnef b/tests/python/frontend/nnef/cases/argmin_reduce_spatial/graph.nnef new file mode 100644 index 000000000000..0f532835811e --- /dev/null +++ b/tests/python/frontend/nnef/cases/argmin_reduce_spatial/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = argmin_reduce(input, axes = [2,3]); +} diff --git a/tests/python/frontend/nnef/cases/asin_2d/graph.nnef b/tests/python/frontend/nnef/cases/asin_2d/graph.nnef new file mode 100644 index 000000000000..5855e4fc4119 --- /dev/null +++ b/tests/python/frontend/nnef/cases/asin_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = asin(input); +} diff --git a/tests/python/frontend/nnef/cases/asin_4d/graph.nnef b/tests/python/frontend/nnef/cases/asin_4d/graph.nnef new file mode 100644 index 000000000000..1eebea76d39f --- /dev/null +++ b/tests/python/frontend/nnef/cases/asin_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = asin(input); +} diff --git a/tests/python/frontend/nnef/cases/asinh_2d/graph.nnef b/tests/python/frontend/nnef/cases/asinh_2d/graph.nnef new file mode 100644 index 000000000000..95571e7e6337 --- /dev/null +++ b/tests/python/frontend/nnef/cases/asinh_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = asinh(input); +} diff --git a/tests/python/frontend/nnef/cases/asinh_4d/graph.nnef b/tests/python/frontend/nnef/cases/asinh_4d/graph.nnef new file mode 100644 index 000000000000..f42189e2e82f --- /dev/null +++ b/tests/python/frontend/nnef/cases/asinh_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = asinh(input); +} diff --git a/tests/python/frontend/nnef/cases/atan_2d/graph.nnef b/tests/python/frontend/nnef/cases/atan_2d/graph.nnef new file mode 100644 index 000000000000..71948f20d5f2 --- /dev/null +++ b/tests/python/frontend/nnef/cases/atan_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = atan(input); +} diff --git a/tests/python/frontend/nnef/cases/atan_4d/graph.nnef b/tests/python/frontend/nnef/cases/atan_4d/graph.nnef new file mode 100644 index 000000000000..444d042c3caa --- /dev/null +++ b/tests/python/frontend/nnef/cases/atan_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = atan(input); +} diff --git a/tests/python/frontend/nnef/cases/atanh_2d/graph.nnef b/tests/python/frontend/nnef/cases/atanh_2d/graph.nnef new file mode 100644 index 000000000000..859943ddda9e --- /dev/null +++ b/tests/python/frontend/nnef/cases/atanh_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = atanh(input); +} diff --git a/tests/python/frontend/nnef/cases/atanh_4d/graph.nnef b/tests/python/frontend/nnef/cases/atanh_4d/graph.nnef new file mode 100644 index 000000000000..b181be734e71 --- /dev/null +++ b/tests/python/frontend/nnef/cases/atanh_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = atanh(input); +} diff --git a/tests/python/frontend/nnef/cases/avg_pool1x1/graph.nnef b/tests/python/frontend/nnef/cases/avg_pool1x1/graph.nnef new file mode 100644 index 000000000000..295ee379cce7 --- /dev/null +++ b/tests/python/frontend/nnef/cases/avg_pool1x1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = avg_pool(input, size = [1,1,1,1], stride = [1,1,2,2], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/avg_pool2x2/graph.nnef b/tests/python/frontend/nnef/cases/avg_pool2x2/graph.nnef new file mode 100644 index 000000000000..48315774032e --- /dev/null +++ b/tests/python/frontend/nnef/cases/avg_pool2x2/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = avg_pool(input, size = [1,1,2,2], stride = [1,1,2,2], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/avg_pool3x3/graph.nnef b/tests/python/frontend/nnef/cases/avg_pool3x3/graph.nnef new file mode 100644 index 000000000000..33d98645b6c4 --- /dev/null +++ b/tests/python/frontend/nnef/cases/avg_pool3x3/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = avg_pool(input, size = [1,1,3,3], stride = [1,1,2,2], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/avg_pool3x3_ignore-border/graph.nnef b/tests/python/frontend/nnef/cases/avg_pool3x3_ignore-border/graph.nnef new file mode 100644 index 000000000000..be79e6d98562 --- /dev/null +++ b/tests/python/frontend/nnef/cases/avg_pool3x3_ignore-border/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = avg_pool(input, size = [1,1,3,3], stride = [1,1,2,2], border = 'ignore'); +} diff --git a/tests/python/frontend/nnef/cases/avg_pool3x3_pad0-0/graph.nnef b/tests/python/frontend/nnef/cases/avg_pool3x3_pad0-0/graph.nnef new file mode 100644 index 000000000000..0434182b0efa --- /dev/null +++ b/tests/python/frontend/nnef/cases/avg_pool3x3_pad0-0/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = avg_pool(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (0,0), (0,0)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/avg_pool3x3_pad0-1/graph.nnef b/tests/python/frontend/nnef/cases/avg_pool3x3_pad0-1/graph.nnef new file mode 100644 index 000000000000..e43442630cd6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/avg_pool3x3_pad0-1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = avg_pool(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (0,1), (0,1)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/avg_pool3x3_pad1-0/graph.nnef b/tests/python/frontend/nnef/cases/avg_pool3x3_pad1-0/graph.nnef new file mode 100644 index 000000000000..09c854997f9a --- /dev/null +++ b/tests/python/frontend/nnef/cases/avg_pool3x3_pad1-0/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = avg_pool(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (1,0), (1,0)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/avg_pool3x3_pad1-1/graph.nnef b/tests/python/frontend/nnef/cases/avg_pool3x3_pad1-1/graph.nnef new file mode 100644 index 000000000000..c334ba3fb807 --- /dev/null +++ b/tests/python/frontend/nnef/cases/avg_pool3x3_pad1-1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = avg_pool(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (1,1), (1,1)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/avg_pool3x3_stride1x1/graph.nnef b/tests/python/frontend/nnef/cases/avg_pool3x3_stride1x1/graph.nnef new file mode 100644 index 000000000000..d1fbf173a721 --- /dev/null +++ b/tests/python/frontend/nnef/cases/avg_pool3x3_stride1x1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = avg_pool(input, size = [1,1,3,3], stride = [1,1,1,1], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/batch_norm/graph.nnef b/tests/python/frontend/nnef/cases/batch_norm/graph.nnef new file mode 100644 index 000000000000..55197bf03d60 --- /dev/null +++ b/tests/python/frontend/nnef/cases/batch_norm/graph.nnef @@ -0,0 +1,11 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + mean = variable(shape = [1,16], label = 'mean'); + variance = variable(shape = [1,16], label = 'variance'); + offset = variable(shape = [1,16], label = 'offset'); + scale = variable(shape = [1,16], label = 'scale'); + output = batch_normalization(input, mean, variance, offset, scale, epsilon = 1e-3); +} diff --git a/tests/python/frontend/nnef/cases/bilinear_upsample_aligned_constant/graph.nnef b/tests/python/frontend/nnef/cases/bilinear_upsample_aligned_constant/graph.nnef new file mode 100644 index 000000000000..6fbc55a4b61d --- /dev/null +++ b/tests/python/frontend/nnef/cases/bilinear_upsample_aligned_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = multilinear_upsample(input, factor = [2,2], method = 'aligned', border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/bilinear_upsample_aligned_replicate/graph.nnef b/tests/python/frontend/nnef/cases/bilinear_upsample_aligned_replicate/graph.nnef new file mode 100644 index 000000000000..5bdee4db665c --- /dev/null +++ b/tests/python/frontend/nnef/cases/bilinear_upsample_aligned_replicate/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = multilinear_upsample(input, factor = [2,2], method = 'aligned', border = 'replicate'); +} diff --git a/tests/python/frontend/nnef/cases/bilinear_upsample_asymmetric_constant/graph.nnef b/tests/python/frontend/nnef/cases/bilinear_upsample_asymmetric_constant/graph.nnef new file mode 100644 index 000000000000..e94572e7aa3b --- /dev/null +++ b/tests/python/frontend/nnef/cases/bilinear_upsample_asymmetric_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = multilinear_upsample(input, factor = [2,2], method = 'asymmetric', border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/bilinear_upsample_asymmetric_replicate/graph.nnef b/tests/python/frontend/nnef/cases/bilinear_upsample_asymmetric_replicate/graph.nnef new file mode 100644 index 000000000000..59a0c229a5d7 --- /dev/null +++ b/tests/python/frontend/nnef/cases/bilinear_upsample_asymmetric_replicate/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = multilinear_upsample(input, factor = [2,2], method = 'asymmetric', border = 'replicate'); +} diff --git a/tests/python/frontend/nnef/cases/bilinear_upsample_symmetric_constant/graph.nnef b/tests/python/frontend/nnef/cases/bilinear_upsample_symmetric_constant/graph.nnef new file mode 100644 index 000000000000..2e6cc716f9f7 --- /dev/null +++ b/tests/python/frontend/nnef/cases/bilinear_upsample_symmetric_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = multilinear_upsample(input, factor = [2,2], method = 'symmetric', border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/bilinear_upsample_symmetric_replicate/graph.nnef b/tests/python/frontend/nnef/cases/bilinear_upsample_symmetric_replicate/graph.nnef new file mode 100644 index 000000000000..a0721fe77e22 --- /dev/null +++ b/tests/python/frontend/nnef/cases/bilinear_upsample_symmetric_replicate/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = multilinear_upsample(input, factor = [2,2], method = 'symmetric', border = 'replicate'); +} diff --git a/tests/python/frontend/nnef/cases/box1x1/graph.nnef b/tests/python/frontend/nnef/cases/box1x1/graph.nnef new file mode 100644 index 000000000000..2f3c0876e950 --- /dev/null +++ b/tests/python/frontend/nnef/cases/box1x1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = box(input, size = [1,1,1,1], stride = [1,1,2,2], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/box2x2/graph.nnef b/tests/python/frontend/nnef/cases/box2x2/graph.nnef new file mode 100644 index 000000000000..693903905ea9 --- /dev/null +++ b/tests/python/frontend/nnef/cases/box2x2/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = box(input, size = [1,1,2,2], stride = [1,1,2,2], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/box3x3/graph.nnef b/tests/python/frontend/nnef/cases/box3x3/graph.nnef new file mode 100644 index 000000000000..60135ee5b37e --- /dev/null +++ b/tests/python/frontend/nnef/cases/box3x3/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = box(input, size = [1,1,3,3], stride = [1,1,2,2], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/box3x3_pad0-0/graph.nnef b/tests/python/frontend/nnef/cases/box3x3_pad0-0/graph.nnef new file mode 100644 index 000000000000..baf67c5304e5 --- /dev/null +++ b/tests/python/frontend/nnef/cases/box3x3_pad0-0/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = box(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (0,0), (0,0)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/box3x3_pad0-1/graph.nnef b/tests/python/frontend/nnef/cases/box3x3_pad0-1/graph.nnef new file mode 100644 index 000000000000..a5a86b05c09c --- /dev/null +++ b/tests/python/frontend/nnef/cases/box3x3_pad0-1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = box(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (0,1), (0,1)], border = 'ignore'); +} diff --git a/tests/python/frontend/nnef/cases/box3x3_pad1-0/graph.nnef b/tests/python/frontend/nnef/cases/box3x3_pad1-0/graph.nnef new file mode 100644 index 000000000000..485a57b456fd --- /dev/null +++ b/tests/python/frontend/nnef/cases/box3x3_pad1-0/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = box(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (1,0), (1,0)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/box3x3_pad1-1/graph.nnef b/tests/python/frontend/nnef/cases/box3x3_pad1-1/graph.nnef new file mode 100644 index 000000000000..d660e46aecb8 --- /dev/null +++ b/tests/python/frontend/nnef/cases/box3x3_pad1-1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = box(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (1,1), (1,1)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/box3x3_stride1x1/graph.nnef b/tests/python/frontend/nnef/cases/box3x3_stride1x1/graph.nnef new file mode 100644 index 000000000000..dd78ea76b5bf --- /dev/null +++ b/tests/python/frontend/nnef/cases/box3x3_stride1x1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = box(input, size = [1,1,3,3], stride = [1,1,1,1], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/ceil_2d/graph.nnef b/tests/python/frontend/nnef/cases/ceil_2d/graph.nnef new file mode 100644 index 000000000000..1a599994c7eb --- /dev/null +++ b/tests/python/frontend/nnef/cases/ceil_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = ceil(input); +} diff --git a/tests/python/frontend/nnef/cases/ceil_4d/graph.nnef b/tests/python/frontend/nnef/cases/ceil_4d/graph.nnef new file mode 100644 index 000000000000..07cec8f89947 --- /dev/null +++ b/tests/python/frontend/nnef/cases/ceil_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = ceil(input); +} diff --git a/tests/python/frontend/nnef/cases/clamp_2d/graph.nnef b/tests/python/frontend/nnef/cases/clamp_2d/graph.nnef new file mode 100644 index 000000000000..7f6747ba4d4e --- /dev/null +++ b/tests/python/frontend/nnef/cases/clamp_2d/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input1, input2, input3 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + input3 = external(shape = [4,16]); + output = clamp(input1, input2, input3); +} diff --git a/tests/python/frontend/nnef/cases/clamp_4d/graph.nnef b/tests/python/frontend/nnef/cases/clamp_4d/graph.nnef new file mode 100644 index 000000000000..0fbf546b1cfd --- /dev/null +++ b/tests/python/frontend/nnef/cases/clamp_4d/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input1, input2, input3 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + input3 = external(shape = [4,16,32,32]); + output = clamp(input1, input2, input3); +} diff --git a/tests/python/frontend/nnef/cases/clamp_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/clamp_4d_constant/graph.nnef new file mode 100644 index 000000000000..ea73414bcb14 --- /dev/null +++ b/tests/python/frontend/nnef/cases/clamp_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = clamp(input, 0.25, 0.75); +} diff --git a/tests/python/frontend/nnef/cases/concat_channel/graph.nnef b/tests/python/frontend/nnef/cases/concat_channel/graph.nnef new file mode 100644 index 000000000000..211f1366c2e9 --- /dev/null +++ b/tests/python/frontend/nnef/cases/concat_channel/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = concat([input1, input2], axis = 1); +} diff --git a/tests/python/frontend/nnef/cases/conv1x1/graph.nnef b/tests/python/frontend/nnef/cases/conv1x1/graph.nnef new file mode 100644 index 000000000000..75d5ca91ac23 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv1x1/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,1,1], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/conv2x2/graph.nnef b/tests/python/frontend/nnef/cases/conv2x2/graph.nnef new file mode 100644 index 000000000000..e47b8fb89772 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv2x2/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,2,2], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/conv3x3/graph.nnef b/tests/python/frontend/nnef/cases/conv3x3/graph.nnef new file mode 100644 index 000000000000..687f05187fa8 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv3x3/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/conv3x3_groups0/graph.nnef b/tests/python/frontend/nnef/cases/conv3x3_groups0/graph.nnef new file mode 100644 index 000000000000..5f169f240925 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv3x3_groups0/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,1,3,3], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias, groups = 0); +} diff --git a/tests/python/frontend/nnef/cases/conv3x3_nobias/graph.nnef b/tests/python/frontend/nnef/cases/conv3x3_nobias/graph.nnef new file mode 100644 index 000000000000..396692d4681f --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv3x3_nobias/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + output = conv(input, filter, 0.0); +} diff --git a/tests/python/frontend/nnef/cases/conv3x3_pad0-0/graph.nnef b/tests/python/frontend/nnef/cases/conv3x3_pad0-0/graph.nnef new file mode 100644 index 000000000000..7365760bacd0 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv3x3_pad0-0/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias, padding = [(0,0), (0,0)]); +} diff --git a/tests/python/frontend/nnef/cases/conv3x3_pad0-1/graph.nnef b/tests/python/frontend/nnef/cases/conv3x3_pad0-1/graph.nnef new file mode 100644 index 000000000000..228fee93cb9f --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv3x3_pad0-1/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias, padding = [(0,1), (0,1)]); +} diff --git a/tests/python/frontend/nnef/cases/conv3x3_pad1-0/graph.nnef b/tests/python/frontend/nnef/cases/conv3x3_pad1-0/graph.nnef new file mode 100644 index 000000000000..f28b4b4a2a8a --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv3x3_pad1-0/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias, padding = [(1,0), (1,0)]); +} diff --git a/tests/python/frontend/nnef/cases/conv3x3_pad1-1/graph.nnef b/tests/python/frontend/nnef/cases/conv3x3_pad1-1/graph.nnef new file mode 100644 index 000000000000..4948bf379449 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv3x3_pad1-1/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias, padding = [(1,1), (1,1)]); +} diff --git a/tests/python/frontend/nnef/cases/conv3x3_stride2x2/graph.nnef b/tests/python/frontend/nnef/cases/conv3x3_stride2x2/graph.nnef new file mode 100644 index 000000000000..5f4df908f330 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv3x3_stride2x2/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias, stride = [2,2]); +} diff --git a/tests/python/frontend/nnef/cases/conv3x3_valid/graph.nnef b/tests/python/frontend/nnef/cases/conv3x3_valid/graph.nnef new file mode 100644 index 000000000000..7365760bacd0 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv3x3_valid/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias, padding = [(0,0), (0,0)]); +} diff --git a/tests/python/frontend/nnef/cases/conv4x4/graph.nnef b/tests/python/frontend/nnef/cases/conv4x4/graph.nnef new file mode 100644 index 000000000000..ee6de3aa535e --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv4x4/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,4,4], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/conv4x4_stride2x2/graph.nnef b/tests/python/frontend/nnef/cases/conv4x4_stride2x2/graph.nnef new file mode 100644 index 000000000000..5a86b6850dd4 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv4x4_stride2x2/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,4,4], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias, stride = [2,2]); +} diff --git a/tests/python/frontend/nnef/cases/conv5x5/graph.nnef b/tests/python/frontend/nnef/cases/conv5x5/graph.nnef new file mode 100644 index 000000000000..bda7b9e120fa --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv5x5/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,5,5], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/conv5x5_pad2-2/graph.nnef b/tests/python/frontend/nnef/cases/conv5x5_pad2-2/graph.nnef new file mode 100644 index 000000000000..7d121ff5e126 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv5x5_pad2-2/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,5,5], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias, padding = [(2,2), (2,2)]); +} diff --git a/tests/python/frontend/nnef/cases/conv5x5_stride3x3/graph.nnef b/tests/python/frontend/nnef/cases/conv5x5_stride3x3/graph.nnef new file mode 100644 index 000000000000..bac1ff164ac4 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv5x5_stride3x3/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,5,5], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias, stride = [3,3]); +} diff --git a/tests/python/frontend/nnef/cases/conv6x6/graph.nnef b/tests/python/frontend/nnef/cases/conv6x6/graph.nnef new file mode 100644 index 000000000000..157d2b73e9ee --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv6x6/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,6,6], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/conv7x7/graph.nnef b/tests/python/frontend/nnef/cases/conv7x7/graph.nnef new file mode 100644 index 000000000000..92e3cdac4404 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv7x7/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,7,7], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/conv7x7_stride4x4/graph.nnef b/tests/python/frontend/nnef/cases/conv7x7_stride4x4/graph.nnef new file mode 100644 index 000000000000..e7c9a49cda33 --- /dev/null +++ b/tests/python/frontend/nnef/cases/conv7x7_stride4x4/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + filter = variable(shape = [16,8,7,7], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = conv(input, filter, bias, stride = [4,4]); +} diff --git a/tests/python/frontend/nnef/cases/copy_2d/graph.nnef b/tests/python/frontend/nnef/cases/copy_2d/graph.nnef new file mode 100644 index 000000000000..cdc5613513b4 --- /dev/null +++ b/tests/python/frontend/nnef/cases/copy_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = copy(input); +} diff --git a/tests/python/frontend/nnef/cases/copy_4d/graph.nnef b/tests/python/frontend/nnef/cases/copy_4d/graph.nnef new file mode 100644 index 000000000000..c80293c10b76 --- /dev/null +++ b/tests/python/frontend/nnef/cases/copy_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = copy(input); +} diff --git a/tests/python/frontend/nnef/cases/cos_2d/graph.nnef b/tests/python/frontend/nnef/cases/cos_2d/graph.nnef new file mode 100644 index 000000000000..d82b2731c822 --- /dev/null +++ b/tests/python/frontend/nnef/cases/cos_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = cos(input); +} diff --git a/tests/python/frontend/nnef/cases/cos_4d/graph.nnef b/tests/python/frontend/nnef/cases/cos_4d/graph.nnef new file mode 100644 index 000000000000..6e4264735a32 --- /dev/null +++ b/tests/python/frontend/nnef/cases/cos_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = cos(input); +} diff --git a/tests/python/frontend/nnef/cases/cosh_2d/graph.nnef b/tests/python/frontend/nnef/cases/cosh_2d/graph.nnef new file mode 100644 index 000000000000..538b3daab320 --- /dev/null +++ b/tests/python/frontend/nnef/cases/cosh_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = cosh(input); +} diff --git a/tests/python/frontend/nnef/cases/cosh_4d/graph.nnef b/tests/python/frontend/nnef/cases/cosh_4d/graph.nnef new file mode 100644 index 000000000000..76c83f2e2d74 --- /dev/null +++ b/tests/python/frontend/nnef/cases/cosh_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = cosh(input); +} diff --git a/tests/python/frontend/nnef/cases/debox1x1/graph.nnef b/tests/python/frontend/nnef/cases/debox1x1/graph.nnef new file mode 100644 index 000000000000..cf6b31a87e58 --- /dev/null +++ b/tests/python/frontend/nnef/cases/debox1x1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = debox(input, size = [1,1,1,1], stride = [1,1,2,2], padding = [(0,0),(0,0),(0,-1),(0,-1)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/debox2x2/graph.nnef b/tests/python/frontend/nnef/cases/debox2x2/graph.nnef new file mode 100644 index 000000000000..75ae129d5cc9 --- /dev/null +++ b/tests/python/frontend/nnef/cases/debox2x2/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = debox(input, size = [1,1,2,2], stride = [1,1,2,2], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/debox3x3/graph.nnef b/tests/python/frontend/nnef/cases/debox3x3/graph.nnef new file mode 100644 index 000000000000..02f1a26532f0 --- /dev/null +++ b/tests/python/frontend/nnef/cases/debox3x3/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = debox(input, size = [1,1,3,3], stride = [1,1,2,2], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/debox3x3_pad0-0/graph.nnef b/tests/python/frontend/nnef/cases/debox3x3_pad0-0/graph.nnef new file mode 100644 index 000000000000..ac127aa0bd25 --- /dev/null +++ b/tests/python/frontend/nnef/cases/debox3x3_pad0-0/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = debox(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (0,0), (0,0)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/debox3x3_pad0-1/graph.nnef b/tests/python/frontend/nnef/cases/debox3x3_pad0-1/graph.nnef new file mode 100644 index 000000000000..3982739aa208 --- /dev/null +++ b/tests/python/frontend/nnef/cases/debox3x3_pad0-1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = debox(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (0,1), (0,1)], border = 'ignore'); +} diff --git a/tests/python/frontend/nnef/cases/debox3x3_pad1-0/graph.nnef b/tests/python/frontend/nnef/cases/debox3x3_pad1-0/graph.nnef new file mode 100644 index 000000000000..12eb3815c833 --- /dev/null +++ b/tests/python/frontend/nnef/cases/debox3x3_pad1-0/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = debox(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (1,0), (1,0)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/debox3x3_pad1-1/graph.nnef b/tests/python/frontend/nnef/cases/debox3x3_pad1-1/graph.nnef new file mode 100644 index 000000000000..6195f3ae620c --- /dev/null +++ b/tests/python/frontend/nnef/cases/debox3x3_pad1-1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = debox(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (1,1), (1,1)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/debox3x3_stride1x1/graph.nnef b/tests/python/frontend/nnef/cases/debox3x3_stride1x1/graph.nnef new file mode 100644 index 000000000000..9424ce312b26 --- /dev/null +++ b/tests/python/frontend/nnef/cases/debox3x3_stride1x1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = debox(input, size = [1,1,3,3], stride = [1,1,1,1], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/deconv1x1/graph.nnef b/tests/python/frontend/nnef/cases/deconv1x1/graph.nnef new file mode 100644 index 000000000000..e49e37942475 --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv1x1/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,1,1], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/deconv2x2/graph.nnef b/tests/python/frontend/nnef/cases/deconv2x2/graph.nnef new file mode 100644 index 000000000000..1039bfe5aaca --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv2x2/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,2,2], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/deconv3x3/graph.nnef b/tests/python/frontend/nnef/cases/deconv3x3/graph.nnef new file mode 100644 index 000000000000..c4900e0c8125 --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv3x3/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/deconv3x3_groups0/graph.nnef b/tests/python/frontend/nnef/cases/deconv3x3_groups0/graph.nnef new file mode 100644 index 000000000000..d817b0e8d8cf --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv3x3_groups0/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,1,3,3], label = 'filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = deconv(input, filter, bias, groups = 0); +} diff --git a/tests/python/frontend/nnef/cases/deconv3x3_nobias/graph.nnef b/tests/python/frontend/nnef/cases/deconv3x3_nobias/graph.nnef new file mode 100644 index 000000000000..dbb9c056fbe1 --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv3x3_nobias/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + output = deconv(input, filter, 0.0); +} diff --git a/tests/python/frontend/nnef/cases/deconv3x3_pad0-0/graph.nnef b/tests/python/frontend/nnef/cases/deconv3x3_pad0-0/graph.nnef new file mode 100644 index 000000000000..9623b24e10fc --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv3x3_pad0-0/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias, padding = [(0,0), (0,0)]); +} diff --git a/tests/python/frontend/nnef/cases/deconv3x3_pad0-1/graph.nnef b/tests/python/frontend/nnef/cases/deconv3x3_pad0-1/graph.nnef new file mode 100644 index 000000000000..1c95c94ed91d --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv3x3_pad0-1/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias, padding = [(0,1), (0,1)]); +} diff --git a/tests/python/frontend/nnef/cases/deconv3x3_pad1-0/graph.nnef b/tests/python/frontend/nnef/cases/deconv3x3_pad1-0/graph.nnef new file mode 100644 index 000000000000..395e8436ba0b --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv3x3_pad1-0/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias, padding = [(1,0), (1,0)]); +} diff --git a/tests/python/frontend/nnef/cases/deconv3x3_pad1-1/graph.nnef b/tests/python/frontend/nnef/cases/deconv3x3_pad1-1/graph.nnef new file mode 100644 index 000000000000..97d4dfebae2c --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv3x3_pad1-1/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias, padding = [(1,1), (1,1)]); +} diff --git a/tests/python/frontend/nnef/cases/deconv3x3_stride2x2/graph.nnef b/tests/python/frontend/nnef/cases/deconv3x3_stride2x2/graph.nnef new file mode 100644 index 000000000000..ee2eb0ae206c --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv3x3_stride2x2/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias, stride = [2,2]); +} diff --git a/tests/python/frontend/nnef/cases/deconv3x3_valid/graph.nnef b/tests/python/frontend/nnef/cases/deconv3x3_valid/graph.nnef new file mode 100644 index 000000000000..9623b24e10fc --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv3x3_valid/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,3,3], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias, padding = [(0,0), (0,0)]); +} diff --git a/tests/python/frontend/nnef/cases/deconv4x4/graph.nnef b/tests/python/frontend/nnef/cases/deconv4x4/graph.nnef new file mode 100644 index 000000000000..04eb81101fe2 --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv4x4/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,4,4], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/deconv4x4_stride2x2/graph.nnef b/tests/python/frontend/nnef/cases/deconv4x4_stride2x2/graph.nnef new file mode 100644 index 000000000000..a5d1b28e3b81 --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv4x4_stride2x2/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,4,4], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias, stride = [2,2]); +} diff --git a/tests/python/frontend/nnef/cases/deconv5x5/graph.nnef b/tests/python/frontend/nnef/cases/deconv5x5/graph.nnef new file mode 100644 index 000000000000..d928d9f53400 --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv5x5/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,5,5], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/deconv5x5_pad2-2/graph.nnef b/tests/python/frontend/nnef/cases/deconv5x5_pad2-2/graph.nnef new file mode 100644 index 000000000000..5713e9dd7d34 --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv5x5_pad2-2/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,5,5], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias, padding = [(2,2), (2,2)]); +} diff --git a/tests/python/frontend/nnef/cases/deconv5x5_stride3x3/graph.nnef b/tests/python/frontend/nnef/cases/deconv5x5_stride3x3/graph.nnef new file mode 100644 index 000000000000..8eb31d814dea --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv5x5_stride3x3/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,5,5], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias, stride = [3,3]); +} diff --git a/tests/python/frontend/nnef/cases/deconv6x6/graph.nnef b/tests/python/frontend/nnef/cases/deconv6x6/graph.nnef new file mode 100644 index 000000000000..6f5fd0a012ca --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv6x6/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,6,6], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/deconv7x7/graph.nnef b/tests/python/frontend/nnef/cases/deconv7x7/graph.nnef new file mode 100644 index 000000000000..1e637b499c22 --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv7x7/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,7,7], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/deconv7x7_stride4x4/graph.nnef b/tests/python/frontend/nnef/cases/deconv7x7_stride4x4/graph.nnef new file mode 100644 index 000000000000..c9974fa8d814 --- /dev/null +++ b/tests/python/frontend/nnef/cases/deconv7x7_stride4x4/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = variable(shape = [16,8,7,7], label = 'filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = deconv(input, filter, bias, stride = [4,4]); +} diff --git a/tests/python/frontend/nnef/cases/div_2d/graph.nnef b/tests/python/frontend/nnef/cases/div_2d/graph.nnef new file mode 100644 index 000000000000..c24464c92548 --- /dev/null +++ b/tests/python/frontend/nnef/cases/div_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = div(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/div_4d/graph.nnef b/tests/python/frontend/nnef/cases/div_4d/graph.nnef new file mode 100644 index 000000000000..b173873e6636 --- /dev/null +++ b/tests/python/frontend/nnef/cases/div_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = div(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/div_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/div_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..78a3f4cc86b7 --- /dev/null +++ b/tests/python/frontend/nnef/cases/div_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = div(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/div_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/div_4d_constant/graph.nnef new file mode 100644 index 000000000000..4391fe6dc89f --- /dev/null +++ b/tests/python/frontend/nnef/cases/div_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = div(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/elu/graph.nnef b/tests/python/frontend/nnef/cases/elu/graph.nnef new file mode 100644 index 000000000000..358b37639529 --- /dev/null +++ b/tests/python/frontend/nnef/cases/elu/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16,32,32]); + filter = constant(shape = [16,1,1,1], value = [1.0]); + bias = constant(shape = [1,16], value = [0.0]); + conv = conv(input, filter, bias, groups = 0); + output = elu(conv); +} diff --git a/tests/python/frontend/nnef/cases/elu_2d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/elu_2d_standalone/graph.nnef new file mode 100644 index 000000000000..96b500d63b02 --- /dev/null +++ b/tests/python/frontend/nnef/cases/elu_2d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16]); + output = elu(input); +} diff --git a/tests/python/frontend/nnef/cases/elu_4d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/elu_4d_standalone/graph.nnef new file mode 100644 index 000000000000..512bd6b651c6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/elu_4d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16,32,32]); + output = elu(input); +} diff --git a/tests/python/frontend/nnef/cases/eq_2d/graph.nnef b/tests/python/frontend/nnef/cases/eq_2d/graph.nnef new file mode 100644 index 000000000000..2869ff5d97bc --- /dev/null +++ b/tests/python/frontend/nnef/cases/eq_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = eq(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/eq_4d/graph.nnef b/tests/python/frontend/nnef/cases/eq_4d/graph.nnef new file mode 100644 index 000000000000..318154c8e9ff --- /dev/null +++ b/tests/python/frontend/nnef/cases/eq_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = eq(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/eq_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/eq_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..b944355b7add --- /dev/null +++ b/tests/python/frontend/nnef/cases/eq_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = eq(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/eq_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/eq_4d_constant/graph.nnef new file mode 100644 index 000000000000..e173af679c3a --- /dev/null +++ b/tests/python/frontend/nnef/cases/eq_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = eq(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/exp_2d/graph.nnef b/tests/python/frontend/nnef/cases/exp_2d/graph.nnef new file mode 100644 index 000000000000..0cc1698c9030 --- /dev/null +++ b/tests/python/frontend/nnef/cases/exp_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = exp(input); +} diff --git a/tests/python/frontend/nnef/cases/exp_4d/graph.nnef b/tests/python/frontend/nnef/cases/exp_4d/graph.nnef new file mode 100644 index 000000000000..f312ca4506fa --- /dev/null +++ b/tests/python/frontend/nnef/cases/exp_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = exp(input); +} diff --git a/tests/python/frontend/nnef/cases/floor_2d/graph.nnef b/tests/python/frontend/nnef/cases/floor_2d/graph.nnef new file mode 100644 index 000000000000..bc88b588ad58 --- /dev/null +++ b/tests/python/frontend/nnef/cases/floor_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = floor(input); +} diff --git a/tests/python/frontend/nnef/cases/floor_4d/graph.nnef b/tests/python/frontend/nnef/cases/floor_4d/graph.nnef new file mode 100644 index 000000000000..00815209ef4a --- /dev/null +++ b/tests/python/frontend/nnef/cases/floor_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = floor(input); +} diff --git a/tests/python/frontend/nnef/cases/ge_2d/graph.nnef b/tests/python/frontend/nnef/cases/ge_2d/graph.nnef new file mode 100644 index 000000000000..b6abf50776f6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/ge_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = ge(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/ge_4d/graph.nnef b/tests/python/frontend/nnef/cases/ge_4d/graph.nnef new file mode 100644 index 000000000000..dadab33cbcaf --- /dev/null +++ b/tests/python/frontend/nnef/cases/ge_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = ge(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/ge_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/ge_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..b1af29219053 --- /dev/null +++ b/tests/python/frontend/nnef/cases/ge_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = ge(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/ge_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/ge_4d_constant/graph.nnef new file mode 100644 index 000000000000..6d779025c607 --- /dev/null +++ b/tests/python/frontend/nnef/cases/ge_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = ge(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/gelu/graph.nnef b/tests/python/frontend/nnef/cases/gelu/graph.nnef new file mode 100644 index 000000000000..3fdfce946d91 --- /dev/null +++ b/tests/python/frontend/nnef/cases/gelu/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16,32,32]); + filter = constant(shape = [16,1,1,1], value = [1.0]); + bias = constant(shape = [1,16], value = [0.0]); + conv = conv(input, filter, bias, groups = 0); + output = gelu(conv); +} diff --git a/tests/python/frontend/nnef/cases/gelu_2d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/gelu_2d_standalone/graph.nnef new file mode 100644 index 000000000000..c903678fa9f5 --- /dev/null +++ b/tests/python/frontend/nnef/cases/gelu_2d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16]); + output = gelu(input); +} diff --git a/tests/python/frontend/nnef/cases/gelu_4d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/gelu_4d_standalone/graph.nnef new file mode 100644 index 000000000000..7180e1a8de53 --- /dev/null +++ b/tests/python/frontend/nnef/cases/gelu_4d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16,32,32]); + output = gelu(input); +} diff --git a/tests/python/frontend/nnef/cases/gt_2d/graph.nnef b/tests/python/frontend/nnef/cases/gt_2d/graph.nnef new file mode 100644 index 000000000000..48bc77a5bdcc --- /dev/null +++ b/tests/python/frontend/nnef/cases/gt_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = gt(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/gt_4d/graph.nnef b/tests/python/frontend/nnef/cases/gt_4d/graph.nnef new file mode 100644 index 000000000000..e3d392a6560d --- /dev/null +++ b/tests/python/frontend/nnef/cases/gt_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = gt(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/gt_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/gt_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..bf2cf2b3ede2 --- /dev/null +++ b/tests/python/frontend/nnef/cases/gt_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = gt(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/gt_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/gt_4d_constant/graph.nnef new file mode 100644 index 000000000000..252a483af67e --- /dev/null +++ b/tests/python/frontend/nnef/cases/gt_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = gt(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/l1_normalization/graph.nnef b/tests/python/frontend/nnef/cases/l1_normalization/graph.nnef new file mode 100644 index 000000000000..0833aa1e2a13 --- /dev/null +++ b/tests/python/frontend/nnef/cases/l1_normalization/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = l1_normalization(input, axes = [1], bias = 1.0, epsilon = 1e-5); +} diff --git a/tests/python/frontend/nnef/cases/l2_normalization/graph.nnef b/tests/python/frontend/nnef/cases/l2_normalization/graph.nnef new file mode 100644 index 000000000000..3bb94b6f52f2 --- /dev/null +++ b/tests/python/frontend/nnef/cases/l2_normalization/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = l2_normalization(input, axes = [1], epsilon = 1e-3); +} diff --git a/tests/python/frontend/nnef/cases/le_2d/graph.nnef b/tests/python/frontend/nnef/cases/le_2d/graph.nnef new file mode 100644 index 000000000000..b89c3f19a726 --- /dev/null +++ b/tests/python/frontend/nnef/cases/le_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = le(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/le_4d/graph.nnef b/tests/python/frontend/nnef/cases/le_4d/graph.nnef new file mode 100644 index 000000000000..17f821d93d8c --- /dev/null +++ b/tests/python/frontend/nnef/cases/le_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = le(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/le_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/le_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..e7df0cf2aab7 --- /dev/null +++ b/tests/python/frontend/nnef/cases/le_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = le(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/le_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/le_4d_constant/graph.nnef new file mode 100644 index 000000000000..328a17aab564 --- /dev/null +++ b/tests/python/frontend/nnef/cases/le_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = le(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/leaky_relu/graph.nnef b/tests/python/frontend/nnef/cases/leaky_relu/graph.nnef new file mode 100644 index 000000000000..43a829232d4c --- /dev/null +++ b/tests/python/frontend/nnef/cases/leaky_relu/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16,32,32]); + filter = constant(shape = [16,1,1,1], value = [1.0]); + bias = constant(shape = [1,16], value = [0.0]); + conv = conv(input, filter, bias, groups = 0); + output = leaky_relu(conv, alpha = 0.5); +} diff --git a/tests/python/frontend/nnef/cases/leaky_relu_2d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/leaky_relu_2d_standalone/graph.nnef new file mode 100644 index 000000000000..24b239e58ddc --- /dev/null +++ b/tests/python/frontend/nnef/cases/leaky_relu_2d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16]); + output = leaky_relu(input, alpha = 0.5); +} diff --git a/tests/python/frontend/nnef/cases/leaky_relu_4d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/leaky_relu_4d_standalone/graph.nnef new file mode 100644 index 000000000000..04e6c4dadaef --- /dev/null +++ b/tests/python/frontend/nnef/cases/leaky_relu_4d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16,32,32]); + output = leaky_relu(input, alpha = 0.5); +} diff --git a/tests/python/frontend/nnef/cases/linear/graph.nnef b/tests/python/frontend/nnef/cases/linear/graph.nnef new file mode 100644 index 000000000000..0cbef1067c1b --- /dev/null +++ b/tests/python/frontend/nnef/cases/linear/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + weights = variable(shape = [32,16], label = 'weights'); + bias = variable(shape = [1,32], label = 'bias'); + output = linear(input, weights, bias); +} diff --git a/tests/python/frontend/nnef/cases/linear_nobias/graph.nnef b/tests/python/frontend/nnef/cases/linear_nobias/graph.nnef new file mode 100644 index 000000000000..9a93ea8d2177 --- /dev/null +++ b/tests/python/frontend/nnef/cases/linear_nobias/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + weights = variable(shape = [32,16], label = 'weights'); + output = linear(input, weights, 0.0); +} diff --git a/tests/python/frontend/nnef/cases/linear_reshape/graph.nnef b/tests/python/frontend/nnef/cases/linear_reshape/graph.nnef new file mode 100644 index 000000000000..d17413a265ea --- /dev/null +++ b/tests/python/frontend/nnef/cases/linear_reshape/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,8,8]); + weights = variable(shape = [32,1024], label = 'weights'); + bias = variable(shape = [1,32], label = 'bias'); + flattened = reshape(input, shape = [0,-1]); + output = linear(flattened, weights, bias); +} diff --git a/tests/python/frontend/nnef/cases/linear_squeeze/graph.nnef b/tests/python/frontend/nnef/cases/linear_squeeze/graph.nnef new file mode 100644 index 000000000000..2971aa570d10 --- /dev/null +++ b/tests/python/frontend/nnef/cases/linear_squeeze/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,1,1]); + weights = variable(shape = [32,16], label = 'weights'); + bias = variable(shape = [1,32], label = 'bias'); + squeezed = squeeze(input, axes = [2,3]); + output = linear(squeezed, weights, bias); +} diff --git a/tests/python/frontend/nnef/cases/local_contrast_normalization/graph.nnef b/tests/python/frontend/nnef/cases/local_contrast_normalization/graph.nnef new file mode 100644 index 000000000000..ac9434e87689 --- /dev/null +++ b/tests/python/frontend/nnef/cases/local_contrast_normalization/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = local_contrast_normalization(input, size = [1, 1, 3, 3], bias = 1.0, epsilon = 1e-5); +} diff --git a/tests/python/frontend/nnef/cases/local_mean_normalization/graph.nnef b/tests/python/frontend/nnef/cases/local_mean_normalization/graph.nnef new file mode 100644 index 000000000000..2aa3a8b7d529 --- /dev/null +++ b/tests/python/frontend/nnef/cases/local_mean_normalization/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = local_mean_normalization(input, size = [1, 1, 3, 3]); +} diff --git a/tests/python/frontend/nnef/cases/local_response_normalization/graph.nnef b/tests/python/frontend/nnef/cases/local_response_normalization/graph.nnef new file mode 100644 index 000000000000..b450cc8cea90 --- /dev/null +++ b/tests/python/frontend/nnef/cases/local_response_normalization/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = local_response_normalization(input, alpha = 1e-05, beta = 0.75, bias = 1.0, size = [1, 5, 1, 1]); +} diff --git a/tests/python/frontend/nnef/cases/local_variance_normalization/graph.nnef b/tests/python/frontend/nnef/cases/local_variance_normalization/graph.nnef new file mode 100644 index 000000000000..83b0c6ebfff1 --- /dev/null +++ b/tests/python/frontend/nnef/cases/local_variance_normalization/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = local_variance_normalization(input, size = [1, 1, 3, 3], bias = 1.0, epsilon = 1e-5); +} diff --git a/tests/python/frontend/nnef/cases/log2_2d/graph.nnef b/tests/python/frontend/nnef/cases/log2_2d/graph.nnef new file mode 100644 index 000000000000..166e05ed6a17 --- /dev/null +++ b/tests/python/frontend/nnef/cases/log2_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = log2(input); +} diff --git a/tests/python/frontend/nnef/cases/log2_4d/graph.nnef b/tests/python/frontend/nnef/cases/log2_4d/graph.nnef new file mode 100644 index 000000000000..95b71212ce00 --- /dev/null +++ b/tests/python/frontend/nnef/cases/log2_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = log2(input); +} diff --git a/tests/python/frontend/nnef/cases/log_2d/graph.nnef b/tests/python/frontend/nnef/cases/log_2d/graph.nnef new file mode 100644 index 000000000000..337102ab8e78 --- /dev/null +++ b/tests/python/frontend/nnef/cases/log_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = log(input); +} diff --git a/tests/python/frontend/nnef/cases/log_4d/graph.nnef b/tests/python/frontend/nnef/cases/log_4d/graph.nnef new file mode 100644 index 000000000000..36975b9bd94f --- /dev/null +++ b/tests/python/frontend/nnef/cases/log_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = log(input); +} diff --git a/tests/python/frontend/nnef/cases/lt_2d/graph.nnef b/tests/python/frontend/nnef/cases/lt_2d/graph.nnef new file mode 100644 index 000000000000..7ef77d6be0a8 --- /dev/null +++ b/tests/python/frontend/nnef/cases/lt_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = lt(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/lt_4d/graph.nnef b/tests/python/frontend/nnef/cases/lt_4d/graph.nnef new file mode 100644 index 000000000000..6cdb2285dd14 --- /dev/null +++ b/tests/python/frontend/nnef/cases/lt_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = lt(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/lt_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/lt_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..7fb5764ec4b3 --- /dev/null +++ b/tests/python/frontend/nnef/cases/lt_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = lt(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/lt_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/lt_4d_constant/graph.nnef new file mode 100644 index 000000000000..a4dce93a6ccb --- /dev/null +++ b/tests/python/frontend/nnef/cases/lt_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = lt(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/matmul_2d/graph.nnef b/tests/python/frontend/nnef/cases/matmul_2d/graph.nnef new file mode 100644 index 000000000000..8586028c3deb --- /dev/null +++ b/tests/python/frontend/nnef/cases/matmul_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [16,4]); + output = matmul(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/matmul_2d_transpose/graph.nnef b/tests/python/frontend/nnef/cases/matmul_2d_transpose/graph.nnef new file mode 100644 index 000000000000..4cb78911ea2d --- /dev/null +++ b/tests/python/frontend/nnef/cases/matmul_2d_transpose/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = matmul(input1, input2, transposeA = true, transposeB = false); +} diff --git a/tests/python/frontend/nnef/cases/matmul_4d/graph.nnef b/tests/python/frontend/nnef/cases/matmul_4d/graph.nnef new file mode 100644 index 000000000000..5e3263458368 --- /dev/null +++ b/tests/python/frontend/nnef/cases/matmul_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = matmul(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/matmul_4d_transpose/graph.nnef b/tests/python/frontend/nnef/cases/matmul_4d_transpose/graph.nnef new file mode 100644 index 000000000000..1b24655bf344 --- /dev/null +++ b/tests/python/frontend/nnef/cases/matmul_4d_transpose/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = matmul(input1, input2, transposeA = true, transposeB = false); +} diff --git a/tests/python/frontend/nnef/cases/max_2d/graph.nnef b/tests/python/frontend/nnef/cases/max_2d/graph.nnef new file mode 100644 index 000000000000..ae302f3ae735 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = max(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/max_4d/graph.nnef b/tests/python/frontend/nnef/cases/max_4d/graph.nnef new file mode 100644 index 000000000000..dc560b1a7020 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = max(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/max_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/max_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..fe7d4ce862c0 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = max(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/max_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/max_4d_constant/graph.nnef new file mode 100644 index 000000000000..c1b61662daef --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/max_pool1x1/graph.nnef b/tests/python/frontend/nnef/cases/max_pool1x1/graph.nnef new file mode 100644 index 000000000000..1c74044e8890 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_pool1x1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max_pool(input, size = [1,1,1,1], stride = [1,1,2,2], border = 'ignore'); +} diff --git a/tests/python/frontend/nnef/cases/max_pool2x2/graph.nnef b/tests/python/frontend/nnef/cases/max_pool2x2/graph.nnef new file mode 100644 index 000000000000..9df88946ab34 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_pool2x2/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max_pool(input, size = [1,1,2,2], stride = [1,1,2,2], border = 'ignore'); +} diff --git a/tests/python/frontend/nnef/cases/max_pool3x3/graph.nnef b/tests/python/frontend/nnef/cases/max_pool3x3/graph.nnef new file mode 100644 index 000000000000..2413faa521d6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_pool3x3/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max_pool(input, size = [1,1,3,3], stride = [1,1,2,2], border = 'ignore'); +} diff --git a/tests/python/frontend/nnef/cases/max_pool3x3_constant-border/graph.nnef b/tests/python/frontend/nnef/cases/max_pool3x3_constant-border/graph.nnef new file mode 100644 index 000000000000..b0221c3c7d33 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_pool3x3_constant-border/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max_pool(input, size = [1,1,3,3], stride = [1,1,2,2], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/max_pool3x3_pad0-0/graph.nnef b/tests/python/frontend/nnef/cases/max_pool3x3_pad0-0/graph.nnef new file mode 100644 index 000000000000..b59b0f166df0 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_pool3x3_pad0-0/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max_pool(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (0,0), (0,0)], border = 'ignore'); +} diff --git a/tests/python/frontend/nnef/cases/max_pool3x3_pad0-1/graph.nnef b/tests/python/frontend/nnef/cases/max_pool3x3_pad0-1/graph.nnef new file mode 100644 index 000000000000..efcbfb924162 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_pool3x3_pad0-1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max_pool(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (0,1), (0,1)], border = 'ignore'); +} diff --git a/tests/python/frontend/nnef/cases/max_pool3x3_pad1-0/graph.nnef b/tests/python/frontend/nnef/cases/max_pool3x3_pad1-0/graph.nnef new file mode 100644 index 000000000000..ccb0db8245a5 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_pool3x3_pad1-0/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max_pool(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (1,0), (1,0)], border = 'ignore'); +} diff --git a/tests/python/frontend/nnef/cases/max_pool3x3_pad1-1/graph.nnef b/tests/python/frontend/nnef/cases/max_pool3x3_pad1-1/graph.nnef new file mode 100644 index 000000000000..189d708d2769 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_pool3x3_pad1-1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max_pool(input, size = [1,1,3,3], stride = [1,1,2,2], padding = [(0,0), (0,0), (1,1), (1,1)], border = 'ignore'); +} diff --git a/tests/python/frontend/nnef/cases/max_pool3x3_stride1x1/graph.nnef b/tests/python/frontend/nnef/cases/max_pool3x3_stride1x1/graph.nnef new file mode 100644 index 000000000000..26513627ef81 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_pool3x3_stride1x1/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max_pool(input, size = [1,1,3,3], stride = [1,1,1,1], border = 'ignore'); +} diff --git a/tests/python/frontend/nnef/cases/max_reduce_channel/graph.nnef b/tests/python/frontend/nnef/cases/max_reduce_channel/graph.nnef new file mode 100644 index 000000000000..38f5bfd2e342 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_reduce_channel/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max_reduce(input, axes = [1]); +} diff --git a/tests/python/frontend/nnef/cases/max_reduce_spatial/graph.nnef b/tests/python/frontend/nnef/cases/max_reduce_spatial/graph.nnef new file mode 100644 index 000000000000..18d0bf319b67 --- /dev/null +++ b/tests/python/frontend/nnef/cases/max_reduce_spatial/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = max_reduce(input, axes = [2,3]); +} diff --git a/tests/python/frontend/nnef/cases/mean_reduce_channel/graph.nnef b/tests/python/frontend/nnef/cases/mean_reduce_channel/graph.nnef new file mode 100644 index 000000000000..06821a94ed9e --- /dev/null +++ b/tests/python/frontend/nnef/cases/mean_reduce_channel/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = mean_reduce(input, axes = [1]); +} diff --git a/tests/python/frontend/nnef/cases/mean_reduce_spatial/graph.nnef b/tests/python/frontend/nnef/cases/mean_reduce_spatial/graph.nnef new file mode 100644 index 000000000000..42219e6e62e2 --- /dev/null +++ b/tests/python/frontend/nnef/cases/mean_reduce_spatial/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = mean_reduce(input, axes = [2,3]); +} diff --git a/tests/python/frontend/nnef/cases/min_2d/graph.nnef b/tests/python/frontend/nnef/cases/min_2d/graph.nnef new file mode 100644 index 000000000000..4c96becd5959 --- /dev/null +++ b/tests/python/frontend/nnef/cases/min_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = min(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/min_4d/graph.nnef b/tests/python/frontend/nnef/cases/min_4d/graph.nnef new file mode 100644 index 000000000000..bbc28df23314 --- /dev/null +++ b/tests/python/frontend/nnef/cases/min_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = min(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/min_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/min_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..7befa71d83b3 --- /dev/null +++ b/tests/python/frontend/nnef/cases/min_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = min(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/min_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/min_4d_constant/graph.nnef new file mode 100644 index 000000000000..5e19520c498b --- /dev/null +++ b/tests/python/frontend/nnef/cases/min_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = min(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/min_reduce_channel/graph.nnef b/tests/python/frontend/nnef/cases/min_reduce_channel/graph.nnef new file mode 100644 index 000000000000..a2ad6680ae4d --- /dev/null +++ b/tests/python/frontend/nnef/cases/min_reduce_channel/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = min_reduce(input, axes = [1]); +} diff --git a/tests/python/frontend/nnef/cases/min_reduce_spatial/graph.nnef b/tests/python/frontend/nnef/cases/min_reduce_spatial/graph.nnef new file mode 100644 index 000000000000..08f0249c3b76 --- /dev/null +++ b/tests/python/frontend/nnef/cases/min_reduce_spatial/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = min_reduce(input, axes = [2,3]); +} diff --git a/tests/python/frontend/nnef/cases/mul_2d/graph.nnef b/tests/python/frontend/nnef/cases/mul_2d/graph.nnef new file mode 100644 index 000000000000..5d5720377c6e --- /dev/null +++ b/tests/python/frontend/nnef/cases/mul_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = mul(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/mul_4d/graph.nnef b/tests/python/frontend/nnef/cases/mul_4d/graph.nnef new file mode 100644 index 000000000000..ac78a91a322d --- /dev/null +++ b/tests/python/frontend/nnef/cases/mul_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = mul(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/mul_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/mul_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..682f4e893e7b --- /dev/null +++ b/tests/python/frontend/nnef/cases/mul_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = mul(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/mul_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/mul_4d_constant/graph.nnef new file mode 100644 index 000000000000..1f6bdbb69732 --- /dev/null +++ b/tests/python/frontend/nnef/cases/mul_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = mul(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/ne_2d/graph.nnef b/tests/python/frontend/nnef/cases/ne_2d/graph.nnef new file mode 100644 index 000000000000..6a8ea3e3ee7c --- /dev/null +++ b/tests/python/frontend/nnef/cases/ne_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = ne(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/ne_4d/graph.nnef b/tests/python/frontend/nnef/cases/ne_4d/graph.nnef new file mode 100644 index 000000000000..7dee4ad2f22c --- /dev/null +++ b/tests/python/frontend/nnef/cases/ne_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = ne(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/ne_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/ne_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..7e619bdc317a --- /dev/null +++ b/tests/python/frontend/nnef/cases/ne_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = ne(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/ne_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/ne_4d_constant/graph.nnef new file mode 100644 index 000000000000..7b0d7720eb2d --- /dev/null +++ b/tests/python/frontend/nnef/cases/ne_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = ne(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/nearest_downsample/graph.nnef b/tests/python/frontend/nnef/cases/nearest_downsample/graph.nnef new file mode 100644 index 000000000000..8b1443c165e4 --- /dev/null +++ b/tests/python/frontend/nnef/cases/nearest_downsample/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = nearest_downsample(input, factor = [2,2]); +} diff --git a/tests/python/frontend/nnef/cases/nearest_upsample/graph.nnef b/tests/python/frontend/nnef/cases/nearest_upsample/graph.nnef new file mode 100644 index 000000000000..34f6fe49e45a --- /dev/null +++ b/tests/python/frontend/nnef/cases/nearest_upsample/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = nearest_upsample(input, factor = [2,2]); +} diff --git a/tests/python/frontend/nnef/cases/neg_2d/graph.nnef b/tests/python/frontend/nnef/cases/neg_2d/graph.nnef new file mode 100644 index 000000000000..b25f97cef465 --- /dev/null +++ b/tests/python/frontend/nnef/cases/neg_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = neg(input); +} diff --git a/tests/python/frontend/nnef/cases/neg_4d/graph.nnef b/tests/python/frontend/nnef/cases/neg_4d/graph.nnef new file mode 100644 index 000000000000..8c752d747860 --- /dev/null +++ b/tests/python/frontend/nnef/cases/neg_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = neg(input); +} diff --git a/tests/python/frontend/nnef/cases/not_2d/graph.nnef b/tests/python/frontend/nnef/cases/not_2d/graph.nnef new file mode 100644 index 000000000000..e7885f852a32 --- /dev/null +++ b/tests/python/frontend/nnef/cases/not_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = not(input); +} diff --git a/tests/python/frontend/nnef/cases/not_4d/graph.nnef b/tests/python/frontend/nnef/cases/not_4d/graph.nnef new file mode 100644 index 000000000000..7544fb1394d3 --- /dev/null +++ b/tests/python/frontend/nnef/cases/not_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = not(input); +} diff --git a/tests/python/frontend/nnef/cases/or_2d/graph.nnef b/tests/python/frontend/nnef/cases/or_2d/graph.nnef new file mode 100644 index 000000000000..52ec1fdbdc20 --- /dev/null +++ b/tests/python/frontend/nnef/cases/or_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = or(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/or_4d/graph.nnef b/tests/python/frontend/nnef/cases/or_4d/graph.nnef new file mode 100644 index 000000000000..a799707f4f80 --- /dev/null +++ b/tests/python/frontend/nnef/cases/or_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = or(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/or_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/or_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..74ed77cf9587 --- /dev/null +++ b/tests/python/frontend/nnef/cases/or_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = or(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/or_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/or_4d_constant/graph.nnef new file mode 100644 index 000000000000..100aedaf2487 --- /dev/null +++ b/tests/python/frontend/nnef/cases/or_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = or(input, false); +} diff --git a/tests/python/frontend/nnef/cases/pad_0-1_constant/graph.nnef b/tests/python/frontend/nnef/cases/pad_0-1_constant/graph.nnef new file mode 100644 index 000000000000..89bf37e06c5c --- /dev/null +++ b/tests/python/frontend/nnef/cases/pad_0-1_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [1,16,32,32]); + output = pad(input, padding = [(0,0), (0,0), (0,1), (0,1)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/pad_0-1_reflect/graph.nnef b/tests/python/frontend/nnef/cases/pad_0-1_reflect/graph.nnef new file mode 100644 index 000000000000..2deca2e42e5f --- /dev/null +++ b/tests/python/frontend/nnef/cases/pad_0-1_reflect/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [1,16,32,32]); + output = pad(input, padding = [(0,0), (0,0), (0,1), (0,1)], border = 'reflect'); +} diff --git a/tests/python/frontend/nnef/cases/pad_0-1_replicate/graph.nnef b/tests/python/frontend/nnef/cases/pad_0-1_replicate/graph.nnef new file mode 100644 index 000000000000..544c5704db18 --- /dev/null +++ b/tests/python/frontend/nnef/cases/pad_0-1_replicate/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [1,16,32,32]); + output = pad(input, padding = [(0,0), (0,0), (0,1), (0,1)], border = 'replicate'); +} diff --git a/tests/python/frontend/nnef/cases/pad_1-0_constant/graph.nnef b/tests/python/frontend/nnef/cases/pad_1-0_constant/graph.nnef new file mode 100644 index 000000000000..5b36fc86d2fc --- /dev/null +++ b/tests/python/frontend/nnef/cases/pad_1-0_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [1,16,32,32]); + output = pad(input, padding = [(0,0), (0,0), (1,0), (1,0)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/pad_1-0_reflect/graph.nnef b/tests/python/frontend/nnef/cases/pad_1-0_reflect/graph.nnef new file mode 100644 index 000000000000..d12aa2270240 --- /dev/null +++ b/tests/python/frontend/nnef/cases/pad_1-0_reflect/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [1,16,32,32]); + output = pad(input, padding = [(0,0), (0,0), (1,0), (1,0)], border = 'reflect'); +} diff --git a/tests/python/frontend/nnef/cases/pad_1-0_replicate/graph.nnef b/tests/python/frontend/nnef/cases/pad_1-0_replicate/graph.nnef new file mode 100644 index 000000000000..d527f2e6fa23 --- /dev/null +++ b/tests/python/frontend/nnef/cases/pad_1-0_replicate/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [1,16,32,32]); + output = pad(input, padding = [(0,0), (0,0), (1,0), (1,0)], border = 'replicate'); +} diff --git a/tests/python/frontend/nnef/cases/pad_1-1_constant/graph.nnef b/tests/python/frontend/nnef/cases/pad_1-1_constant/graph.nnef new file mode 100644 index 000000000000..c5118096c957 --- /dev/null +++ b/tests/python/frontend/nnef/cases/pad_1-1_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [1,16,32,32]); + output = pad(input, padding = [(0,0), (0,0), (1,1), (1,1)], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/pad_1-1_reflect/graph.nnef b/tests/python/frontend/nnef/cases/pad_1-1_reflect/graph.nnef new file mode 100644 index 000000000000..fa2709ea354b --- /dev/null +++ b/tests/python/frontend/nnef/cases/pad_1-1_reflect/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [1,16,32,32]); + output = pad(input, padding = [(0,0), (0,0), (1,1), (1,1)], border = 'reflect'); +} diff --git a/tests/python/frontend/nnef/cases/pad_1-1_replicate/graph.nnef b/tests/python/frontend/nnef/cases/pad_1-1_replicate/graph.nnef new file mode 100644 index 000000000000..dcdead991e9a --- /dev/null +++ b/tests/python/frontend/nnef/cases/pad_1-1_replicate/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [1,16,32,32]); + output = pad(input, padding = [(0,0), (0,0), (1,1), (1,1)], border = 'replicate'); +} diff --git a/tests/python/frontend/nnef/cases/pow_2d/graph.nnef b/tests/python/frontend/nnef/cases/pow_2d/graph.nnef new file mode 100644 index 000000000000..b07c5b61a573 --- /dev/null +++ b/tests/python/frontend/nnef/cases/pow_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = pow(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/pow_4d/graph.nnef b/tests/python/frontend/nnef/cases/pow_4d/graph.nnef new file mode 100644 index 000000000000..f81284811043 --- /dev/null +++ b/tests/python/frontend/nnef/cases/pow_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = pow(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/pow_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/pow_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..664e8381eed5 --- /dev/null +++ b/tests/python/frontend/nnef/cases/pow_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = pow(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/pow_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/pow_4d_constant/graph.nnef new file mode 100644 index 000000000000..2d3ed54b01b5 --- /dev/null +++ b/tests/python/frontend/nnef/cases/pow_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = pow(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/prelu/graph.nnef b/tests/python/frontend/nnef/cases/prelu/graph.nnef new file mode 100644 index 000000000000..04fe7c0a3464 --- /dev/null +++ b/tests/python/frontend/nnef/cases/prelu/graph.nnef @@ -0,0 +1,11 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [16,16,32,32]); + filter = constant(shape = [16,1,1,1], value = [1.0]); + bias = constant(shape = [1,16], value = [0.0]); + conv = conv(input1, filter, bias, groups = 0); + input2 = external(shape = [16]); + output = prelu(conv, input2); +} diff --git a/tests/python/frontend/nnef/cases/prelu_2d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/prelu_2d_standalone/graph.nnef new file mode 100644 index 000000000000..1cbe5da61515 --- /dev/null +++ b/tests/python/frontend/nnef/cases/prelu_2d_standalone/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [16,16]); + input2 = external(shape = [16]); + output = prelu(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/prelu_4d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/prelu_4d_standalone/graph.nnef new file mode 100644 index 000000000000..abc6613b2ea6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/prelu_4d_standalone/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [16,16,32,32]); + input2 = external(shape = [16]); + output = prelu(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/rcp_2d/graph.nnef b/tests/python/frontend/nnef/cases/rcp_2d/graph.nnef new file mode 100644 index 000000000000..aa9db7a80291 --- /dev/null +++ b/tests/python/frontend/nnef/cases/rcp_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = rcp(input); +} diff --git a/tests/python/frontend/nnef/cases/rcp_4d/graph.nnef b/tests/python/frontend/nnef/cases/rcp_4d/graph.nnef new file mode 100644 index 000000000000..f5784549bec7 --- /dev/null +++ b/tests/python/frontend/nnef/cases/rcp_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = rcp(input); +} diff --git a/tests/python/frontend/nnef/cases/relu/graph.nnef b/tests/python/frontend/nnef/cases/relu/graph.nnef new file mode 100644 index 000000000000..08a81ee886ee --- /dev/null +++ b/tests/python/frontend/nnef/cases/relu/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = constant(shape = [16,1,1,1], value = [1.0]); + bias = constant(shape = [1,16], value = [0.0]); + conv = conv(input, filter, bias, groups = 0); + output = relu(conv); +} diff --git a/tests/python/frontend/nnef/cases/relu_2d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/relu_2d_standalone/graph.nnef new file mode 100644 index 000000000000..fdba3f74bff6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/relu_2d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = relu(input); +} diff --git a/tests/python/frontend/nnef/cases/relu_4d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/relu_4d_standalone/graph.nnef new file mode 100644 index 000000000000..347cf9665ae3 --- /dev/null +++ b/tests/python/frontend/nnef/cases/relu_4d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = relu(input); +} diff --git a/tests/python/frontend/nnef/cases/reshape_flatten/graph.nnef b/tests/python/frontend/nnef/cases/reshape_flatten/graph.nnef new file mode 100644 index 000000000000..1d39de4b26e6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/reshape_flatten/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = reshape(input, shape = [0,-1]); +} diff --git a/tests/python/frontend/nnef/cases/reshape_partial/graph.nnef b/tests/python/frontend/nnef/cases/reshape_partial/graph.nnef new file mode 100644 index 000000000000..50f983e266c6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/reshape_partial/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [2,3,3,3,2]); + output = reshape(input, shape = [0,-1], axis_start = 1, axis_count = 3); +} diff --git a/tests/python/frontend/nnef/cases/reshape_squeeze/graph.nnef b/tests/python/frontend/nnef/cases/reshape_squeeze/graph.nnef new file mode 100644 index 000000000000..b8471424234a --- /dev/null +++ b/tests/python/frontend/nnef/cases/reshape_squeeze/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,1,1]); + output = reshape(input, shape = [4,16]); +} diff --git a/tests/python/frontend/nnef/cases/rms_pool3x3/graph.nnef b/tests/python/frontend/nnef/cases/rms_pool3x3/graph.nnef new file mode 100644 index 000000000000..bd3972de2ed1 --- /dev/null +++ b/tests/python/frontend/nnef/cases/rms_pool3x3/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = rms_pool(input, size = [1,1,3,3], stride = [1,1,2,2], border = 'constant'); +} diff --git a/tests/python/frontend/nnef/cases/round_2d/graph.nnef b/tests/python/frontend/nnef/cases/round_2d/graph.nnef new file mode 100644 index 000000000000..6dcc91eb50a1 --- /dev/null +++ b/tests/python/frontend/nnef/cases/round_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = round(input); +} diff --git a/tests/python/frontend/nnef/cases/round_4d/graph.nnef b/tests/python/frontend/nnef/cases/round_4d/graph.nnef new file mode 100644 index 000000000000..bbbdb1bea377 --- /dev/null +++ b/tests/python/frontend/nnef/cases/round_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = round(input); +} diff --git a/tests/python/frontend/nnef/cases/rsqr_2d/graph.nnef b/tests/python/frontend/nnef/cases/rsqr_2d/graph.nnef new file mode 100644 index 000000000000..385ec228b1c6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/rsqr_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = rsqr(input); +} diff --git a/tests/python/frontend/nnef/cases/rsqr_4d/graph.nnef b/tests/python/frontend/nnef/cases/rsqr_4d/graph.nnef new file mode 100644 index 000000000000..a462d27572da --- /dev/null +++ b/tests/python/frontend/nnef/cases/rsqr_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = rsqr(input); +} diff --git a/tests/python/frontend/nnef/cases/rsqrt_2d/graph.nnef b/tests/python/frontend/nnef/cases/rsqrt_2d/graph.nnef new file mode 100644 index 000000000000..f3503cfee649 --- /dev/null +++ b/tests/python/frontend/nnef/cases/rsqrt_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = rsqrt(input); +} diff --git a/tests/python/frontend/nnef/cases/rsqrt_4d/graph.nnef b/tests/python/frontend/nnef/cases/rsqrt_4d/graph.nnef new file mode 100644 index 000000000000..76583e05c7f6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/rsqrt_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = rsqrt(input); +} diff --git a/tests/python/frontend/nnef/cases/select_2d/graph.nnef b/tests/python/frontend/nnef/cases/select_2d/graph.nnef new file mode 100644 index 000000000000..a771def8b45e --- /dev/null +++ b/tests/python/frontend/nnef/cases/select_2d/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( cond, input1, input2 ) -> ( output ) +{ + cond = external(shape = [4,16]); + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = select(cond, input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/select_2d_false/graph.nnef b/tests/python/frontend/nnef/cases/select_2d_false/graph.nnef new file mode 100644 index 000000000000..44669bc31ca1 --- /dev/null +++ b/tests/python/frontend/nnef/cases/select_2d_false/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = select(false, input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/select_2d_true/graph.nnef b/tests/python/frontend/nnef/cases/select_2d_true/graph.nnef new file mode 100644 index 000000000000..6df5598fa1cc --- /dev/null +++ b/tests/python/frontend/nnef/cases/select_2d_true/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = select(true, input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/select_4d/graph.nnef b/tests/python/frontend/nnef/cases/select_4d/graph.nnef new file mode 100644 index 000000000000..06ae030eb933 --- /dev/null +++ b/tests/python/frontend/nnef/cases/select_4d/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( cond, input1, input2 ) -> ( output ) +{ + cond = external(shape = [4,16,32,32]); + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = select(cond, input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/select_4d_false/graph.nnef b/tests/python/frontend/nnef/cases/select_4d_false/graph.nnef new file mode 100644 index 000000000000..d2f4f45b7177 --- /dev/null +++ b/tests/python/frontend/nnef/cases/select_4d_false/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = select(false, input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/select_4d_true/graph.nnef b/tests/python/frontend/nnef/cases/select_4d_true/graph.nnef new file mode 100644 index 000000000000..b6437d595376 --- /dev/null +++ b/tests/python/frontend/nnef/cases/select_4d_true/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = select(true, input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/selu/graph.nnef b/tests/python/frontend/nnef/cases/selu/graph.nnef new file mode 100644 index 000000000000..cf08d103a23a --- /dev/null +++ b/tests/python/frontend/nnef/cases/selu/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16,32,32]); + filter = constant(shape = [16,1,1,1], value = [1.0]); + bias = constant(shape = [1,16], value = [0.0]); + conv = conv(input, filter, bias, groups = 0); + output = selu(conv); +} diff --git a/tests/python/frontend/nnef/cases/selu_2d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/selu_2d_standalone/graph.nnef new file mode 100644 index 000000000000..cfe55aa9ca32 --- /dev/null +++ b/tests/python/frontend/nnef/cases/selu_2d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16]); + output = selu(input); +} diff --git a/tests/python/frontend/nnef/cases/selu_4d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/selu_4d_standalone/graph.nnef new file mode 100644 index 000000000000..c8d6bd6b6d2b --- /dev/null +++ b/tests/python/frontend/nnef/cases/selu_4d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16,32,32]); + output = selu(input); +} diff --git a/tests/python/frontend/nnef/cases/separable_conv3x3/graph.nnef b/tests/python/frontend/nnef/cases/separable_conv3x3/graph.nnef new file mode 100644 index 000000000000..30c722ba6062 --- /dev/null +++ b/tests/python/frontend/nnef/cases/separable_conv3x3/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + plane_filter = variable(shape = [8,1,3,3], label = 'plane_filter'); + point_filter = variable(shape = [16,8,1,1], label = 'point_filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = separable_conv(input, plane_filter, point_filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/separable_conv3x3_with_attrs/graph.nnef b/tests/python/frontend/nnef/cases/separable_conv3x3_with_attrs/graph.nnef new file mode 100644 index 000000000000..7471ad7fce3c --- /dev/null +++ b/tests/python/frontend/nnef/cases/separable_conv3x3_with_attrs/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + plane_filter = variable(shape = [8,1,3,3], label = 'plane_filter'); + point_filter = variable(shape = [16,8,1,1], label = 'point_filter'); + output = separable_conv(input, plane_filter, point_filter, padding = [(0,1), (0,1)], stride = [2,2]); +} diff --git a/tests/python/frontend/nnef/cases/separable_conv5x5/graph.nnef b/tests/python/frontend/nnef/cases/separable_conv5x5/graph.nnef new file mode 100644 index 000000000000..07903799cdec --- /dev/null +++ b/tests/python/frontend/nnef/cases/separable_conv5x5/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,8,32,32]); + plane_filter = variable(shape = [8,1,5,5], label = 'plane_filter'); + point_filter = variable(shape = [16,8,1,1], label = 'point_filter'); + bias = variable(shape = [1,16], label = 'bias'); + output = separable_conv(input, plane_filter, point_filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/separable_deconv3x3/graph.nnef b/tests/python/frontend/nnef/cases/separable_deconv3x3/graph.nnef new file mode 100644 index 000000000000..1d830b6bba5e --- /dev/null +++ b/tests/python/frontend/nnef/cases/separable_deconv3x3/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + plane_filter = variable(shape = [8,1,3,3], label = 'plane_filter'); + point_filter = variable(shape = [16,8,1,1], label = 'point_filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = separable_deconv(input, plane_filter, point_filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/separable_deconv3x3_with_attrs/graph.nnef b/tests/python/frontend/nnef/cases/separable_deconv3x3_with_attrs/graph.nnef new file mode 100644 index 000000000000..331f733d3195 --- /dev/null +++ b/tests/python/frontend/nnef/cases/separable_deconv3x3_with_attrs/graph.nnef @@ -0,0 +1,9 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + plane_filter = variable(shape = [8,1,3,3], label = 'plane_filter'); + point_filter = variable(shape = [16,8,1,1], label = 'point_filter'); + output = separable_deconv(input, plane_filter, point_filter, padding = [(0,1), (0,1)], stride = [2,2]); +} diff --git a/tests/python/frontend/nnef/cases/separable_deconv5x5/graph.nnef b/tests/python/frontend/nnef/cases/separable_deconv5x5/graph.nnef new file mode 100644 index 000000000000..f115a9ecc105 --- /dev/null +++ b/tests/python/frontend/nnef/cases/separable_deconv5x5/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + plane_filter = variable(shape = [8,1,5,5], label = 'plane_filter'); + point_filter = variable(shape = [16,8,1,1], label = 'point_filter'); + bias = variable(shape = [1,8], label = 'bias'); + output = separable_deconv(input, plane_filter, point_filter, bias); +} diff --git a/tests/python/frontend/nnef/cases/sigmoid/graph.nnef b/tests/python/frontend/nnef/cases/sigmoid/graph.nnef new file mode 100644 index 000000000000..83eb12e4cdb3 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sigmoid/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = constant(shape = [16,1,1,1], value = [1.0]); + bias = constant(shape = [1,16], value = [0.0]); + conv = conv(input, filter, bias, groups = 0); + output = sigmoid(conv); +} diff --git a/tests/python/frontend/nnef/cases/sigmoid_2d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/sigmoid_2d_standalone/graph.nnef new file mode 100644 index 000000000000..64ac4e44a611 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sigmoid_2d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = sigmoid(input); +} diff --git a/tests/python/frontend/nnef/cases/sigmoid_4d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/sigmoid_4d_standalone/graph.nnef new file mode 100644 index 000000000000..80ddf8208c6a --- /dev/null +++ b/tests/python/frontend/nnef/cases/sigmoid_4d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = sigmoid(input); +} diff --git a/tests/python/frontend/nnef/cases/sign_2d/graph.nnef b/tests/python/frontend/nnef/cases/sign_2d/graph.nnef new file mode 100644 index 000000000000..77f0bf039bdd --- /dev/null +++ b/tests/python/frontend/nnef/cases/sign_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = sign(input); +} diff --git a/tests/python/frontend/nnef/cases/sign_4d/graph.nnef b/tests/python/frontend/nnef/cases/sign_4d/graph.nnef new file mode 100644 index 000000000000..1e0e429c4a52 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sign_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = sign(input); +} diff --git a/tests/python/frontend/nnef/cases/silu/graph.nnef b/tests/python/frontend/nnef/cases/silu/graph.nnef new file mode 100644 index 000000000000..b3209da214c7 --- /dev/null +++ b/tests/python/frontend/nnef/cases/silu/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16,32,32]); + filter = constant(shape = [16,1,1,1], value = [1.0]); + bias = constant(shape = [1,16], value = [0.0]); + conv = conv(input, filter, bias, groups = 0); + output = silu(conv); +} diff --git a/tests/python/frontend/nnef/cases/silu_2d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/silu_2d_standalone/graph.nnef new file mode 100644 index 000000000000..c307794e1c37 --- /dev/null +++ b/tests/python/frontend/nnef/cases/silu_2d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16]); + output = silu(input); +} diff --git a/tests/python/frontend/nnef/cases/silu_4d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/silu_4d_standalone/graph.nnef new file mode 100644 index 000000000000..a36fa0e18c58 --- /dev/null +++ b/tests/python/frontend/nnef/cases/silu_4d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,16,32,32]); + output = silu(input); +} diff --git a/tests/python/frontend/nnef/cases/sin_2d/graph.nnef b/tests/python/frontend/nnef/cases/sin_2d/graph.nnef new file mode 100644 index 000000000000..3fb5738babd4 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sin_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = sin(input); +} diff --git a/tests/python/frontend/nnef/cases/sin_4d/graph.nnef b/tests/python/frontend/nnef/cases/sin_4d/graph.nnef new file mode 100644 index 000000000000..ce3cffc0ba30 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sin_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = sin(input); +} diff --git a/tests/python/frontend/nnef/cases/sinh_2d/graph.nnef b/tests/python/frontend/nnef/cases/sinh_2d/graph.nnef new file mode 100644 index 000000000000..2c00c7ab9ca5 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sinh_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = sinh(input); +} diff --git a/tests/python/frontend/nnef/cases/sinh_4d/graph.nnef b/tests/python/frontend/nnef/cases/sinh_4d/graph.nnef new file mode 100644 index 000000000000..a7df179fa543 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sinh_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = sinh(input); +} diff --git a/tests/python/frontend/nnef/cases/slice/graph.nnef b/tests/python/frontend/nnef/cases/slice/graph.nnef new file mode 100644 index 000000000000..52f7ac48ab35 --- /dev/null +++ b/tests/python/frontend/nnef/cases/slice/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = slice(input, axes = [2,3], begin = [1,2], end = [-1,-2]); +} diff --git a/tests/python/frontend/nnef/cases/slice_strides/graph.nnef b/tests/python/frontend/nnef/cases/slice_strides/graph.nnef new file mode 100644 index 000000000000..1f35e7e1758b --- /dev/null +++ b/tests/python/frontend/nnef/cases/slice_strides/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = slice(input, axes = [1,2,3], begin = [5,16,2], end = [1,4,-1], stride = [-1,-1,1]); +} diff --git a/tests/python/frontend/nnef/cases/softmax/graph.nnef b/tests/python/frontend/nnef/cases/softmax/graph.nnef new file mode 100644 index 000000000000..ab0d00b1a27a --- /dev/null +++ b/tests/python/frontend/nnef/cases/softmax/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = constant(shape = [16,1,1,1], value = [1.0]); + bias = constant(shape = [1,16], value = [0.0]); + conv = conv(input, filter, bias, groups = 0); + output = softmax(conv, axes = [1]); +} diff --git a/tests/python/frontend/nnef/cases/softmax_2d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/softmax_2d_standalone/graph.nnef new file mode 100644 index 000000000000..76e2410a695e --- /dev/null +++ b/tests/python/frontend/nnef/cases/softmax_2d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = softmax(input); +} diff --git a/tests/python/frontend/nnef/cases/softmax_4d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/softmax_4d_standalone/graph.nnef new file mode 100644 index 000000000000..0eb2191f81eb --- /dev/null +++ b/tests/python/frontend/nnef/cases/softmax_4d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = softmax(input); +} diff --git a/tests/python/frontend/nnef/cases/softplus/graph.nnef b/tests/python/frontend/nnef/cases/softplus/graph.nnef new file mode 100644 index 000000000000..9c4c1f15b7c4 --- /dev/null +++ b/tests/python/frontend/nnef/cases/softplus/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = constant(shape = [16,1,1,1], value = [1.0]); + bias = constant(shape = [1,16], value = [0.0]); + conv = conv(input, filter, bias, groups = 0); + output = softplus(conv); +} diff --git a/tests/python/frontend/nnef/cases/softplus_2d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/softplus_2d_standalone/graph.nnef new file mode 100644 index 000000000000..fca49a128dfd --- /dev/null +++ b/tests/python/frontend/nnef/cases/softplus_2d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = softplus(input); +} diff --git a/tests/python/frontend/nnef/cases/softplus_4d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/softplus_4d_standalone/graph.nnef new file mode 100644 index 000000000000..14972ff7530d --- /dev/null +++ b/tests/python/frontend/nnef/cases/softplus_4d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = softplus(input); +} diff --git a/tests/python/frontend/nnef/cases/split_channel/graph.nnef b/tests/python/frontend/nnef/cases/split_channel/graph.nnef new file mode 100644 index 000000000000..ae48d85891d7 --- /dev/null +++ b/tests/python/frontend/nnef/cases/split_channel/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output1, output2 ) +{ + input = external(shape = [4,16,32,32]); + [output1, output2] = split(input, axis = 1, ratios = [1,1]); +} diff --git a/tests/python/frontend/nnef/cases/split_unbalanced/graph.nnef b/tests/python/frontend/nnef/cases/split_unbalanced/graph.nnef new file mode 100644 index 000000000000..d3dda048014c --- /dev/null +++ b/tests/python/frontend/nnef/cases/split_unbalanced/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output1, output2, output3 ) +{ + input = external(shape = [4,32,3]); + [output1, output2, output3] = split(input, axis = 1, ratios = [3,1,4]); +} diff --git a/tests/python/frontend/nnef/cases/sqr_2d/graph.nnef b/tests/python/frontend/nnef/cases/sqr_2d/graph.nnef new file mode 100644 index 000000000000..b1b3fe4848a8 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sqr_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = sqr(input); +} diff --git a/tests/python/frontend/nnef/cases/sqr_4d/graph.nnef b/tests/python/frontend/nnef/cases/sqr_4d/graph.nnef new file mode 100644 index 000000000000..297c1f264e34 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sqr_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = sqr(input); +} diff --git a/tests/python/frontend/nnef/cases/sqrt_2d/graph.nnef b/tests/python/frontend/nnef/cases/sqrt_2d/graph.nnef new file mode 100644 index 000000000000..5c00df461686 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sqrt_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = sqrt(input); +} diff --git a/tests/python/frontend/nnef/cases/sqrt_4d/graph.nnef b/tests/python/frontend/nnef/cases/sqrt_4d/graph.nnef new file mode 100644 index 000000000000..03d5845d43dc --- /dev/null +++ b/tests/python/frontend/nnef/cases/sqrt_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = sqrt(input); +} diff --git a/tests/python/frontend/nnef/cases/squeeze_spatial/graph.nnef b/tests/python/frontend/nnef/cases/squeeze_spatial/graph.nnef new file mode 100644 index 000000000000..da182b5fb217 --- /dev/null +++ b/tests/python/frontend/nnef/cases/squeeze_spatial/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,1,1]); + output = squeeze(input, axes = [2,3]); +} diff --git a/tests/python/frontend/nnef/cases/stack/graph.nnef b/tests/python/frontend/nnef/cases/stack/graph.nnef new file mode 100644 index 000000000000..aaf3e0c3b92e --- /dev/null +++ b/tests/python/frontend/nnef/cases/stack/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = stack([input1, input2], axis = 1); +} diff --git a/tests/python/frontend/nnef/cases/sub_2d/graph.nnef b/tests/python/frontend/nnef/cases/sub_2d/graph.nnef new file mode 100644 index 000000000000..b3c33a2cf882 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sub_2d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16]); + input2 = external(shape = [4,16]); + output = sub(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/sub_4d/graph.nnef b/tests/python/frontend/nnef/cases/sub_4d/graph.nnef new file mode 100644 index 000000000000..ff8a068e4f27 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sub_4d/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [4,16,32,32]); + output = sub(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/sub_4d_broadcast/graph.nnef b/tests/python/frontend/nnef/cases/sub_4d_broadcast/graph.nnef new file mode 100644 index 000000000000..1ffec0c486ac --- /dev/null +++ b/tests/python/frontend/nnef/cases/sub_4d_broadcast/graph.nnef @@ -0,0 +1,8 @@ +version 1.0; + +graph G( input1, input2 ) -> ( output ) +{ + input1 = external(shape = [4,16,32,32]); + input2 = external(shape = [1,16,1,1]); + output = sub(input1, input2); +} diff --git a/tests/python/frontend/nnef/cases/sub_4d_constant/graph.nnef b/tests/python/frontend/nnef/cases/sub_4d_constant/graph.nnef new file mode 100644 index 000000000000..c9c6abf4951e --- /dev/null +++ b/tests/python/frontend/nnef/cases/sub_4d_constant/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = sub(input, 0.5); +} diff --git a/tests/python/frontend/nnef/cases/sum_reduce_channel/graph.nnef b/tests/python/frontend/nnef/cases/sum_reduce_channel/graph.nnef new file mode 100644 index 000000000000..ba9154a2e715 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sum_reduce_channel/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = sum_reduce(input, axes = [1]); +} diff --git a/tests/python/frontend/nnef/cases/sum_reduce_spatial/graph.nnef b/tests/python/frontend/nnef/cases/sum_reduce_spatial/graph.nnef new file mode 100644 index 000000000000..b46afa623754 --- /dev/null +++ b/tests/python/frontend/nnef/cases/sum_reduce_spatial/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = sum_reduce(input, axes = [2,3]); +} diff --git a/tests/python/frontend/nnef/cases/tan_2d/graph.nnef b/tests/python/frontend/nnef/cases/tan_2d/graph.nnef new file mode 100644 index 000000000000..af203dcb8a4d --- /dev/null +++ b/tests/python/frontend/nnef/cases/tan_2d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = tan(input); +} diff --git a/tests/python/frontend/nnef/cases/tan_4d/graph.nnef b/tests/python/frontend/nnef/cases/tan_4d/graph.nnef new file mode 100644 index 000000000000..6b039dd270ba --- /dev/null +++ b/tests/python/frontend/nnef/cases/tan_4d/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = tan(input); +} diff --git a/tests/python/frontend/nnef/cases/tanh/graph.nnef b/tests/python/frontend/nnef/cases/tanh/graph.nnef new file mode 100644 index 000000000000..1d39aec99c8c --- /dev/null +++ b/tests/python/frontend/nnef/cases/tanh/graph.nnef @@ -0,0 +1,10 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + filter = constant(shape = [16,1,1,1], value = [1.0]); + bias = constant(shape = [1,16], value = [0.0]); + conv = conv(input, filter, bias, groups = 0); + output = tanh(conv); +} diff --git a/tests/python/frontend/nnef/cases/tanh_2d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/tanh_2d_standalone/graph.nnef new file mode 100644 index 000000000000..a5dae283dfad --- /dev/null +++ b/tests/python/frontend/nnef/cases/tanh_2d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = tanh(input); +} diff --git a/tests/python/frontend/nnef/cases/tanh_4d_standalone/graph.nnef b/tests/python/frontend/nnef/cases/tanh_4d_standalone/graph.nnef new file mode 100644 index 000000000000..7c9ee3a6c14a --- /dev/null +++ b/tests/python/frontend/nnef/cases/tanh_4d_standalone/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = tanh(input); +} diff --git a/tests/python/frontend/nnef/cases/tile_batch/graph.nnef b/tests/python/frontend/nnef/cases/tile_batch/graph.nnef new file mode 100644 index 000000000000..853f7789e500 --- /dev/null +++ b/tests/python/frontend/nnef/cases/tile_batch/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [1,16]); + output = tile(input, repeats = [16,1]); +} diff --git a/tests/python/frontend/nnef/cases/tile_channel/graph.nnef b/tests/python/frontend/nnef/cases/tile_channel/graph.nnef new file mode 100644 index 000000000000..bddc2f13ad5f --- /dev/null +++ b/tests/python/frontend/nnef/cases/tile_channel/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [16,1]); + output = tile(input, repeats = [1,16]); +} diff --git a/tests/python/frontend/nnef/cases/tile_spatial/graph.nnef b/tests/python/frontend/nnef/cases/tile_spatial/graph.nnef new file mode 100644 index 000000000000..6f44e9847083 --- /dev/null +++ b/tests/python/frontend/nnef/cases/tile_spatial/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = tile(input, repeats = [1,1,3,3]); +} diff --git a/tests/python/frontend/nnef/cases/transpose_nchw_to_nhwc/graph.nnef b/tests/python/frontend/nnef/cases/transpose_nchw_to_nhwc/graph.nnef new file mode 100644 index 000000000000..7e6dbd6a7668 --- /dev/null +++ b/tests/python/frontend/nnef/cases/transpose_nchw_to_nhwc/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16,32,32]); + output = transpose(input, axes = [0,2,3,1]); +} diff --git a/tests/python/frontend/nnef/cases/transpose_nhwc_to_nchw/graph.nnef b/tests/python/frontend/nnef/cases/transpose_nhwc_to_nchw/graph.nnef new file mode 100644 index 000000000000..0e6f5172989a --- /dev/null +++ b/tests/python/frontend/nnef/cases/transpose_nhwc_to_nchw/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,32,32,16]); + output = transpose(input, axes = [0,3,1,2]); +} diff --git a/tests/python/frontend/nnef/cases/unsqueeze/graph.nnef b/tests/python/frontend/nnef/cases/unsqueeze/graph.nnef new file mode 100644 index 000000000000..ede2811723f8 --- /dev/null +++ b/tests/python/frontend/nnef/cases/unsqueeze/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output ) +{ + input = external(shape = [4,16]); + output = unsqueeze(input, axes = [2,3]); +} diff --git a/tests/python/frontend/nnef/cases/unstack/graph.nnef b/tests/python/frontend/nnef/cases/unstack/graph.nnef new file mode 100644 index 000000000000..1c37b792c4c6 --- /dev/null +++ b/tests/python/frontend/nnef/cases/unstack/graph.nnef @@ -0,0 +1,7 @@ +version 1.0; + +graph G( input ) -> ( output1, output2, output3 ) +{ + input = external(shape = [4,3,16]); + [output1, output2, output3] = unstack(input, axis = 1); +} diff --git a/tests/python/frontend/nnef/test_forward.py b/tests/python/frontend/nnef/test_forward.py new file mode 100644 index 000000000000..1e6caceabb47 --- /dev/null +++ b/tests/python/frontend/nnef/test_forward.py @@ -0,0 +1,1627 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +import numpy as np + +import _nnef +import nnef +import nnef_tools.interpreter.pytorch as interpreter + +import tvm +import tvm.testing +from tvm import relay + + +graphs_dir = os.path.join("tests", "python", "frontend", "nnef", "cases") + + +def get_nnef_outputs(path, inputs): + ip = interpreter.Interpreter(path, None, None) + inputs = [inputs[tensor.name] for tensor in ip.input_details()] + return ip(inputs) + + +def get_type(val): + if val == "scalar": + return "float32" + if val == "integer": + return "int32" + if val == "logical": + return "bool" + if val == "string": + return "string" + + +def verify_model( + model_path, + target, + device, + rtol=1e-5, + atol=1e-5, +): + path = os.path.join(graphs_dir, model_path) + graph = nnef.load_graph(path, load_variables=False) + nnef.infer_shapes(graph) + inputs = {} + # generate inputs + for inp in graph.inputs: + intensor = graph.tensors[inp] + shape = intensor.shape + if any(exc in model_path for exc in ["log", "sqrt", "pow", "batch_norm"]): + low = 0.0 + else: + low = -1.0 + high = 1.0 + if "acosh" in model_path: + high = 2.0 + low = 1.0 + if intensor.dtype == "scalar": + inputs[inp] = np.random.uniform(low=low, high=high, size=shape).astype("float32") + elif intensor.dtype == "integer": + inputs[inp] = np.random.randint(0, 64, shape) + elif intensor.dtype == "logical": + inputs[inp] = np.random.binomial(1, 0.5, shape).astype("bool") + elif intensor.dtype == "string": + inputs[inp] = np.random.uniform(low=low, high=high, size=shape).astype("string") + + # set graph parameters + for operation in graph.operations: + if operation.name == "variable": + tensor_name = operation.outputs["output"] + + shape = operation.attribs["shape"] + + assert ( + operation.dtype == "scalar" + ), f"variable of type {operation.dtype} is not supported, please update verify_model" + + data = np.random.uniform(low=-1.0, size=shape).astype("float32") + + tensor = graph.tensors[tensor_name] + graph.tensors[tensor_name] = _nnef.Tensor( + tensor.name, tensor.dtype, shape, data, tensor.quantization + ) + + outputs = get_nnef_outputs(graph, inputs) + + mod, params = relay.frontend.from_nnef(graph) + + with tvm.transform.PassContext(opt_level=3): + # dev = tvm.device(target, 0) + executor = relay.create_executor( + "graph", mod, device=device, target=target, params=params + ).evaluate() + out = executor(**inputs) + + if not isinstance(out, (list, tuple)): + out = [out] + + for i, base_out in enumerate(outputs): + tvm.testing.assert_allclose(out[i].numpy(), outputs[base_out], rtol=rtol, atol=atol) + + +# graph tests + + +@tvm.testing.parametrize_targets +def test_ats_tan_2d(target, dev): + verify_model("tan_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_tan_4d(target, dev): + verify_model("tan_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_sinh_2d(target, dev): + verify_model("sinh_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_sinh_4d(target, dev): + verify_model("sinh_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_cosh_2d(target, dev): + verify_model("cosh_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_cosh_4d(target, dev): + verify_model("cosh_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_asin_2d(target, dev): + verify_model("asin_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_asin_4d(target, dev): + verify_model("asin_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_acos_2d(target, dev): + verify_model("acos_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_acos_4d(target, dev): + verify_model("acos_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_atan_2d(target, dev): + verify_model("atan_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_atan_4d(target, dev): + verify_model("atan_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_asinh_2d(target, dev): + verify_model("asinh_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_asinh_4d(target, dev): + verify_model("asinh_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_acosh_2d(target, dev): + verify_model("acosh_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_acosh_4d(target, dev): + verify_model("acosh_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_atanh_2d(target, dev): + verify_model("atanh_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_atanh_4d(target, dev): + verify_model("atanh_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_clamp_2d(target, dev): + verify_model("clamp_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_clamp_4d(target, dev): + verify_model("clamp_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_clamp_4d_constant(target, dev): + verify_model("clamp_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_reshape_partial(target, dev): + verify_model("reshape_partial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_split_unbalanced(target, dev): + verify_model("split_unbalanced", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_stack(target, dev): + verify_model("stack", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_unstack(target, dev): + verify_model("unstack", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_slice_strides(target, dev): + verify_model("slice_strides", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_matmul_2d(target, dev): + verify_model("matmul_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_matmul_2d_transpose(target, dev): + verify_model("matmul_2d_transpose", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_matmul_4d(target, dev): + verify_model("matmul_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_matmul_4d_transpose(target, dev): + verify_model("matmul_4d_transpose", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_prelu(target, dev): + verify_model("prelu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_prelu_2d_standalone(target, dev): + verify_model("prelu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_prelu_4d_standalone(target, dev): + verify_model("prelu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_leaky_relu(target, dev): + verify_model("leaky_relu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_leaky_relu_2d_standalone(target, dev): + verify_model("leaky_relu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_leaky_relu_4d_standalone(target, dev): + verify_model("leaky_relu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_elu(target, dev): + verify_model("elu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_elu_2d_standalone(target, dev): + verify_model("elu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_elu_4d_standalone(target, dev): + verify_model("elu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_selu(target, dev): + verify_model("selu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_selu_2d_standalone(target, dev): + verify_model("selu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_selu_4d_standalone(target, dev): + verify_model("selu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_gelu(target, dev): + verify_model("gelu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_gelu_2d_standalone(target, dev): + verify_model("gelu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_gelu_4d_standalone(target, dev): + verify_model("gelu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_silu(target, dev): + verify_model("silu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_silu_2d_standalone(target, dev): + verify_model("silu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_silu_4d_standalone(target, dev): + verify_model("silu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_softplus(target, dev): + verify_model("softplus", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_softplus_2d_standalone(target, dev): + verify_model("softplus_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_softplus_4d_standalone(target, dev): + verify_model("softplus_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_separable_conv3x3(target, dev): + verify_model("separable_conv3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_separable_conv3x3_with_attrs(target, dev): + verify_model("separable_conv3x3_with_attrs", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_separable_conv5x5(target, dev): + verify_model("separable_conv5x5", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_separable_deconv3x3(target, dev): + verify_model("separable_deconv3x3", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_ats_separable_deconv3x3_with_attrs(target, dev): + verify_model("separable_deconv3x3_with_attrs", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_ats_separable_deconv5x5(target, dev): + verify_model("separable_deconv5x5", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_ats_rms_pool3x3(target, dev): + verify_model("rms_pool3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_local_response_normalization(target, dev): + verify_model("local_response_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_local_mean_normalization(target, dev): + verify_model("local_mean_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_local_variance_normalization(target, dev): + verify_model("local_variance_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_local_contrast_normalization(target, dev): + verify_model("local_contrast_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_l1_normalization(target, dev): + verify_model("l1_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_l2_normalization(target, dev): + verify_model("l2_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_pad_0_1_reflect(target, dev): + verify_model("pad_0-1_reflect", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_pad_1_0_reflect(target, dev): + verify_model("pad_1-0_reflect", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_pad_1_1_reflect(target, dev): + verify_model("pad_1-1_reflect", target, dev, rtol=1e-5, atol=1e-5) + + +# GENERATED CASES START + + +@tvm.testing.parametrize_targets +def test_cts_gt_2d(target, dev): + verify_model("gt_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_2d(target, dev): + verify_model("max_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mean_reduce_spatial(target, dev): + verify_model("mean_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_4d(target, dev): + verify_model("select_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_pad1_0(target, dev): + verify_model("max_pool3x3_pad1-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_relu(target, dev): + verify_model("relu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_split_channel(target, dev): + verify_model("split_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rcp_4d(target, dev): + verify_model("rcp_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool2x2(target, dev): + verify_model("max_pool2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool2x2(target, dev): + verify_model("avg_pool2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rcp_2d(target, dev): + verify_model("rcp_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_log2_4d(target, dev): + verify_model("log2_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_stride2x2(target, dev): + verify_model("conv3x3_stride2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_lt_4d_constant(target, dev): + verify_model("lt_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_or_4d(target, dev): + verify_model("or_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv7x7(target, dev): + verify_model("deconv7x7", target, dev, rtol=1e-5, atol=1e-4) + + +@tvm.testing.parametrize_targets +def test_cts_nearest_upsample(target, dev): + verify_model("nearest_upsample", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ceil_4d(target, dev): + verify_model("ceil_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_floor_2d(target, dev): + verify_model("floor_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool1x1(target, dev): + verify_model("avg_pool1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_log_4d(target, dev): + verify_model("log_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sum_reduce_channel(target, dev): + verify_model("sum_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_min_reduce_spatial(target, dev): + verify_model("min_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_4d_broadcast(target, dev): + verify_model("max_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_pad0_1(target, dev): + verify_model("max_pool3x3_pad0-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_cos_2d(target, dev): + verify_model("cos_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_not_4d(target, dev): + verify_model("not_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sub_4d(target, dev): + verify_model("sub_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_aligned_replicate(target, dev): + verify_model("bilinear_upsample_aligned_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_log_2d(target, dev): + verify_model("log_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_argmin_reduce_spatial(target, dev): + verify_model("argmin_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_2d(target, dev): + verify_model("select_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ne_4d(target, dev): + verify_model("ne_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_or_2d(target, dev): + verify_model("or_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_eq_2d(target, dev): + verify_model("eq_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rsqr_2d(target, dev): + verify_model("rsqr_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_eq_4d(target, dev): + verify_model("eq_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv7x7_stride4x4(target, dev): + verify_model("deconv7x7_stride4x4", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3(target, dev): + verify_model("max_pool3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_and_4d(target, dev): + verify_model("and_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mul_4d(target, dev): + verify_model("mul_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_softmax(target, dev): + verify_model("softmax", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sign_4d(target, dev): + verify_model("sign_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mul_4d_constant(target, dev): + verify_model("mul_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_le_4d_constant(target, dev): + verify_model("le_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box2x2(target, dev): + verify_model("box2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_or_4d_broadcast(target, dev): + verify_model("or_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv5x5(target, dev): + verify_model("deconv5x5", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3_pad1_0(target, dev): + verify_model("box3x3_pad1-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3_pad1_0(target, dev): + verify_model("debox3x3_pad1-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ge_4d_broadcast(target, dev): + verify_model("ge_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_linear_reshape(target, dev): + verify_model("linear_reshape", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_le_2d(target, dev): + verify_model("le_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3(target, dev): + verify_model("deconv3x3", target, dev, rtol=1e-5, atol=5e-3) + + +@tvm.testing.parametrize_targets +def test_cts_nearest_downsample(target, dev): + verify_model("nearest_downsample", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_4d_true(target, dev): + verify_model("select_4d_true", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_min_4d_broadcast(target, dev): + verify_model("min_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_4d(target, dev): + verify_model("max_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_4d_constant(target, dev): + verify_model("max_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sum_reduce_spatial(target, dev): + verify_model("sum_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_min_2d(target, dev): + verify_model("min_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ge_2d(target, dev): + verify_model("ge_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv2x2(target, dev): + verify_model("conv2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv4x4_stride2x2(target, dev): + verify_model("conv4x4_stride2x2", target, dev, rtol=1e-5, atol=5e-3) + + +@tvm.testing.parametrize_targets +def test_cts_debox1x1(target, dev): + verify_model("debox1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_reshape_flatten(target, dev): + verify_model("reshape_flatten", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_nobias(target, dev): + verify_model("conv3x3_nobias", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_tile_spatial(target, dev): + verify_model("tile_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_softmax_4d_standalone(target, dev): + verify_model("softmax_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rsqrt_4d(target, dev): + verify_model("rsqrt_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_concat_channel(target, dev): + verify_model("concat_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_area_downsample(target, dev): + verify_model("area_downsample", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_pad1_1(target, dev): + verify_model("max_pool3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sigmoid_2d_standalone(target, dev): + verify_model("sigmoid_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ne_4d_constant(target, dev): + verify_model("ne_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3(target, dev): + verify_model("conv3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_all_reduce_channel(target, dev): + verify_model("all_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_squeeze_spatial(target, dev): + verify_model("squeeze_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_and_4d_constant(target, dev): + verify_model("and_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_constant_border(target, dev): + verify_model("max_pool3x3_constant-border", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_argmax_reduce_spatial(target, dev): + verify_model("argmax_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_cos_4d(target, dev): + verify_model("cos_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sqr_4d(target, dev): + verify_model("sqr_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rsqrt_2d(target, dev): + verify_model("rsqrt_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_symmetric_replicate(target, dev): + verify_model("bilinear_upsample_symmetric_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_tile_channel(target, dev): + verify_model("tile_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_div_4d(target, dev): + verify_model("div_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sqrt_2d(target, dev): + verify_model("sqrt_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_and_4d_broadcast(target, dev): + verify_model("and_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_transpose_nhwc_to_nchw(target, dev): + verify_model("transpose_nhwc_to_nchw", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_pad0_1(target, dev): + verify_model("avg_pool3x3_pad0-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_round_2d(target, dev): + verify_model("round_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3_pad0_1(target, dev): + verify_model("box3x3_pad0-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv6x6(target, dev): + verify_model("deconv6x6", target, dev, rtol=1e-5, atol=1e-4) + + +@tvm.testing.parametrize_targets +def test_cts_add_4d_constant(target, dev): + verify_model("add_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_lt_2d(target, dev): + verify_model("lt_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_min_4d(target, dev): + verify_model("min_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3_stride1x1(target, dev): + verify_model("box3x3_stride1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_linear_nobias(target, dev): + verify_model("linear_nobias", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_div_2d(target, dev): + verify_model("div_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_stride1x1(target, dev): + verify_model("avg_pool3x3_stride1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv7x7(target, dev): + verify_model("conv7x7", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_groups0(target, dev): + verify_model("conv3x3_groups0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mul_2d(target, dev): + verify_model("mul_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_pad1_0(target, dev): + verify_model("deconv3x3_pad1-0", target, dev, rtol=1e-5, atol=5e-3) + + +@tvm.testing.parametrize_targets +def test_cts_ne_2d(target, dev): + verify_model("ne_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_pad1_1(target, dev): + verify_model("avg_pool3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mean_reduce_channel(target, dev): + verify_model("mean_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv5x5(target, dev): + verify_model("conv5x5", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_stride1x1(target, dev): + verify_model("max_pool3x3_stride1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_1_0_replicate(target, dev): + verify_model("pad_1-0_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3_pad1_1(target, dev): + verify_model("debox3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_pad1_0(target, dev): + verify_model("avg_pool3x3_pad1-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_symmetric_constant(target, dev): + verify_model("bilinear_upsample_symmetric_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_gt_4d_broadcast(target, dev): + verify_model("gt_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_tanh_4d_standalone(target, dev): + verify_model("tanh_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_add_2d(target, dev): + verify_model("add_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rsqr_4d(target, dev): + verify_model("rsqr_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_div_4d_broadcast(target, dev): + verify_model("div_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_eq_4d_broadcast(target, dev): + verify_model("eq_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_valid(target, dev): + verify_model("conv3x3_valid", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_min_4d_constant(target, dev): + verify_model("min_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_or_4d_constant(target, dev): + verify_model("or_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_min_reduce_channel(target, dev): + verify_model("min_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_reduce_spatial(target, dev): + verify_model("max_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_asymmetric_constant(target, dev): + verify_model("bilinear_upsample_asymmetric_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_pad0_0(target, dev): + verify_model("conv3x3_pad0-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_pad1_0(target, dev): + verify_model("conv3x3_pad1-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_abs_2d(target, dev): + verify_model("abs_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_reduce_channel(target, dev): + verify_model("max_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ge_4d_constant(target, dev): + verify_model("ge_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_transpose_nchw_to_nhwc(target, dev): + verify_model("transpose_nchw_to_nhwc", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_pad1_1(target, dev): + verify_model("deconv3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_ne_4d_broadcast(target, dev): + verify_model("ne_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sqr_2d(target, dev): + verify_model("sqr_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_pad1_1(target, dev): + verify_model("conv3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_aligned_constant(target, dev): + verify_model("bilinear_upsample_aligned_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_log2_2d(target, dev): + verify_model("log2_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_slice(target, dev): + verify_model("slice", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv2x2(target, dev): + verify_model("deconv2x2", target, dev, rtol=1e-5, atol=5e-3) + + +@tvm.testing.parametrize_targets +def test_cts_all_reduce_spatial(target, dev): + verify_model("all_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sqrt_4d(target, dev): + verify_model("sqrt_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv7x7_stride4x4(target, dev): + verify_model("conv7x7_stride4x4", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_ge_4d(target, dev): + verify_model("ge_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_any_reduce_channel(target, dev): + verify_model("any_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_and_2d(target, dev): + verify_model("and_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_add_4d_broadcast(target, dev): + verify_model("add_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_copy_2d(target, dev): + verify_model("copy_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ceil_2d(target, dev): + verify_model("ceil_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_linear_squeeze(target, dev): + verify_model("linear_squeeze", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sub_2d(target, dev): + verify_model("sub_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_valid(target, dev): + verify_model("deconv3x3_valid", target, dev, rtol=1e-5, atol=5e-3) + + +@tvm.testing.parametrize_targets +def test_cts_pow_4d(target, dev): + verify_model("pow_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_1_1_constant(target, dev): + verify_model("pad_1-1_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3(target, dev): + verify_model("debox3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv1x1(target, dev): + verify_model("conv1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_exp_4d(target, dev): + verify_model("exp_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_ignore_border(target, dev): + verify_model("avg_pool3x3_ignore-border", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_pad0_0(target, dev): + verify_model("deconv3x3_pad0-0", target, dev, rtol=1e-5, atol=5e-3) + + +@tvm.testing.parametrize_targets +def test_cts_pow_4d_broadcast(target, dev): + verify_model("pow_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_abs_4d(target, dev): + verify_model("abs_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sin_4d(target, dev): + verify_model("sin_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_2d_true(target, dev): + verify_model("select_2d_true", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_relu_2d_standalone(target, dev): + verify_model("relu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_reshape_squeeze(target, dev): + verify_model("reshape_squeeze", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sub_4d_constant(target, dev): + verify_model("sub_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_linear(target, dev): + verify_model("linear", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pow_2d(target, dev): + verify_model("pow_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3_pad0_1(target, dev): + verify_model("debox3x3_pad0-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_floor_4d(target, dev): + verify_model("floor_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_nobias(target, dev): + verify_model("deconv3x3_nobias", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_batch_norm(target, dev): + verify_model("batch_norm", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_stride2x2(target, dev): + verify_model("deconv3x3_stride2x2", target, dev, rtol=1e-5, atol=5e-3) + + +@tvm.testing.parametrize_targets +def test_cts_debox2x2(target, dev): + verify_model("debox2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_0_1_replicate(target, dev): + verify_model("pad_0-1_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mul_4d_broadcast(target, dev): + verify_model("mul_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3_pad0_0(target, dev): + verify_model("debox3x3_pad0-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_argmin_reduce_channel(target, dev): + verify_model("argmin_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_copy_4d(target, dev): + verify_model("copy_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_not_2d(target, dev): + verify_model("not_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sigmoid_4d_standalone(target, dev): + verify_model("sigmoid_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_exp_2d(target, dev): + verify_model("exp_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_lt_4d(target, dev): + verify_model("lt_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv4x4(target, dev): + verify_model("conv4x4", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3(target, dev): + verify_model("avg_pool3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_pad0_0(target, dev): + verify_model("avg_pool3x3_pad0-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_pad0_1(target, dev): + verify_model("conv3x3_pad0-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_0_1_constant(target, dev): + verify_model("pad_0-1_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv4x4(target, dev): + verify_model("deconv4x4", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_neg_2d(target, dev): + verify_model("neg_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_asymmetric_replicate(target, dev): + verify_model("bilinear_upsample_asymmetric_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv5x5_stride3x3(target, dev): + verify_model("conv5x5_stride3x3", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_relu_4d_standalone(target, dev): + verify_model("relu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool1x1(target, dev): + verify_model("max_pool1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv5x5_pad2_2(target, dev): + verify_model("deconv5x5_pad2-2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_tile_batch(target, dev): + verify_model("tile_batch", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_eq_4d_constant(target, dev): + verify_model("eq_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_lt_4d_broadcast(target, dev): + verify_model("lt_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv1x1(target, dev): + verify_model("deconv1x1", target, dev, rtol=1e-5, atol=2e-3) + + +@tvm.testing.parametrize_targets +def test_cts_sign_2d(target, dev): + verify_model("sign_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_2d_false(target, dev): + verify_model("select_2d_false", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_div_4d_constant(target, dev): + verify_model("div_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pow_4d_constant(target, dev): + verify_model("pow_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_round_4d(target, dev): + verify_model("round_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3_stride1x1(target, dev): + verify_model("debox3x3_stride1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv5x5_stride3x3(target, dev): + verify_model("deconv5x5_stride3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sub_4d_broadcast(target, dev): + verify_model("sub_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_any_reduce_spatial(target, dev): + verify_model("any_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_gt_4d_constant(target, dev): + verify_model("gt_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv6x6(target, dev): + verify_model("conv6x6", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_le_4d(target, dev): + verify_model("le_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_gt_4d(target, dev): + verify_model("gt_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv4x4_stride2x2(target, dev): + verify_model("deconv4x4_stride2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_le_4d_broadcast(target, dev): + verify_model("le_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_tanh_2d_standalone(target, dev): + verify_model("tanh_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3(target, dev): + verify_model("box3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_4d_false(target, dev): + verify_model("select_4d_false", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_tanh(target, dev): + verify_model("tanh", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sin_2d(target, dev): + verify_model("sin_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3_pad0_0(target, dev): + verify_model("box3x3_pad0-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box1x1(target, dev): + verify_model("box1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3_pad1_1(target, dev): + verify_model("box3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv5x5_pad2_2(target, dev): + verify_model("conv5x5_pad2-2", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_pad0_0(target, dev): + verify_model("max_pool3x3_pad0-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_softmax_2d_standalone(target, dev): + verify_model("softmax_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_groups0(target, dev): + verify_model("deconv3x3_groups0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_pad0_1(target, dev): + verify_model("deconv3x3_pad0-1", target, dev, rtol=1e-5, atol=5e-3) + + +@tvm.testing.parametrize_targets +def test_cts_sigmoid(target, dev): + verify_model("sigmoid", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_argmax_reduce_channel(target, dev): + verify_model("argmax_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_1_1_replicate(target, dev): + verify_model("pad_1-1_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_1_0_constant(target, dev): + verify_model("pad_1-0_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_unsqueeze(target, dev): + verify_model("unsqueeze", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_neg_4d(target, dev): + verify_model("neg_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_add_4d(target, dev): + verify_model("add_4d", target, dev, rtol=1e-5, atol=1e-5) diff --git a/tests/python/relax/test_frontend_nnef.py b/tests/python/relax/test_frontend_nnef.py new file mode 100644 index 000000000000..a699ad91f706 --- /dev/null +++ b/tests/python/relax/test_frontend_nnef.py @@ -0,0 +1,1634 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +import numpy as np + +import _nnef +import nnef +import nnef_tools.interpreter.pytorch as interpreter +import tvm.testing +import tvm.relax as relax +import tvm.relax.frontend.nnef +import pytest + +graphs_dir = os.path.join("tests", "python", "frontend", "nnef", "cases") + + +def get_nnef_outputs(path, inputs): + ip = interpreter.Interpreter(path, None, None) + inputs = [inputs[tensor.name] for tensor in ip.input_details()] + return ip(inputs) + + +def get_type(val): + if val == "scalar": + return "float32" + if val == "integer": + return "int32" + if val == "logical": + return "bool" + if val == "string": + return "string" + + +def verify_model( + model_path, + target="llvm", + dev=tvm.cpu(0), + rtol=1e-5, + atol=1e-5, +): + path = os.path.join(graphs_dir, model_path) + graph = nnef.load_graph(path, load_variables=False) + nnef.infer_shapes(graph) + inputs = {} + # generate inputs + for inp in graph.inputs: + intensor = graph.tensors[inp] + shape = intensor.shape + if any(exc in model_path for exc in ["log", "sqrt", "pow", "batch_norm"]): + low = 0.0 + else: + low = -1.0 + high = 1.0 + if "acosh" in model_path: + high = 2.0 + low = 1.0 + if intensor.dtype == "scalar": + inputs[inp] = np.random.uniform(low=low, high=high, size=shape).astype("float32") + elif intensor.dtype == "integer": + inputs[inp] = np.random.randint(0, 64, shape) + elif intensor.dtype == "logical": + inputs[inp] = np.random.binomial(1, 0.5, shape).astype("bool") + elif intensor.dtype == "string": + inputs[inp] = np.random.uniform(low=low, high=high, size=shape).astype("string") + + # set graph parameters + for operation in graph.operations: + if operation.name == "variable": + tensor_name = operation.outputs["output"] + + shape = operation.attribs["shape"] + + assert ( + operation.dtype == "scalar" + ), f"variable of type {operation.dtype} is not supported, please update verify_model" + + if any(exc in model_path for exc in ["log", "sqrt", "pow", "batch_norm"]): + low = 0.0 + else: + low = -1.0 + data = np.random.uniform(low=low, size=shape).astype("float32") + + tensor = graph.tensors[tensor_name] + graph.tensors[tensor_name] = _nnef.Tensor( + tensor.name, tensor.dtype, shape, data, tensor.quantization + ) + + outputs = get_nnef_outputs(graph, inputs) + + mod = tvm.relax.frontend.nnef.from_nnef(graph) + + exec = relax.build(mod, target=target) + vm = relax.VirtualMachine(exec, dev) + inputs = [tvm.nd.array(arr, device=dev) for arr in inputs.values()] + out = vm["main"](*inputs) + + if isinstance(out, tvm.ir.container.Array): + out = [o.numpy() for o in out] + else: + out = out.numpy() + + if not isinstance(out, (list, tuple)): + out = [out] + + for i, base_out in enumerate(outputs): + tvm.testing.assert_allclose(out[i], outputs[base_out], rtol=rtol, atol=atol) + + +@tvm.testing.parametrize_targets +def test_ats_tan_2d(target, dev): + verify_model("tan_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_tan_4d(target, dev): + verify_model("tan_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_sinh_2d(target, dev): + verify_model("sinh_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_sinh_4d(target, dev): + verify_model("sinh_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_cosh_2d(target, dev): + verify_model("cosh_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_cosh_4d(target, dev): + verify_model("cosh_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_asin_2d(target, dev): + verify_model("asin_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_asin_4d(target, dev): + verify_model("asin_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_acos_2d(target, dev): + verify_model("acos_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_acos_4d(target, dev): + verify_model("acos_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_atan_2d(target, dev): + verify_model("atan_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_atan_4d(target, dev): + verify_model("atan_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_asinh_2d(target, dev): + verify_model("asinh_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_asinh_4d(target, dev): + verify_model("asinh_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_acosh_2d(target, dev): + verify_model("acosh_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_acosh_4d(target, dev): + verify_model("acosh_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_atanh_2d(target, dev): + verify_model("atanh_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_atanh_4d(target, dev): + verify_model("atanh_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_clamp_2d(target, dev): + verify_model("clamp_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_clamp_4d(target, dev): + verify_model("clamp_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_clamp_4d_constant(target, dev): + verify_model("clamp_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_reshape_partial(target, dev): + verify_model("reshape_partial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_split_unbalanced(target, dev): + verify_model("split_unbalanced", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_stack(target, dev): + verify_model("stack", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_unstack(target, dev): + verify_model("unstack", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_slice_strides(target, dev): + verify_model("slice_strides", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_matmul_2d(target, dev): + verify_model("matmul_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_matmul_2d_transpose(target, dev): + verify_model("matmul_2d_transpose", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_matmul_4d(target, dev): + verify_model("matmul_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_matmul_4d_transpose(target, dev): + verify_model("matmul_4d_transpose", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_prelu(target, dev): + verify_model("prelu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_prelu_2d_standalone(target, dev): + verify_model("prelu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_prelu_4d_standalone(target, dev): + verify_model("prelu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_leaky_relu(target, dev): + verify_model("leaky_relu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_leaky_relu_2d_standalone(target, dev): + verify_model("leaky_relu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_leaky_relu_4d_standalone(target, dev): + verify_model("leaky_relu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_elu(target, dev): + verify_model("elu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_elu_2d_standalone(target, dev): + verify_model("elu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_elu_4d_standalone(target, dev): + verify_model("elu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_selu(target, dev): + verify_model("selu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_selu_2d_standalone(target, dev): + verify_model("selu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_selu_4d_standalone(target, dev): + verify_model("selu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_gelu(target, dev): + verify_model("gelu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_gelu_2d_standalone(target, dev): + verify_model("gelu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_gelu_4d_standalone(target, dev): + verify_model("gelu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_silu(target, dev): + verify_model("silu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_silu_2d_standalone(target, dev): + verify_model("silu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_silu_4d_standalone(target, dev): + verify_model("silu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_softplus(target, dev): + verify_model("softplus", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_softplus_2d_standalone(target, dev): + verify_model("softplus_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_softplus_4d_standalone(target, dev): + verify_model("softplus_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_separable_conv3x3(target, dev): + verify_model("separable_conv3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_separable_conv3x3_with_attrs(target, dev): + verify_model("separable_conv3x3_with_attrs", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_separable_conv5x5(target, dev): + verify_model("separable_conv5x5", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_separable_deconv3x3(target, dev): + verify_model("separable_deconv3x3", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_ats_separable_deconv3x3_with_attrs(target, dev): + verify_model("separable_deconv3x3_with_attrs", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_ats_separable_deconv5x5(target, dev): + verify_model("separable_deconv5x5", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_ats_rms_pool3x3(target, dev): + verify_model("rms_pool3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_local_response_normalization(target, dev): + verify_model("local_response_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_local_mean_normalization(target, dev): + verify_model("local_mean_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_local_variance_normalization(target, dev): + verify_model("local_variance_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_local_contrast_normalization(target, dev): + verify_model("local_contrast_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_l1_normalization(target, dev): + verify_model("l1_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_ats_l2_normalization(target, dev): + verify_model("l2_normalization", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_0_1_reflect(target, dev): + verify_model("pad_0-1_reflect", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_1_0_reflect(target, dev): + verify_model("pad_1-0_reflect", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_1_1_reflect(target, dev): + verify_model("pad_1-1_reflect", target, dev, rtol=1e-5, atol=1e-5) + + +# GENERATED CASES START + + +@tvm.testing.parametrize_targets +def test_cts_gt_2d(target, dev): + verify_model("gt_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_2d(target, dev): + verify_model("max_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mean_reduce_spatial(target, dev): + verify_model("mean_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_4d(target, dev): + verify_model("select_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_pad1_0(target, dev): + verify_model("max_pool3x3_pad1-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_relu(target, dev): + verify_model("relu", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_split_channel(target, dev): + verify_model("split_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rcp_4d(target, dev): + verify_model("rcp_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool2x2(target, dev): + verify_model("max_pool2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool2x2(target, dev): + verify_model("avg_pool2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rcp_2d(target, dev): + verify_model("rcp_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_log2_4d(target, dev): + verify_model("log2_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_stride2x2(target, dev): + verify_model("conv3x3_stride2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_lt_4d_constant(target, dev): + verify_model("lt_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_or_4d(target, dev): + verify_model("or_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv7x7(target, dev): + verify_model("deconv7x7", target, dev, rtol=1e-5, atol=1e-4) + + +@tvm.testing.parametrize_targets +def test_cts_nearest_upsample(target, dev): + verify_model("nearest_upsample", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ceil_4d(target, dev): + verify_model("ceil_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_floor_2d(target, dev): + verify_model("floor_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool1x1(target, dev): + verify_model("avg_pool1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_log_4d(target, dev): + verify_model("log_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sum_reduce_channel(target, dev): + verify_model("sum_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_min_reduce_spatial(target, dev): + verify_model("min_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_4d_broadcast(target, dev): + verify_model("max_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_pad0_1(target, dev): + verify_model("max_pool3x3_pad0-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_cos_2d(target, dev): + verify_model("cos_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_not_4d(target, dev): + verify_model("not_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sub_4d(target, dev): + verify_model("sub_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_aligned_replicate(target, dev): + verify_model("bilinear_upsample_aligned_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_log_2d(target, dev): + verify_model("log_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_argmin_reduce_spatial(target, dev): + verify_model("argmin_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_2d(target, dev): + verify_model("select_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ne_4d(target, dev): + verify_model("ne_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_or_2d(target, dev): + verify_model("or_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_eq_2d(target, dev): + verify_model("eq_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rsqr_2d(target, dev): + verify_model("rsqr_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_eq_4d(target, dev): + verify_model("eq_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv7x7_stride4x4(target, dev): + verify_model("deconv7x7_stride4x4", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3(target, dev): + verify_model("max_pool3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_and_4d(target, dev): + verify_model("and_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mul_4d(target, dev): + verify_model("mul_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_softmax(target, dev): + verify_model("softmax", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sign_4d(target, dev): + verify_model("sign_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mul_4d_constant(target, dev): + verify_model("mul_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_le_4d_constant(target, dev): + verify_model("le_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box2x2(target, dev): + verify_model("box2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_or_4d_broadcast(target, dev): + verify_model("or_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv5x5(target, dev): + verify_model("deconv5x5", target, dev, rtol=1e-5, atol=1e-4) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3_pad1_0(target, dev): + verify_model("box3x3_pad1-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3_pad1_0(target, dev): + verify_model("debox3x3_pad1-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ge_4d_broadcast(target, dev): + verify_model("ge_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_linear_reshape(target, dev): + verify_model("linear_reshape", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_le_2d(target, dev): + verify_model("le_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3(target, dev): + verify_model("deconv3x3", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_nearest_downsample(target, dev): + verify_model("nearest_downsample", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_4d_true(target, dev): + verify_model("select_4d_true", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_min_4d_broadcast(target, dev): + verify_model("min_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_4d(target, dev): + verify_model("max_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_4d_constant(target, dev): + verify_model("max_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sum_reduce_spatial(target, dev): + verify_model("sum_reduce_spatial", target, dev, rtol=1e-5, atol=1e-4) + + +@tvm.testing.parametrize_targets +def test_cts_min_2d(target, dev): + verify_model("min_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ge_2d(target, dev): + verify_model("ge_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv2x2(target, dev): + verify_model("conv2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv4x4_stride2x2(target, dev): + verify_model("conv4x4_stride2x2", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_debox1x1(target, dev): + verify_model("debox1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_reshape_flatten(target, dev): + verify_model("reshape_flatten", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_nobias(target, dev): + verify_model("conv3x3_nobias", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_tile_spatial(target, dev): + verify_model("tile_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_softmax_4d_standalone(target, dev): + verify_model("softmax_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rsqrt_4d(target, dev): + verify_model("rsqrt_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_concat_channel(target, dev): + verify_model("concat_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_area_downsample(target, dev): + verify_model("area_downsample", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_pad1_1(target, dev): + verify_model("max_pool3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sigmoid_2d_standalone(target, dev): + verify_model("sigmoid_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ne_4d_constant(target, dev): + verify_model("ne_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3(target, dev): + verify_model("conv3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_all_reduce_channel(target, dev): + verify_model("all_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_squeeze_spatial(target, dev): + verify_model("squeeze_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_and_4d_constant(target, dev): + verify_model("and_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_constant_border(target, dev): + verify_model("max_pool3x3_constant-border", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_argmax_reduce_spatial(target, dev): + verify_model("argmax_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_cos_4d(target, dev): + verify_model("cos_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sqr_4d(target, dev): + verify_model("sqr_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rsqrt_2d(target, dev): + verify_model("rsqrt_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_symmetric_replicate(target, dev): + verify_model("bilinear_upsample_symmetric_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_tile_channel(target, dev): + verify_model("tile_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_div_4d(target, dev): + verify_model("div_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sqrt_2d(target, dev): + verify_model("sqrt_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_and_4d_broadcast(target, dev): + verify_model("and_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_transpose_nhwc_to_nchw(target, dev): + verify_model("transpose_nhwc_to_nchw", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_pad0_1(target, dev): + verify_model("avg_pool3x3_pad0-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_round_2d(target, dev): + verify_model("round_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3_pad0_1(target, dev): + verify_model("box3x3_pad0-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv6x6(target, dev): + verify_model("deconv6x6", target, dev, rtol=1e-5, atol=1e-4) + + +@tvm.testing.parametrize_targets +def test_cts_add_4d_constant(target, dev): + verify_model("add_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_lt_2d(target, dev): + verify_model("lt_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_min_4d(target, dev): + verify_model("min_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3_stride1x1(target, dev): + verify_model("box3x3_stride1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_linear_nobias(target, dev): + verify_model("linear_nobias", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_div_2d(target, dev): + verify_model("div_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_stride1x1(target, dev): + verify_model("avg_pool3x3_stride1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv7x7(target, dev): + verify_model("conv7x7", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_groups0(target, dev): + verify_model("conv3x3_groups0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mul_2d(target, dev): + verify_model("mul_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_pad1_0(target, dev): + verify_model("deconv3x3_pad1-0", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_ne_2d(target, dev): + verify_model("ne_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_pad1_1(target, dev): + verify_model("avg_pool3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mean_reduce_channel(target, dev): + verify_model("mean_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv5x5(target, dev): + verify_model("conv5x5", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_stride1x1(target, dev): + verify_model("max_pool3x3_stride1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@pytest.mark.skip(reason="Replicate - Edge mode is currently not supported in TVM relax") +@tvm.testing.parametrize_targets +def test_cts_pad_1_0_replicate(target, dev): + verify_model("pad_1-0_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3_pad1_1(target, dev): + verify_model("debox3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_pad1_0(target, dev): + verify_model("avg_pool3x3_pad1-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_symmetric_constant(target, dev): + verify_model("bilinear_upsample_symmetric_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_gt_4d_broadcast(target, dev): + verify_model("gt_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_tanh_4d_standalone(target, dev): + verify_model("tanh_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_add_2d(target, dev): + verify_model("add_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_rsqr_4d(target, dev): + verify_model("rsqr_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_div_4d_broadcast(target, dev): + verify_model("div_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_eq_4d_broadcast(target, dev): + verify_model("eq_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_valid(target, dev): + verify_model("conv3x3_valid", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_min_4d_constant(target, dev): + verify_model("min_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_or_4d_constant(target, dev): + verify_model("or_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_min_reduce_channel(target, dev): + verify_model("min_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_reduce_spatial(target, dev): + verify_model("max_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_asymmetric_constant(target, dev): + verify_model("bilinear_upsample_asymmetric_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_pad0_0(target, dev): + verify_model("conv3x3_pad0-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_pad1_0(target, dev): + verify_model("conv3x3_pad1-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_abs_2d(target, dev): + verify_model("abs_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_reduce_channel(target, dev): + verify_model("max_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ge_4d_constant(target, dev): + verify_model("ge_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_transpose_nchw_to_nhwc(target, dev): + verify_model("transpose_nchw_to_nhwc", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_pad1_1(target, dev): + verify_model("deconv3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_ne_4d_broadcast(target, dev): + verify_model("ne_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sqr_2d(target, dev): + verify_model("sqr_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_pad1_1(target, dev): + verify_model("conv3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_aligned_constant(target, dev): + verify_model("bilinear_upsample_aligned_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_log2_2d(target, dev): + verify_model("log2_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_slice(target, dev): + verify_model("slice", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv2x2(target, dev): + verify_model("deconv2x2", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_all_reduce_spatial(target, dev): + verify_model("all_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sqrt_4d(target, dev): + verify_model("sqrt_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv7x7_stride4x4(target, dev): + verify_model("conv7x7_stride4x4", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_ge_4d(target, dev): + verify_model("ge_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_any_reduce_channel(target, dev): + verify_model("any_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_and_2d(target, dev): + verify_model("and_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_add_4d_broadcast(target, dev): + verify_model("add_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_copy_2d(target, dev): + verify_model("copy_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_ceil_2d(target, dev): + verify_model("ceil_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_linear_squeeze(target, dev): + verify_model("linear_squeeze", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sub_2d(target, dev): + verify_model("sub_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_valid(target, dev): + verify_model("deconv3x3_valid", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_pow_4d(target, dev): + verify_model("pow_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_1_1_constant(target, dev): + verify_model("pad_1-1_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3(target, dev): + verify_model("debox3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv1x1(target, dev): + verify_model("conv1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_exp_4d(target, dev): + verify_model("exp_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_ignore_border(target, dev): + verify_model("avg_pool3x3_ignore-border", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_pad0_0(target, dev): + verify_model("deconv3x3_pad0-0", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_pow_4d_broadcast(target, dev): + verify_model("pow_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_abs_4d(target, dev): + verify_model("abs_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sin_4d(target, dev): + verify_model("sin_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_2d_true(target, dev): + verify_model("select_2d_true", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_relu_2d_standalone(target, dev): + verify_model("relu_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_reshape_squeeze(target, dev): + verify_model("reshape_squeeze", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sub_4d_constant(target, dev): + verify_model("sub_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_linear(target, dev): + verify_model("linear", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pow_2d(target, dev): + verify_model("pow_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3_pad0_1(target, dev): + verify_model("debox3x3_pad0-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_floor_4d(target, dev): + verify_model("floor_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_nobias(target, dev): + verify_model("deconv3x3_nobias", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_batch_norm(target, dev): + verify_model("batch_norm", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_stride2x2(target, dev): + verify_model("deconv3x3_stride2x2", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_debox2x2(target, dev): + verify_model("debox2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@pytest.mark.skip(reason="Replicate - Edge mode is currently not supported in TVM relax") +@tvm.testing.parametrize_targets +def test_cts_pad_0_1_replicate(target, dev): + verify_model("pad_0-1_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_mul_4d_broadcast(target, dev): + verify_model("mul_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3_pad0_0(target, dev): + verify_model("debox3x3_pad0-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_argmin_reduce_channel(target, dev): + verify_model("argmin_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_copy_4d(target, dev): + verify_model("copy_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_not_2d(target, dev): + verify_model("not_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sigmoid_4d_standalone(target, dev): + verify_model("sigmoid_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_exp_2d(target, dev): + verify_model("exp_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_lt_4d(target, dev): + verify_model("lt_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv4x4(target, dev): + verify_model("conv4x4", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3(target, dev): + verify_model("avg_pool3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_avg_pool3x3_pad0_0(target, dev): + verify_model("avg_pool3x3_pad0-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv3x3_pad0_1(target, dev): + verify_model("conv3x3_pad0-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_0_1_constant(target, dev): + verify_model("pad_0-1_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv4x4(target, dev): + verify_model("deconv4x4", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_neg_2d(target, dev): + verify_model("neg_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@pytest.mark.skip(reason="Replicate - Edge mode is currently not supported in TVM relax") +@tvm.testing.parametrize_targets +def test_cts_bilinear_upsample_asymmetric_replicate(target, dev): + verify_model("bilinear_upsample_asymmetric_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv5x5_stride3x3(target, dev): + verify_model("conv5x5_stride3x3", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_relu_4d_standalone(target, dev): + verify_model("relu_4d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool1x1(target, dev): + verify_model("max_pool1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv5x5_pad2_2(target, dev): + verify_model("deconv5x5_pad2-2", target, dev, rtol=1e-5, atol=1e-4) + + +@tvm.testing.parametrize_targets +def test_cts_tile_batch(target, dev): + verify_model("tile_batch", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_eq_4d_constant(target, dev): + verify_model("eq_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_lt_4d_broadcast(target, dev): + verify_model("lt_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv1x1(target, dev): + verify_model("deconv1x1", target, dev, rtol=1e-5, atol=2e-3) + + +@tvm.testing.parametrize_targets +def test_cts_sign_2d(target, dev): + verify_model("sign_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_2d_false(target, dev): + verify_model("select_2d_false", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_div_4d_constant(target, dev): + verify_model("div_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pow_4d_constant(target, dev): + verify_model("pow_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_round_4d(target, dev): + verify_model("round_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_debox3x3_stride1x1(target, dev): + verify_model("debox3x3_stride1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv5x5_stride3x3(target, dev): + verify_model("deconv5x5_stride3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sub_4d_broadcast(target, dev): + verify_model("sub_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_any_reduce_spatial(target, dev): + verify_model("any_reduce_spatial", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_gt_4d_constant(target, dev): + verify_model("gt_4d_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv6x6(target, dev): + verify_model("conv6x6", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_le_4d(target, dev): + verify_model("le_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_gt_4d(target, dev): + verify_model("gt_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv4x4_stride2x2(target, dev): + verify_model("deconv4x4_stride2x2", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_le_4d_broadcast(target, dev): + verify_model("le_4d_broadcast", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_tanh_2d_standalone(target, dev): + verify_model("tanh_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3(target, dev): + verify_model("box3x3", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_select_4d_false(target, dev): + verify_model("select_4d_false", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_tanh(target, dev): + verify_model("tanh", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_sin_2d(target, dev): + verify_model("sin_2d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3_pad0_0(target, dev): + verify_model("box3x3_pad0-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box1x1(target, dev): + verify_model("box1x1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_box3x3_pad1_1(target, dev): + verify_model("box3x3_pad1-1", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_conv5x5_pad2_2(target, dev): + verify_model("conv5x5_pad2-2", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_max_pool3x3_pad0_0(target, dev): + verify_model("max_pool3x3_pad0-0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_softmax_2d_standalone(target, dev): + verify_model("softmax_2d_standalone", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_groups0(target, dev): + verify_model("deconv3x3_groups0", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_deconv3x3_pad0_1(target, dev): + verify_model("deconv3x3_pad0-1", target, dev, rtol=1e-5, atol=1e-2) + + +@tvm.testing.parametrize_targets +def test_cts_sigmoid(target, dev): + verify_model("sigmoid", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_argmax_reduce_channel(target, dev): + verify_model("argmax_reduce_channel", target, dev, rtol=1e-5, atol=1e-5) + + +@pytest.mark.skip(reason="Replicate - Edge mode is currently not supported in TVM relax") +@tvm.testing.parametrize_targets +def test_cts_pad_1_1_replicate(target, dev): + verify_model("pad_1-1_replicate", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_pad_1_0_constant(target, dev): + verify_model("pad_1-0_constant", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_unsqueeze(target, dev): + verify_model("unsqueeze", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_neg_4d(target, dev): + verify_model("neg_4d", target, dev, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets +def test_cts_add_4d(target, dev): + verify_model("add_4d", target, dev, rtol=1e-5, atol=1e-5) diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index ee6be87b36d0..aaf2aad12626 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -61,3 +61,6 @@ run_pytest cython python-frontend-coreml tests/python/frontend/coreml echo "Running relay OneFlow frontend test..." run_pytest cython python-frontend-oneflow tests/python/frontend/oneflow + +echo "Running relay NNEF frontend test..." +run_pytest cython python-frontend-nnef tests/python/frontend/nnef diff --git a/tests/scripts/task_python_frontend_cpu.sh b/tests/scripts/task_python_frontend_cpu.sh index 52c3d1078edf..0c02ad264606 100755 --- a/tests/scripts/task_python_frontend_cpu.sh +++ b/tests/scripts/task_python_frontend_cpu.sh @@ -39,3 +39,6 @@ run_pytest cython python-frontend-keras tests/python/frontend/keras echo "Running relay Caffe frontend test..." run_pytest cython python-frontend-caffe tests/python/frontend/caffe + +echo "Running relay NNEF frontend test..." +run_pytest cython python-frontend-nnef tests/python/frontend/nnef