Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Fix potential precision bug when using math vector and matrix types #5032

Merged
merged 11 commits into from
May 25, 2022
Merged
5 changes: 3 additions & 2 deletions python/taichi/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ def vector_to_fast_image(img: template(), out: ndarray_type.ndarray()):
r, g, b = 0, 0, 0
color = img[i, img.shape[1] - 1 - j]
if static(img.dtype in [f16, f32, f64]):
r, g, b = min(255, max(0, int(color * 255)))
r, g, b = min(255, max(0, int(color * 255)))[:3]
neozhaoliang marked this conversation as resolved.
Show resolved Hide resolved
else:
static_assert(img.dtype == u8)
r, g, b = color
r, g, b = color[:3]

idx = j * img.shape[0] + i
# We use i32 for |out| since OpenGL and Metal doesn't support u8 types
if static(get_os_name() != 'osx'):
Expand Down
122 changes: 87 additions & 35 deletions python/taichi/math/mathimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,44 @@

import taichi as ti

vec2 = ti.types.vector(2, float) # pylint: disable=E1101
"""2D float vector type.
"""
_get_uint_ip = lambda: ti.u32 if impl.get_runtime(
).default_ip == ti.i32 else ti.u64

vec3 = ti.types.vector(3, float) # pylint: disable=E1101
"""3D float vector type.
"""

vec4 = ti.types.vector(4, float) # pylint: disable=E1101
"""4D float vector type.
"""
def vec2(*args):
"""2D floating vector type.
"""
return ti.types.vector(2, float)(*args) # pylint: disable=E1101

ivec2 = ti.types.vector(2, int) # pylint: disable=E1101
"""2D int vector type.
"""

ivec3 = ti.types.vector(3, int) # pylint: disable=E1101
"""3D int vector type.
"""
def vec3(*args):
"""3D floating vector type.
"""
return ti.types.vector(3, float)(*args) # pylint: disable=E1101

ivec4 = ti.types.vector(4, int) # pylint: disable=E1101
"""4D int vector type.
"""

mat2 = ti.types.matrix(2, 2, float) # pylint: disable=E1101
"""2x2 float matrix type.
"""
def vec4(*args):
"""4D floating vector type.
"""
return ti.types.vector(4, float)(*args) # pylint: disable=E1101

mat3 = ti.types.matrix(3, 3, float) # pylint: disable=E1101
"""3x3 float matrix type.
"""

mat4 = ti.types.matrix(4, 4, float) # pylint: disable=E1101
"""4x4 float matrix type.
"""
def ivec2(*args):
"""2D signed int vector type.
"""
return ti.types.vector(2, int)(*args) # pylint: disable=E1101

_get_uint_ip = lambda: ti.u32 if impl.get_runtime(
).default_ip == ti.i32 else ti.u64

def ivec3(*args):
"""3D signed int vector type.
"""
return ti.types.vector(3, int)(*args) # pylint: disable=E1101


def ivec4(*args):
"""4D signed int vector type.
"""
return ti.types.vector(4, int)(*args) # pylint: disable=E1101


def uvec2(*args):
Expand All @@ -68,6 +68,24 @@ def uvec4(*args):
return ti.types.vector(4, _get_uint_ip())(*args) # pylint: disable=E1101


def mat2(*args):
"""2x2 floating matrix type.
"""
return ti.types.matrix(2, 2, float)(*args) # pylint: disable=E1101


def mat3(*args):
"""3x3 floating matrix type.
"""
return ti.types.matrix(3, 3, float)(*args) # pylint: disable=E1101


def mat4(*args):
"""4x4 floating matrix type.
"""
return ti.types.matrix(4, 4, float)(*args) # pylint: disable=E1101


@ti.func
def mix(x, y, a):
"""Performs a linear interpolation between `x` and `y` using
Expand Down Expand Up @@ -611,12 +629,46 @@ def length(x):
return x.norm()


@ti.func
def determinant(m):
"""Alias for :func:`taichi.Matrix.determinant`.
"""
return m.determinant()


@ti.func
def inverse(mat): # pylint: disable=R1710
neozhaoliang marked this conversation as resolved.
Show resolved Hide resolved
"""Calculate the inverse of a matrix.

This function is equivalent to the `inverse` function in GLSL.

Args:
mat (:class:`taichi.Matrix`): The matrix of which to take the inverse.

Returns:
Inverse of the input matrix.

Example::

>>> @ti.kernel
>>> def test():
>>> m = mat3([(1, 1, 0), (0, 1, 1), (0, 0, 1)])
>>> print(inverse(m))
>>>
>>> test()
[[1.000000, -1.000000, 1.000000],
[0.000000, 1.000000, -1.000000],
[0.000000, 0.000000, 1.000000]]
"""
return mat.inverse()


__all__ = [
"acos", "asin", "atan2", "ceil", "clamp", "cos", "cross", "degrees",
"distance", "dot", "e", "exp", "eye", "floor", "fract", "ivec2", "ivec3",
"ivec4", "length", "log", "log2", "mat2", "mat3", "mat4", "max", "min",
"mix", "mod", "normalize", "pi", "pow", "radians", "reflect", "refract",
"rot2", "rot3", "rotate2d", "rotate3d", "round", "sign", "sin",
"smoothstep", "sqrt", "step", "tan", "tanh", "uvec2", "uvec3", "uvec4",
"vec2", "vec3", "vec4"
"determinant", "distance", "dot", "e", "exp", "eye", "floor", "fract",
"inverse", "ivec2", "ivec3", "ivec4", "length", "log", "log2", "mat2",
"mat3", "mat4", "max", "min", "mix", "mod", "normalize", "pi", "pow",
"radians", "reflect", "refract", "rot2", "rot3", "rotate2d", "rotate3d",
"round", "sign", "sin", "smoothstep", "sqrt", "step", "tan", "tanh",
"uvec2", "uvec3", "uvec4", "vec2", "vec3", "vec4"
]
14 changes: 7 additions & 7 deletions tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ def _get_expected_matrix_apis():
]
user_api[ti.math] = [
'acos', 'asin', 'atan2', 'cconj', 'cdiv', 'ceil', 'cexp', 'cinv', 'clamp',
'clog', 'cmul', 'cos', 'cpow', 'cross', 'csqrt', 'degrees', 'distance',
'dot', 'e', 'exp', 'eye', 'floor', 'fract', 'ivec2', 'ivec3', 'ivec4',
'length', 'log', 'log2', 'mat2', 'mat3', 'mat4', 'max', 'min', 'mix',
'mod', 'normalize', 'pi', 'pow', 'radians', 'reflect', 'refract', 'rot2',
'rot3', 'rotate2d', 'rotate3d', 'round', 'sign', 'sin', 'smoothstep',
'sqrt', 'step', 'tan', 'tanh', 'uvec2', 'uvec3', 'uvec4', 'vec2', 'vec3',
'vec4'
'clog', 'cmul', 'cos', 'cpow', 'cross', 'csqrt', 'degrees', 'determinant',
'distance', 'dot', 'e', 'exp', 'eye', 'floor', 'fract', 'inverse', 'ivec2',
'ivec3', 'ivec4', 'length', 'log', 'log2', 'mat2', 'mat3', 'mat4', 'max',
'min', 'mix', 'mod', 'normalize', 'pi', 'pow', 'radians', 'reflect',
'refract', 'rot2', 'rot3', 'rotate2d', 'rotate3d', 'round', 'sign', 'sin',
'smoothstep', 'sqrt', 'step', 'tan', 'tanh', 'uvec2', 'uvec3', 'uvec4',
'vec2', 'vec3', 'vec4'
]
user_api[ti.Matrix] = _get_expected_matrix_apis()
user_api[ti.MatrixField] = [
Expand Down