Skip to content

Commit

Permalink
[Lang] [type] Fix parameter name 'range' for ti.types.quant.fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier committed May 18, 2022
1 parent 84b6c89 commit c8f494a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions python/taichi/types/quantized_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def int(bits, signed=False, compute=None): # pylint: disable=W0622
return _custom_int(bits, signed, compute)


def fixed(frac, signed=True, num_range=1.0, compute=None):
def fixed(frac, signed=True, range=1.0, compute=None): # pylint: disable=W0622
"""Generates a quantized type for fixed-point real numbers.
Args:
frac (int): Number of bits.
signed (bool): Signed or unsigned.
num_range (float): Range of the number.
range (float): Range of the number.
compute (DataType): Type for computation.
Returns:
Expand All @@ -83,9 +83,9 @@ def fixed(frac, signed=True, num_range=1.0, compute=None):
# TODO: handle cases with frac > 32
frac_type = int(bits=frac, signed=signed, compute=i32)
if signed:
scale = num_range / 2**(frac - 1)
scale = range / 2**(frac - 1)
else:
scale = num_range / 2**frac
scale = range / 2**frac
if compute is None:
compute = impl.get_runtime().default_fp
return _custom_float(frac_type, None, compute, scale)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_bit_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_bit_struct_struct_for():
block_size = 16
N = 64
cell = ti.root.pointer(ti.i, N // block_size)
fixed32 = ti.types.quant.fixed(frac=32, num_range=1024)
fixed32 = ti.types.quant.fixed(frac=32, range=1024)

x = ti.field(dtype=fixed32)
cell.dense(ti.i, block_size).bit_struct(32).place(x)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_custom_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@test_utils.test(require=ti.extension.quant_basic)
def test_custom_float():
cft = ti.types.quant.fixed(frac=32, num_range=2)
cft = ti.types.quant.fixed(frac=32, range=2)
x = ti.field(dtype=cft)

ti.root.bit_struct(num_bits=32).place(x)
Expand All @@ -29,7 +29,7 @@ def foo():

@test_utils.test(require=ti.extension.quant_basic)
def test_custom_matrix_rotation():
cft = ti.types.quant.fixed(frac=16, num_range=1.2)
cft = ti.types.quant.fixed(frac=16, range=1.2)

x = ti.Matrix.field(2, 2, dtype=cft)

Expand Down

0 comments on commit c8f494a

Please sign in to comment.