Skip to content

Commit

Permalink
[WebGPU] Add tir.dp4a (#17124)
Browse files Browse the repository at this point in the history
* [WebGPU] Add `tir.dp4a`

This patch adds `tir.dp4a` as a new TIR built-in operator as a
preparation of supporting int8 computation with `dot4I8Packed`
in WebGPU backend.

* Fix format issues

* Fix format issue

* Replace `accumulation` with `accumulator`
  • Loading branch information
Jiawei-Shao authored Jul 1, 2024
1 parent ab7c1a9 commit 4247433
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 0 deletions.
5 changes: 5 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,11 @@ TVM_DLL const Op& vectorlow();
*/
TVM_DLL const Op& vectorcombine();

/*!
* \brief Dot product of two int8x4 vectors and add an optional accumulator
*/
TVM_DLL const Op& dp4a();

/*!
* \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA
*/
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,7 @@ def wrapped(*args, **kwargs):
vectorhigh = _dtype_forward(_tir_op.vectorhigh)
vectorcombine = _dtype_forward(_tir_op.vectorcombine)
get_active_lane_mask = _dtype_forward(_tir_op.get_active_lane_mask)
dp4a = _dtype_forward(_tir_op.dp4a)


broadcast = Broadcast
Expand Down Expand Up @@ -2191,6 +2192,7 @@ def wrapped(*args, **kwargs):
"vectorlow",
"vectorhigh",
"vectorcombine",
"dp4a",
"assume",
"undef",
"tvm_call_packed",
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
from .op import start_profile_intrinsic, end_profile_intrinsic
from .op import vscale, get_active_lane_mask, get_vscale_expr
from .op import dp4a
from .generic import add, subtract, multiply

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,6 +1813,31 @@ def vectorcombine(dtype, vec1, vec2):
return call_intrin(dtype, "tir.vectorcombine", vec1, vec2)


def dp4a(vec1, vec2, acc=0):
"""Dot product of two int8x4 vectors and add an optional accumulator
Parameters
----------
vec1 : int8x4
The input vector.
vec2 : int8x4
The input vector.
acc : int32
The accumulator.
Returns
-------
call : PrimExpr
The call expression.
"""
vec1 = convert(vec1)
vec2 = convert(vec2)
acc = convert(acc)
return call_intrin("int32", "tir.dp4a", vec1, vec2, acc)


def ret(val):
"""Create a tir return expression
Expand Down
5 changes: 5 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,11 @@ TIR_DEFINE_BUILTIN_FUNC(vectorcombine)
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst));

TIR_DEFINE_BUILTIN_FUNC(dp4a)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst));

TIR_DEFINE_BUILTIN_FUNC(atomic_add)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

Expand Down
8 changes: 8 additions & 0 deletions tests/python/tir-base/test_tir_op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,14 @@ def test_tir_op_vectorhigh():
assert expr.op.name == "tir.vectorhigh"


def test_tir_op_dp4a():
vec1 = tir.Var("vec1", dtype="int8x4")
vec2 = tir.Var("vec2", dtype="int8x4")
acc = tir.Var("acc", dtype="int32")
expr = tir.dp4a(vec1, vec2, acc)
assert expr.op.name == "tir.dp4a"


def test_tir_op_vectorcombine():
buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1)
vec = buffer.vload([0, 0], dtype="int8x16")
Expand Down

0 comments on commit 4247433

Please sign in to comment.