Skip to content

Commit

Permalink
[Lang] [type] Rename some parameters in quant APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier committed Jul 29, 2022
1 parent d88dacf commit 53a258f
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 38 deletions.
4 changes: 2 additions & 2 deletions python/taichi/_snode/fields_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ def bitmasked(self, indices: Union[Sequence[_Axis], _Axis],
return self.root.bitmasked(indices, dimensions)

def quant_array(self, indices: Union[Sequence[_Axis], _Axis],
dimensions: Union[Sequence[int], int], num_bits: int):
dimensions: Union[Sequence[int], int], max_num_bits: int):
"""Same as :func:`taichi.lang.snode.SNode.quant_array`"""
self._check_not_finalized()
self.empty = False
return self.root.quant_array(indices, dimensions, num_bits)
return self.root.quant_array(indices, dimensions, max_num_bits)

def place(self,
*args: Any,
Expand Down
6 changes: 3 additions & 3 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,21 @@ def bitmasked(self, axes, dimensions):
self.ptr.bitmasked(axes, dimensions,
impl.current_cfg().packed))

def quant_array(self, axes, dimensions, num_bits):
def quant_array(self, axes, dimensions, max_num_bits):
"""Adds a quant_array SNode as a child component of `self`.
Args:
axes (List[Axis]): Axes to activate.
dimensions (Union[List[int], int]): Shape of each axis.
num_bits (int): Number of bits to use.
max_num_bits (int): Maximum number of bits it can hold.
Returns:
The added :class:`~taichi.lang.SNode` instance.
"""
if isinstance(dimensions, int):
dimensions = [dimensions] * len(axes)
return SNode(
self.ptr.quant_array(axes, dimensions, num_bits,
self.ptr.quant_array(axes, dimensions, max_num_bits,
impl.current_cfg().packed))

def place(self, *args, offset=None):
Expand Down
16 changes: 8 additions & 8 deletions python/taichi/types/quantized_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def int(bits, signed=True, compute=None): # pylint: disable=W0622
return _type_factory.get_quant_int_type(bits, signed, compute)


def fixed(frac, signed=True, range=1.0, compute=None, scale=None): # pylint: disable=W0622
def fixed(bits, signed=True, max_value=1.0, compute=None, scale=None):
"""Generates a quantized type for fixed-point real numbers.
Args:
frac (int): Number of bits.
bits (int): Number of bits.
signed (bool): Signed or unsigned.
range (float): Range of the number.
max_value (float): Maximum value of the number.
compute (DataType): Type for computation.
scale (float): Scaling factor. The argument is prioritized over range.
Expand All @@ -44,14 +44,14 @@ def fixed(frac, signed=True, range=1.0, compute=None, scale=None): # pylint: di
compute = impl.get_runtime().default_fp
if isinstance(compute, _ti_python_core.DataType):
compute = compute.get_ptr()
# TODO: handle cases with frac > 32
frac_type = int(bits=frac, signed=signed, compute=i32)
# TODO: handle cases with bits > 32
underlying_type = int(bits=bits, signed=signed, compute=i32)
if scale is None:
if signed:
scale = range / 2**(frac - 1)
scale = max_value / 2 ** (bits - 1)
else:
scale = range / 2**frac
return _type_factory.get_quant_fixed_type(frac_type, compute, scale)
scale = max_value / 2 ** bits
return _type_factory.get_quant_fixed_type(underlying_type, compute, scale)


def float(exp, frac, signed=True, compute=None): # pylint: disable=W0622
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_bitpacked_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_bitpacked_fields_struct_for():
block_size = 16
N = 64
cell = ti.root.pointer(ti.i, N // block_size)
fixed32 = ti.types.quant.fixed(frac=32, range=1024)
fixed32 = ti.types.quant.fixed(bits=32, max_value=1024)

x = ti.field(dtype=fixed32)
bitpack = ti.BitpackedFields(max_num_bits=32)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_matrix_different_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def verify():
@test_utils.test(require=ti.extension.quant_basic)
def test_quant_type():
qit1 = ti.types.quant.int(bits=10, signed=True)
qfxt1 = ti.types.quant.fixed(frac=10, signed=True, scale=0.1)
qfxt1 = ti.types.quant.fixed(bits=10, signed=True, scale=0.1)
qit2 = ti.types.quant.int(bits=22, signed=False)
qfxt2 = ti.types.quant.fixed(frac=22, signed=False, scale=0.1)
qfxt2 = ti.types.quant.fixed(bits=22, signed=False, scale=0.1)
type_list = [[qit1, qfxt2], [qfxt1, qit2]]
a = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list)
b = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list)
Expand Down
12 changes: 6 additions & 6 deletions tests/python/test_quant_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_1D_quant_array():

N = 32

ti.root.quant_array(ti.i, N, num_bits=32).place(x)
ti.root.quant_array(ti.i, N, max_num_bits=32).place(x)

@ti.kernel
def set_val():
Expand All @@ -31,7 +31,7 @@ def test_1D_quant_array_negative():
N = 4
qi7 = ti.types.quant.int(7)
x = ti.field(dtype=qi7)
ti.root.quant_array(ti.i, N, num_bits=32).place(x)
ti.root.quant_array(ti.i, N, max_num_bits=32).place(x)

@ti.kernel
def assign():
Expand All @@ -45,13 +45,13 @@ def assign():

@test_utils.test(require=ti.extension.quant, debug=True)
def test_1D_quant_array_fixed():
qfxt = ti.types.quant.fixed(frac=8, range=2)
qfxt = ti.types.quant.fixed(bits=8, max_value=2)

x = ti.field(dtype=qfxt)

N = 4

ti.root.quant_array(ti.i, N, num_bits=32).place(x)
ti.root.quant_array(ti.i, N, max_num_bits=32).place(x)

@ti.kernel
def set_val():
Expand All @@ -75,7 +75,7 @@ def test_2D_quant_array():

M, N = 4, 8

ti.root.quant_array(ti.ij, (M, N), num_bits=32).place(x)
ti.root.quant_array(ti.ij, (M, N), max_num_bits=32).place(x)

@ti.kernel
def set_val():
Expand All @@ -102,7 +102,7 @@ def test_quant_array_struct_for():

x = ti.field(dtype=qi7)
cell.dense(ti.i, block_size // 4).quant_array(ti.i, 4,
num_bits=32).place(x)
max_num_bits=32).place(x)

@ti.kernel
def activate():
Expand Down
16 changes: 8 additions & 8 deletions tests/python/test_quant_array_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def test_vectorized_struct_for():

block = ti.root.pointer(ti.ij, (n_blocks, n_blocks))
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array(
ti.j, bits, num_bits=bits).place(x)
ti.j, bits, max_num_bits=bits).place(x)
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array(
ti.j, bits, num_bits=bits).place(y)
ti.j, bits, max_num_bits=bits).place(y)

@ti.kernel
def init():
Expand Down Expand Up @@ -61,11 +61,11 @@ def test_offset_load():

block = ti.root.pointer(ti.ij, (n_blocks, n_blocks))
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array(
ti.j, bits, num_bits=bits).place(x)
ti.j, bits, max_num_bits=bits).place(x)
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array(
ti.j, bits, num_bits=bits).place(y)
ti.j, bits, max_num_bits=bits).place(y)
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array(
ti.j, bits, num_bits=bits).place(z)
ti.j, bits, max_num_bits=bits).place(z)

@ti.kernel
def init():
Expand Down Expand Up @@ -121,11 +121,11 @@ def test_evolve():

block = ti.root.pointer(ti.ij, (n_blocks, n_blocks))
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array(
ti.j, bits, num_bits=bits).place(x)
ti.j, bits, max_num_bits=bits).place(x)
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array(
ti.j, bits, num_bits=bits).place(y)
ti.j, bits, max_num_bits=bits).place(y)
block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array(
ti.j, bits, num_bits=bits).place(z)
ti.j, bits, max_num_bits=bits).place(z)

@ti.kernel
def init():
Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_quant_atomics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_quant_int_atomics_b64():

x = ti.field(dtype=qi13)

ti.root.quant_array(ti.i, 4, num_bits=64).place(x)
ti.root.quant_array(ti.i, 4, max_num_bits=64).place(x)

x[0] = 100
x[1] = 200
Expand All @@ -70,8 +70,8 @@ def foo():

@test_utils.test(require=ti.extension.quant_basic, debug=True)
def test_quant_fixed_atomics():
qfxt13 = ti.types.quant.fixed(frac=13, signed=True, scale=0.1)
qfxt19 = ti.types.quant.fixed(frac=19, signed=False, scale=0.1)
qfxt13 = ti.types.quant.fixed(bits=13, signed=True, scale=0.1)
qfxt19 = ti.types.quant.fixed(bits=19, signed=False, scale=0.1)

x = ti.field(dtype=qfxt13)
y = ti.field(dtype=qfxt19)
Expand Down
8 changes: 4 additions & 4 deletions tests/python/test_quant_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@test_utils.test(require=ti.extension.quant_basic)
def test_quant_fixed():
qfxt = ti.types.quant.fixed(frac=32, range=2)
qfxt = ti.types.quant.fixed(bits=32, max_value=2)
x = ti.field(dtype=qfxt)

bitpack = ti.BitpackedFields(max_num_bits=32)
Expand All @@ -31,7 +31,7 @@ def foo():

@test_utils.test(require=ti.extension.quant_basic)
def test_quant_fixed_matrix_rotation():
qfxt = ti.types.quant.fixed(frac=16, range=1.2)
qfxt = ti.types.quant.fixed(bits=16, max_value=1.2)

x = ti.Matrix.field(2, 2, dtype=qfxt)

Expand Down Expand Up @@ -61,7 +61,7 @@ def rotate_18_degrees():

@test_utils.test(require=ti.extension.quant_basic)
def test_quant_fixed_implicit_cast():
qfxt = ti.types.quant.fixed(frac=13, scale=0.1)
qfxt = ti.types.quant.fixed(bits=13, scale=0.1)
x = ti.field(dtype=qfxt)

bitpack = ti.BitpackedFields(max_num_bits=32)
Expand All @@ -78,7 +78,7 @@ def foo():

@test_utils.test(require=ti.extension.quant_basic)
def test_quant_fixed_cache_read_only():
qfxt = ti.types.quant.fixed(frac=15, scale=0.1)
qfxt = ti.types.quant.fixed(bits=15, scale=0.1)
x = ti.field(dtype=qfxt)

bitpack = ti.BitpackedFields(max_num_bits=32)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_quant_time_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_quant_time_integration(use_quant, use_exponent, use_shared_exp):
bitpack.place(x.get_scalar_field(1))
ti.root.place(bitpack)
else:
qfxt = ti.types.quant.fixed(frac=16, range=2)
qfxt = ti.types.quant.fixed(bits=16, max_value=2)
x = ti.Vector.field(2, dtype=qfxt)
bitpack = ti.BitpackedFields(max_num_bits=32)
bitpack.place(x)
Expand Down

0 comments on commit 53a258f

Please sign in to comment.