Skip to content

Commit

Permalink
[Runtime] Implemented Datatype.itemsize() (#16880)
Browse files Browse the repository at this point in the history
* [Runtime] Implemented Datatype.itemsize()
  • Loading branch information
vinx13 authored Apr 14, 2024
1 parent d0cbb02 commit 64911ab
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 6 deletions.
14 changes: 14 additions & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,20 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

def itemsize(self):
"""Get the number of bytes of a single element of this data type. When the number of lanes
is greater than 1, the itemsize is the size of the vector type.
Returns
-------
itemsize : int
The number of bytes of a single element of this data type
"""
lanes_as_int = ctypes.c_int16(self.lanes).value
if lanes_as_int < 0:
raise ValueError("Cannot determine itemsize for scalable vector types")
return (self.bits * self.lanes + 7) // 8


if ml_dtypes is not None:
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/dlight/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV):
def get_bytes(dtype: Union[DataType, str]) -> int:
if isinstance(dtype, str):
dtype = DataType(dtype)
return dtype.bits * dtype.lanes // 8
return dtype.itemsize()


def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]:
Expand Down
8 changes: 3 additions & 5 deletions python/tvm/dlight/gpu/low_batch_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""A rule for low-batch GEMM / decode-GEMM using GEMV schedule."""
import re
from functools import reduce
from typing import List, Optional, Set, Union

Expand Down Expand Up @@ -55,10 +54,9 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV):


def get_bytes(dtype: Union[DataType, str]) -> int:
num = re.findall(r"\d+", dtype)
if len(num) != 1:
raise ValueError(f"Cannot get bytes from {dtype}")
return int(num[0]) // 8
if isinstance(dtype, str):
dtype = DataType(dtype)
return dtype.itemsize()


def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]:
Expand Down
40 changes: 40 additions & 0 deletions tests/python/ir/test_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test data type related API"""
import tvm
from tvm import DataType
import tvm.testing
import pytest


@pytest.mark.parametrize(
"dtype_str, expected_size",
[("float32", 4), ("float32x4", 16), ("e5m2_float8x4", 4), ("uint8", 1)],
)
def test_dtype_itemsize(dtype_str, expected_size):
dtype = DataType(dtype_str)
assert dtype.itemsize() == expected_size


@pytest.mark.parametrize("dtype_str", [("int32xvscalex4")])
def test_dtype_itemmize_error(dtype_str):
with pytest.raises(ValueError):
size = DataType(dtype_str).itemsize()


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 64911ab

Please sign in to comment.