Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microNPU] Merge LUT activation with binary elementwise operation #13935

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self):
"contrib.ethosu.conv2d": op.ethosu_conv2d,
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
"contrib.ethosu.pooling": op.ethosu_pooling,
"contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
}

def create_op_with_lut(self, call):
Expand Down
13 changes: 11 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ def binary_elementwise_compute(
}
broadcast = [value == 1 for value in dmaed_ifm2.shape]

has_lut = activation in ("TANH", "LUT", "SIGMOID")
# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if has_lut:
binary_elementwise_attrs["lut"] = lut

if reversed_operands:
binary_elementwise = te.compute(
(1, ofm_height, ofm_width, ifm_channels),
Expand All @@ -188,7 +196,7 @@ def binary_elementwise_compute(
0 if broadcast[2] else ww,
0 if broadcast[3] else cc,
).astype(ifm.dtype),
dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype),
dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype) + lut_expr,
).astype(ofm_dtype),
name="ethosu_binary_elementwise",
attrs=binary_elementwise_attrs,
Expand All @@ -203,7 +211,8 @@ def binary_elementwise_compute(
0 if broadcast[1] else hh,
0 if broadcast[2] else ww,
0 if broadcast[3] else cc,
).astype(ifm.dtype),
).astype(ifm.dtype)
+ lut_expr,
).astype(ofm_dtype),
name="ethosu_binary_elementwise",
attrs=binary_elementwise_attrs,
Expand Down
27 changes: 5 additions & 22 deletions python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,12 @@
"""Extract information from the binary_elementwise operators in TIR."""
from typing import Tuple
import tvm
from .utils import get_outer_loops, get_op_attrs
from .utils import get_outer_loops, get_op_attrs, get_loads
from .dma import get_ifm_params, get_ofm_params
from .spec import SerialActivation, SerialBinaryElementwise, SerialRescaleConfig
from .producers_consumers import ProducersConsumers


def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var:
"""When the datatype of the ifm, ifm2 and ofm do not match,
casts are inserted in TE to handle the difference in these types.
Since TIR is not directly run on the NPU we can simply ignore
these, and allow the NPU to handle the difference in datatypes
itself.

Parameters
----------
tir_load : tvm.tir.expr.Load

Returns
-------
tvm.tir.Var
"""
return tir_load.value if isinstance(tir_load, tvm.tir.Cast) else tir_load


def get_binary_elementwise_params(
stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers
) -> Tuple[SerialBinaryElementwise, tvm.tir.Var, tvm.tir.Var]:
Expand Down Expand Up @@ -72,9 +54,10 @@ def get_binary_elementwise_params(
reversed_operands = attrs["reversed_operands"]

_, _, _, _, _, inner = get_outer_loops(body, "NHWC")
op = ignore_cast(inner.value)
input_pointer = ignore_cast(op.a).buffer.data
input_pointer1 = ignore_cast(op.b).buffer.data
# loads = [input, input, LUT, LUT]
loads = get_loads(inner)
input_pointer = loads[0].buffer.data
input_pointer1 = loads[1].buffer.data

if reversed_operands:
input_pointer, input_pointer1 = input_pointer1, input_pointer
Expand Down
3 changes: 2 additions & 1 deletion tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,11 +694,12 @@ def make_ethosu_binary_elementwise(
use_rescale: bool = False,
rescale_scale: int = 0,
rescale_shift: int = 0,
lut=relay.const([], dtype="int8"),
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
):
ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
ifm=ifm,
ifm2=ifm2,
lut=relay.const([], dtype="int8"),
lut=lut,
operator_type=operator_type,
ifm_scale=1,
ifm_zero_point=0,
Expand Down
19 changes: 19 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,5 +1313,24 @@ def fully_connected(x):
)


@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
def test_tflite_subtract_sigmoid(accel_type):
np.random.seed(0)
ifm_shape = [1, 6, 8, 4]

@tf.function
def subtract_sigmoid_function(lhs, rhs):
op = tf.math.subtract(lhs, rhs)
op = tf.nn.sigmoid(op)
return op

infra.compare_tvm_with_tflite(
subtract_sigmoid_function,
[ifm_shape, ifm_shape],
accel_type,
enable_cascader=is_u55_accel_type(accel_type),
)


if __name__ == "__main__":
tvm.testing.main()
42 changes: 42 additions & 0 deletions tests/python/contrib/test_ethosu/test_lut_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,48 @@ def after():
assert tvm.ir.structural_equal(mod, after())


def test_merge_lut_into_binary_elementwise():
"""If an binary elementwise operator is followed by an identity operator
with LUT, we can merge the two operataors."""

shape = (1, 8, 8, 4)
dtype = "int8"
ifm = relay.var("x", shape=shape, dtype=dtype)
ifm2 = relay.var("x", shape=shape, dtype=dtype)
lut1 = relay.const([i for i in range(256)], dtype=dtype)
lut2 = relay.const([i for i in reversed(range(256))], dtype=dtype)

def before():
sub = infra.make_ethosu_binary_elementwise(ifm, ifm2, shape[-1], shape[-1], "SUB", dtype)
id1 = infra.make_ethosu_identity(sub, lut=lut1, activation="TANH")
add = infra.make_ethosu_binary_elementwise(id1, ifm2, shape[-1], shape[-1], "ADD", dtype)
id2 = infra.make_ethosu_identity(add, lut=lut2, activation="SIGMOID")

func = relay.Function(relay.analysis.free_vars(id2), id2)
func = func.with_attr("Compiler", "ethos-u")
mod = tvm.IRModule.from_expr(func)
return mod

def after():
sub = infra.make_ethosu_binary_elementwise(
ifm, ifm2, shape[-1], shape[-1], "SUB", dtype, lut=lut1, activation="TANH"
)
add = infra.make_ethosu_binary_elementwise(
sub, ifm2, shape[-1], shape[-1], "ADD", dtype, lut=lut2, activation="SIGMOID"
)

func = relay.Function(relay.analysis.free_vars(add), add)
func = func.with_attr("Compiler", "ethos-u")
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
return mod

mod = LUTsOptimizer()(before())
mod = relay.transform.InferType()(mod)

assert tvm.ir.structural_equal(mod, after())


def test_multiple_luts():
"""Test that when an operation already has a LUT, we don't overwrite that LUT"""

Expand Down