diff --git a/python/tvm/topi/hexagon/__init__.py b/python/tvm/topi/hexagon/__init__.py index dfe739288187..a3768a6e809e 100644 --- a/python/tvm/topi/hexagon/__init__.py +++ b/python/tvm/topi/hexagon/__init__.py @@ -26,4 +26,5 @@ from .pooling import * from .reduce import * from .resize2d import * +from .tensor_intrin import * from .qnn import * diff --git a/python/tvm/topi/hexagon/injective.py b/python/tvm/topi/hexagon/injective.py index 9ced0ac7d399..b1d1e1541961 100644 --- a/python/tvm/topi/hexagon/injective.py +++ b/python/tvm/topi/hexagon/injective.py @@ -19,6 +19,8 @@ import tvm +import numpy as np + def schedule_injective(outs): """Schedule for injective op. @@ -37,11 +39,10 @@ def schedule_injective(outs): outs = [outs] if isinstance(outs, tvm.te.tensor.Tensor) else outs s = tvm.te.create_schedule([x.op for x in outs]) tvm.te.schedule.AutoInlineInjective(s) - - # Fuse axes and vectorize inner 128 elements + # Fuse axes and vectorize inner elements for x in outs: fused = s[x].fuse(*x.op.axis) - _, inner = s[x].split(fused, factor=128) + _, inner = s[x].split(fused, factor=128 // np.dtype(x.dtype).itemsize) s[x].vectorize(inner) return s diff --git a/python/tvm/topi/hexagon/tensor_intrin.py b/python/tvm/topi/hexagon/tensor_intrin.py new file mode 100644 index 000000000000..bdc63854328b --- /dev/null +++ b/python/tvm/topi/hexagon/tensor_intrin.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Optimized implementation of q_multiply_shift based on LLVM intrinsics""" + +import tvm +from tvm.ir import register_intrin_lowering + + +def _q_multiply_shift_hexagon(op): + """ + Implementation of q_multiply_shift through hexagon intrinsics vmpyewuh and vmpyowh when q == 31. + + Please note that this is introducing a small round-up error for some corner cases with negative + shift argument. This is because we are rounding twice instead than only once. I.e.: + + * original q_multiply_shift: round(x*y*2^-s) + * hexagon q_multiply_shift: round(round(x*y)*2^-s) + """ + x = op.args[0] + y = op.args[1] + fractional_bits = op.args[2] + shift = op.args[3] + + # Don't use this intrinsic if we don't have a int32x32 vector + # or if we are not multiplying q31 numbers + if x.dtype != "int32x32" or fractional_bits.value != 31: + return op + + # Case 1, shift is negative + mul_e_1 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + mul_o_1 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_1, x, y + ) + fixup = mul_o_1 & (-shift) + round_mul = mul_o_1 + fixup + out_negative_shift = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vaslwv.128B", tvm.tir.const(2, "uint32"), round_mul, shift + ) + + # Case 2, shift is positive + x = x * (1 << (shift)) + mul_e_2 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + mul_o_2 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_2, x, y + ) + + # Select depending on the shift + return tvm.tir.Select(shift < 0, out_negative_shift, mul_o_2) + + +register_intrin_lowering( + "tir.q_multiply_shift", target="hexagon", f=_q_multiply_shift_hexagon, level=99 +) diff --git a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py new file mode 100644 index 000000000000..8ee04a649990 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.testing +from tvm import relay +from tvm.relay.backend import Executor +from tvm.contrib.hexagon.session import Session + +import re +import numpy as np + + +@tvm.testing.requires_hexagon +def test_vmpy_intrinsic_presence(): + """ + check intrinsic lowering for fixed_point_multiply operation + """ + ishape = (1, 128) + a = relay.var("a", relay.TensorType(ishape, "int32")) + + y = relay.fixed_point_multiply(a, 1395864320, 1) # 1.3 + + relay_mod = tvm.IRModule.from_expr(y) + + params = {} + target_hexagon = tvm.target.hexagon("v68") + executor = Executor("graph", {"link-params": True}) + + with tvm.transform.PassContext(opt_level=3): + hexagon_lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target_hexagon, host=target_hexagon), + executor=executor, + params=params, + ) + + asm = hexagon_lowered.lib.get_source("asm") + + # Check that 'vmpye' instruction was generated in asm file. + vmpye_regex = re.compile(r"v\d{1,2}.w = vmpye\(v\d{1,2}.w,v\d{1,2}.uh\)") + assert vmpye_regex.search(asm) is not None + + # Check that 'vmpyo' instruction was generated in asm file. + vmpyo_regex = re.compile(r"v\d{1,2}.w \+= vmpyo\(v\d{1,2}.w,v\d{1,2}.h\):<<1:rnd:sat:shift") + assert vmpyo_regex.search(asm) is not None + + +def build_module(relay_mod, target): + params = {} + executor = Executor("graph", {"link-params": True}) + lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target, host=target), + executor=executor, + params=params, + ) + return lowered + + +def run_module(graph_mod, inputs): + graph_mod.set_input(**inputs) + graph_mod.run() + output = graph_mod.get_output(0).numpy() + return output + + +@tvm.testing.requires_hexagon +def test_fixed_point_multiply_positive_shift(hexagon_session: Session): + ishape = (6, 32) + a = relay.var("a", relay.TensorType(ishape, "int32")) + multiplier, shift = (1395864320, 1) # 1.3 + fpm = relay.fixed_point_multiply(a, multiplier, shift) + relay_mod = tvm.IRModule.from_expr(fpm) + + with tvm.transform.PassContext(opt_level=3): + # Compile for Hexagon... + hexagon_lowered = build_module(relay_mod, tvm.target.hexagon("v68")) + + # Compile for LLVM... + llvm_lowered = build_module(relay_mod, tvm.target.Target("llvm")) + + data_in = np.arange(-96, 96).reshape(ishape) + inputs = {"a": data_in} + + # Run hexagon... + graph_mod = hexagon_session.get_executor_from_factory(hexagon_lowered) + hexagon_output = run_module(graph_mod, inputs) + + # Run llvm... + llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) + expected_output = run_module(llvm_graph_mod, inputs) + + tvm.testing.assert_allclose(hexagon_output, expected_output) + + +@tvm.testing.requires_hexagon +def test_fixed_point_multiply_negative_shift(hexagon_session: Session): + ishape = (6, 32) + a = relay.var("a", relay.TensorType(ishape, "int32")) + multiplier, shift = (1288490240, -2) # 0.15 + fpm = relay.fixed_point_multiply(a, multiplier, shift) + relay_mod = tvm.IRModule.from_expr(fpm) + + with tvm.transform.PassContext(opt_level=3): + # Compile for Hexagon... + hexagon_lowered = build_module(relay_mod, tvm.target.hexagon("v68")) + + # Compile for LLVM... + llvm_lowered = build_module(relay_mod, tvm.target.Target("llvm")) + + data_in = np.arange(-96, 96).reshape(ishape) + inputs = {"a": data_in} + + # Run hexagon... + graph_mod = hexagon_session.get_executor_from_factory(hexagon_lowered) + hexagon_output = run_module(graph_mod, inputs) + + # Run llvm... + llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) + expected_output = run_module(llvm_graph_mod, inputs) + + tvm.testing.assert_allclose(hexagon_output, expected_output, atol=1) + + +if __name__ == "__main__": + tvm.testing.main()