diff --git a/tripy/docs/conf.py b/tripy/docs/conf.py index 8b30df9a1..dfb827d9c 100644 --- a/tripy/docs/conf.py +++ b/tripy/docs/conf.py @@ -28,7 +28,7 @@ import tripy as tp from tripy.common.datatype import DATA_TYPES -from tripy.constraints import TYPE_VERIFICATION +from tripy.wrappers import TYPE_VERIFICATION PARAM_PAT = re.compile(":param .*?:") diff --git a/tripy/docs/post0_developer_guides/how-to-add-new-ops.md b/tripy/docs/post0_developer_guides/how-to-add-new-ops.md index 632345571..43accc2c0 100644 --- a/tripy/docs/post0_developer_guides/how-to-add-new-ops.md +++ b/tripy/docs/post0_developer_guides/how-to-add-new-ops.md @@ -176,7 +176,7 @@ it as a `tripy.Module` under [`frontend/module`](source:/tripy/frontend/module). ```py # doc: no-eval -from tripy import constraints, export +from tripy import export, wrappers from tripy.types import ShapeLike # We can use the `export.public_api()` decorator to automatically export this @@ -191,12 +191,12 @@ from tripy.types import ShapeLike # `autodoc_options` parameter. @export.public_api(document_under="tensor_operations") -# We can use the `constraints.interface` decorator to specify constraints on +# We can use the `wrappers.interface` decorator to specify constraints on # inputs and perform transformations on them, like automatically converting # compatible arguments (e.g., `TensorLike` or `ShapeLike`s) into tensors. # We will aim to include most constraints and transformations in this decorator # so as to avoid layering too many decorators. -@constraints.interface(convert_tensor_and_shape_likes=True) +@wrappers.interface(convert_to_tensors=True) def theta(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "tripy.Tensor": # For any public facing interfaces, we have documentation requirements which # you can read about in the 'Docs README' (linked below). The docstring diff --git a/tripy/tests/common/test_exception.py b/tripy/tests/common/test_exception.py index a006025c3..09b04782b 100644 --- a/tripy/tests/common/test_exception.py +++ b/tripy/tests/common/test_exception.py @@ -124,14 +124,14 @@ def test_can_determine_column_range(self): in dedent(error_msg).strip() ) - def test_constraints_is_excluded(self): - from tripy import constraints + def test_wrappers_is_excluded(self): + from tripy import wrappers tensor = tp.ones((2, 3)) stack_info = tensor.stack_info - assert any(frame.module == constraints.__name__ for frame in stack_info) + assert any(frame.module == wrappers.__name__ for frame in stack_info) # Make sure that no extraneous wrapper code is included expected = dedent( @@ -146,7 +146,7 @@ def test_constraints_is_excluded(self): [0-9]+ | return full\(shape, 1, dtype\) | ^^^^^^^^^^^^^^^^^^^^^ --- required from here - --> [a-z_/\.]+:[0-9]+ in test_constraints_is_excluded\(\) + --> [a-z_/\.]+:[0-9]+ in test_wrappers_is_excluded\(\) | [0-9]+ | tensor = tp.ones\(\(2, 3\)\) | ^^^^^^^^^^^^^^^ --- required from here diff --git a/tripy/tests/constraints/object_builders.py b/tripy/tests/wrappers/object_builders.py similarity index 100% rename from tripy/tests/constraints/object_builders.py rename to tripy/tests/wrappers/object_builders.py diff --git a/tripy/tests/constraints/test_interface.py b/tripy/tests/wrappers/test_interface.py similarity index 91% rename from tripy/tests/constraints/test_interface.py rename to tripy/tests/wrappers/test_interface.py index bbb1d178b..8d089ed76 100755 --- a/tripy/tests/constraints/test_interface.py +++ b/tripy/tests/wrappers/test_interface.py @@ -24,12 +24,12 @@ import pytest from tests import helper from tests.conftest import skip_if_older_than_sm89 -from tests.constraints.object_builders import create_obj +from tests.wrappers.object_builders import create_obj import tripy as tp -from tripy import constraints +from tripy import wrappers from tripy.common.datatype import DATA_TYPES -from tripy.constraints import TYPE_VERIFICATION +from tripy.wrappers import TYPE_VERIFICATION from tripy.export import PUBLIC_APIS # Get all functions/methods which have tensors in the type signature @@ -239,7 +239,7 @@ def test_dtype_constraints(test_data): assert ret_val.dtype == namespace[return_dtype] -@constraints.interface(dtype_constraints={"tensors": "T1"}, variables={"T1": ["float32"]}) +@wrappers.interface(dtype_constraints={"tensors": "T1"}, dtype_variables={"T1": ["float32"]}) def sequence_func(tensors: List[tp.Tensor]): return @@ -255,7 +255,7 @@ def test_raises_on_mismatched_sequence_dtypes(self): class TestTensorConversion: def test_no_effect_on_non_tensor_likes(self): - @constraints.interface(convert_tensor_and_shape_likes=True) + @wrappers.interface(convert_to_tensors=True) def func(a: tp.Tensor, b: int): return a, b @@ -266,7 +266,7 @@ def func(a: tp.Tensor, b: int): assert b is 4 def test_tensor_likes(self): - @constraints.interface(convert_tensor_and_shape_likes=True) + @wrappers.interface(convert_to_tensors=True) def func(a: tp.types.TensorLike): return a @@ -277,7 +277,7 @@ def func(a: tp.types.TensorLike): def test_converts_to_dimension_size(self): # The decorator should convert to DimensionSizes when possible. - @constraints.interface(convert_tensor_and_shape_likes=True) + @wrappers.interface(convert_to_tensors=True) def func(a: tp.types.TensorLike): return a @@ -289,7 +289,7 @@ def func(a: tp.types.TensorLike): assert type(a) is tp.Tensor def test_shape_likes(self): - @constraints.interface(convert_tensor_and_shape_likes=True) + @wrappers.interface(convert_to_tensors=True) def func(a: tp.types.ShapeLike): return a @@ -310,7 +310,7 @@ def func(a: tp.types.ShapeLike): assert bool(tp.all(a == tp.Tensor([2, 2, 3, 5]))) def test_keyword_args(self): - @constraints.interface(convert_tensor_and_shape_likes=True) + @wrappers.interface(convert_to_tensors=True) def func(a: tp.types.TensorLike): return a @@ -320,7 +320,7 @@ def func(a: tp.types.TensorLike): assert a.stack_info[4].column_range == (17, 22) def test_multiple_args(self): - @constraints.interface(convert_tensor_and_shape_likes=True) + @wrappers.interface(convert_to_tensors=True) def func(a: tp.types.TensorLike, b: tp.types.TensorLike): return a, b @@ -333,7 +333,7 @@ def func(a: tp.types.TensorLike, b: tp.types.TensorLike): assert b.stack_info[4].column_range == (25, 28) def test_args_out_of_order(self): - @constraints.interface(convert_tensor_and_shape_likes=True) + @wrappers.interface(convert_to_tensors=True) def func(a: tp.types.TensorLike, b: tp.types.TensorLike): return a, b @@ -349,10 +349,10 @@ def func(a: tp.types.TensorLike, b: tp.types.TensorLike): def test_cast_dtype(self): # When type constraints are included, the decorator should automatically cast when possible. - @constraints.interface( - dtype_constraints={"a": "T1", "b": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float16"]}, - convert_tensor_and_shape_likes=True, + @wrappers.interface( + dtype_constraints={"a": "T1", "b": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float16"]}, + convert_to_tensors=True, ) def func(a: tp.Tensor, b: tp.types.TensorLike): return a, b @@ -369,10 +369,10 @@ def func(a: tp.Tensor, b: tp.types.TensorLike): @pytest.mark.parametrize("arg, dtype", [(1.0, tp.int32), (1.0, tp.int64), (2, tp.bool)]) def test_refuse_unsafe_cast(self, arg, dtype): - @constraints.interface( - dtype_constraints={"a": "T1", "b": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["int32", "int64"]}, - convert_tensor_and_shape_likes=True, + @wrappers.interface( + dtype_constraints={"a": "T1", "b": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["int32", "int64"]}, + convert_to_tensors=True, ) def func(a: tp.Tensor, b: tp.types.TensorLike): return a, b @@ -385,7 +385,7 @@ def test_preprocess_func(self): def add_a_to_b(a, b): return {"b": a + b} - @constraints.interface(convert_tensor_and_shape_likes=True, conversion_preprocess_func=add_a_to_b) + @wrappers.interface(convert_to_tensors=True, conversion_preprocess_func=add_a_to_b) def func(a: tp.types.TensorLike, b: tp.types.TensorLike): return a, b @@ -398,7 +398,7 @@ def test_variadic_args(self): def increment(a, *args): return {"a": a + 1, "args": list(map(lambda arg: arg + 1, args))} - @constraints.interface(convert_tensor_and_shape_likes=True, conversion_preprocess_func=increment) + @wrappers.interface(convert_to_tensors=True, conversion_preprocess_func=increment) def func(a: tp.Tensor, *args): return [a] + list(args) diff --git a/tripy/tripy/frontend/ops/allclose.py b/tripy/tripy/frontend/ops/allclose.py index 95933370c..fd7f1b5c0 100644 --- a/tripy/tripy/frontend/ops/allclose.py +++ b/tripy/tripy/frontend/ops/allclose.py @@ -15,12 +15,12 @@ # limitations under the License. # -from tripy import constraints, export +from tripy import export, wrappers @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", "other": "T1"}, variables={"T1": ["float32", "float16", "bfloat16"]} +@wrappers.interface( + dtype_constraints={"input": "T1", "other": "T1"}, dtype_variables={"T1": ["float32", "float16", "bfloat16"]} ) def allclose(input: "tripy.Tensor", other: "tripy.Tensor", rtol: float = 1e-05, atol: float = 1e-08) -> bool: r""" diff --git a/tripy/tripy/frontend/ops/cumsum.py b/tripy/tripy/frontend/ops/cumsum.py index b3db535ee..fc9fa888a 100644 --- a/tripy/tripy/frontend/ops/cumsum.py +++ b/tripy/tripy/frontend/ops/cumsum.py @@ -12,15 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tripy import constraints, export +from tripy import export, wrappers from tripy.frontend import utils as frontend_utils @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int32"], }, ) diff --git a/tripy/tripy/frontend/ops/equal.py b/tripy/tripy/frontend/ops/equal.py index 0c4ddd693..5590c387e 100644 --- a/tripy/tripy/frontend/ops/equal.py +++ b/tripy/tripy/frontend/ops/equal.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.datatype import DATA_TYPES @export.public_api(document_under="operations/functions") -@constraints.interface(dtype_constraints={"input": "T1", "other": "T1"}, variables={"T1": list(DATA_TYPES.keys())}) +@wrappers.interface(dtype_constraints={"input": "T1", "other": "T1"}, dtype_variables={"T1": list(DATA_TYPES.keys())}) def equal(input: "tripy.Tensor", other: "tripy.Tensor") -> bool: r""" Returns ``True`` if ``input`` and ``other`` have the same shape and elements. diff --git a/tripy/tripy/frontend/ops/flatten.py b/tripy/tripy/frontend/ops/flatten.py index e2407b5a0..930a69d70 100644 --- a/tripy/tripy/frontend/ops/flatten.py +++ b/tripy/tripy/frontend/ops/flatten.py @@ -14,14 +14,14 @@ # limitations under the License. import math -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.exception import raise_error @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, ) def flatten(input: "tripy.Tensor", start_dim: int = 0, end_dim: int = -1) -> "tripy.Tensor": """ diff --git a/tripy/tripy/frontend/ops/gelu.py b/tripy/tripy/frontend/ops/gelu.py index ccf88fe45..19e62d7fc 100644 --- a/tripy/tripy/frontend/ops/gelu.py +++ b/tripy/tripy/frontend/ops/gelu.py @@ -17,13 +17,13 @@ import math -from tripy import export, constraints +from tripy import export, wrappers @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16"], }, ) diff --git a/tripy/tripy/frontend/ops/outer.py b/tripy/tripy/frontend/ops/outer.py index 023a0726a..ce229ebc9 100644 --- a/tripy/tripy/frontend/ops/outer.py +++ b/tripy/tripy/frontend/ops/outer.py @@ -15,13 +15,13 @@ # limitations under the License. # -from tripy import constraints, export +from tripy import export, wrappers @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"vec1": "T1", "vec2": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "int32"]}, +@wrappers.interface( + dtype_constraints={"vec1": "T1", "vec2": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32"]}, ) def outer(vec1: "tripy.Tensor", vec2: "tripy.Tensor") -> "tripy.Tensor": r""" diff --git a/tripy/tripy/frontend/ops/relu.py b/tripy/tripy/frontend/ops/relu.py index d0ea28bbf..990397a51 100644 --- a/tripy/tripy/frontend/ops/relu.py +++ b/tripy/tripy/frontend/ops/relu.py @@ -15,13 +15,13 @@ # limitations under the License. # -from tripy import export, constraints +from tripy import export, wrappers @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int4", "int32", "int64", "bool", "int8"], }, ) diff --git a/tripy/tripy/frontend/ops/repeat.py b/tripy/tripy/frontend/ops/repeat.py index 01d295a27..898804334 100644 --- a/tripy/tripy/frontend/ops/repeat.py +++ b/tripy/tripy/frontend/ops/repeat.py @@ -15,15 +15,15 @@ # from typing import Union -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.exception import raise_error from tripy.frontend import utils as frontend_utils @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int4", "float8", "int8", "int32", "int64", "bool"], }, ) diff --git a/tripy/tripy/frontend/ops/sigmoid.py b/tripy/tripy/frontend/ops/sigmoid.py index b6c3dc227..45469dcdd 100644 --- a/tripy/tripy/frontend/ops/sigmoid.py +++ b/tripy/tripy/frontend/ops/sigmoid.py @@ -15,13 +15,13 @@ # limitations under the License. # -from tripy import export, constraints +from tripy import export, wrappers @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16"], }, ) diff --git a/tripy/tripy/frontend/ops/silu.py b/tripy/tripy/frontend/ops/silu.py index cbe9f0ff2..9b3228a55 100644 --- a/tripy/tripy/frontend/ops/silu.py +++ b/tripy/tripy/frontend/ops/silu.py @@ -15,13 +15,13 @@ # limitations under the License. # -from tripy import export, constraints +from tripy import export, wrappers @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16"], }, ) diff --git a/tripy/tripy/frontend/ops/softmax.py b/tripy/tripy/frontend/ops/softmax.py index f994d2f3d..1d7c045da 100644 --- a/tripy/tripy/frontend/ops/softmax.py +++ b/tripy/tripy/frontend/ops/softmax.py @@ -15,13 +15,13 @@ # limitations under the License. # -from tripy import export, constraints +from tripy import export, wrappers @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16"], }, ) diff --git a/tripy/tripy/frontend/ops/stack.py b/tripy/tripy/frontend/ops/stack.py index 8ba51895a..558dfae63 100644 --- a/tripy/tripy/frontend/ops/stack.py +++ b/tripy/tripy/frontend/ops/stack.py @@ -15,14 +15,14 @@ from typing import Sequence -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.exception import raise_error @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"tensors": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"tensors": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, ) diff --git a/tripy/tripy/frontend/ops/tensor_initializers.py b/tripy/tripy/frontend/ops/tensor_initializers.py index 3021503f8..244822f29 100644 --- a/tripy/tripy/frontend/ops/tensor_initializers.py +++ b/tripy/tripy/frontend/ops/tensor_initializers.py @@ -18,7 +18,7 @@ import numbers from typing import Optional, Union -from tripy import constraints, export +from tripy import export, wrappers from tripy.common import datatype from tripy.common.exception import raise_error from tripy.frontend.trace.ops.fill import full, full_like @@ -27,9 +27,9 @@ @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int8", "int4", "int32", "int64", "bool"], }, ) @@ -61,9 +61,9 @@ def ones( @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int8", "int4", "int32", "int64", "bool"], }, ) @@ -95,9 +95,9 @@ def zeros( @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"input": "T1", "dtype": "T2", constraints.RETURN_VALUE: "T2"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, @@ -128,9 +128,9 @@ def ones_like(input: "tripy.Tensor", dtype: Optional[datatype.dtype] = None) -> @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"input": "T1", "dtype": "T2", constraints.RETURN_VALUE: "T2"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, @@ -162,9 +162,9 @@ def zeros_like(input: "tripy.Tensor", dtype: Optional[datatype.dtype] = None) -> @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"tensor": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"tensor": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int32", "int64", "bool"], }, ) @@ -220,9 +220,9 @@ def tril(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor": @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"tensor": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"tensor": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int32", "int64", "bool"], }, ) @@ -278,9 +278,9 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor": @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], }, ) @@ -343,9 +343,9 @@ def arange( @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], }, ) diff --git a/tripy/tripy/frontend/ops/transpose.py b/tripy/tripy/frontend/ops/transpose.py index 91e321920..9d37ed39e 100644 --- a/tripy/tripy/frontend/ops/transpose.py +++ b/tripy/tripy/frontend/ops/transpose.py @@ -12,16 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass - -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.exception import raise_error @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, ) def transpose(input: "tripy.Tensor", dim0: int, dim1: int) -> "tripy.Tensor": """ diff --git a/tripy/tripy/frontend/ops/unsqueeze.py b/tripy/tripy/frontend/ops/unsqueeze.py index 72eca82f7..8fce4bc3e 100644 --- a/tripy/tripy/frontend/ops/unsqueeze.py +++ b/tripy/tripy/frontend/ops/unsqueeze.py @@ -15,13 +15,13 @@ # limitations under the License. # -from tripy import constraints, export +from tripy import export, wrappers @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, ) def unsqueeze(input: "tripy.Tensor", dim: int) -> "tripy.Tensor": """ diff --git a/tripy/tripy/frontend/trace/ops/binary_elementwise.py b/tripy/tripy/frontend/trace/ops/binary_elementwise.py index 038c69e90..2e62e07bc 100644 --- a/tripy/tripy/frontend/trace/ops/binary_elementwise.py +++ b/tripy/tripy/frontend/trace/ops/binary_elementwise.py @@ -18,8 +18,7 @@ from dataclasses import dataclass import tripy.frontend.trace.ops.utils as op_utils -import tripy.frontend.utils as frontend_utils -from tripy import constraints, export +from tripy import export, wrappers from tripy.common import datatype from tripy.frontend.ops.registry import register_tensor_method from tripy.frontend.trace.ops.base import BaseTraceOp @@ -186,11 +185,11 @@ def to_flat_ir(self, inputs, outputs): @register_tensor_method("__add__") @register_tensor_method("__radd__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, aliases=["__radd__"], - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, ) def __add__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -218,10 +217,10 @@ def __add__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__sub__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, + convert_to_tensors=True, ) def __sub__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -249,10 +248,10 @@ def __sub__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__rsub__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, + convert_to_tensors=True, ) def __rsub__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -280,10 +279,10 @@ def __rsub__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__pow__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "int8"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8"]}, + convert_to_tensors=True, ) def __pow__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -311,10 +310,10 @@ def __pow__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__rpow__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "int8"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8"]}, + convert_to_tensors=True, ) def __rpow__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -343,11 +342,11 @@ def __rpow__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__mul__") @register_tensor_method("__rmul__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"]}, aliases=["__rmul__"], - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, ) def __mul__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -375,10 +374,10 @@ def __mul__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__truediv__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, + convert_to_tensors=True, ) def __truediv__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -406,10 +405,10 @@ def __truediv__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__rtruediv__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, + convert_to_tensors=True, ) def __rtruediv__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -437,10 +436,10 @@ def __rtruediv__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__floordiv__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64"]}, + convert_to_tensors=True, ) def __floordiv__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -473,10 +472,10 @@ def __floordiv__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__rfloordiv__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64"]}, + convert_to_tensors=True, ) def __rfloordiv__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -509,10 +508,10 @@ def __rfloordiv__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__mod__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8"]}, + convert_to_tensors=True, ) def __mod__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -540,10 +539,10 @@ def __mod__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__rmod__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8"]}, + convert_to_tensors=True, ) def __rmod__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -570,9 +569,9 @@ def __rmod__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"lhs": "T1", "rhs": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, +@wrappers.interface( + dtype_constraints={"lhs": "T1", "rhs": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, ) def maximum(lhs: "tripy.Tensor", rhs: "tripy.Tensor") -> "tripy.Tensor": """ @@ -600,9 +599,9 @@ def maximum(lhs: "tripy.Tensor", rhs: "tripy.Tensor") -> "tripy.Tensor": @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"lhs": "T1", "rhs": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, +@wrappers.interface( + dtype_constraints={"lhs": "T1", "rhs": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, ) def minimum(lhs: "tripy.Tensor", rhs: "tripy.Tensor") -> "tripy.Tensor": """ @@ -630,13 +629,13 @@ def minimum(lhs: "tripy.Tensor", rhs: "tripy.Tensor") -> "tripy.Tensor": @register_tensor_method("__lt__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, - variables={ +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["bool"], }, - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, ) def __lt__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -664,13 +663,13 @@ def __lt__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__le__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, - variables={ +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["bool"], }, - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, ) def __le__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -698,13 +697,13 @@ def __le__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__eq__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, - variables={ +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["bool"], }, - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, ) def __eq__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -732,13 +731,13 @@ def __eq__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__ne__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, - variables={ +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["bool"], }, - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, ) def __ne__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -766,13 +765,13 @@ def __ne__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__ge__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, - variables={ +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["bool"], }, - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, ) def __ge__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ @@ -800,13 +799,13 @@ def __ge__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": @register_tensor_method("__gt__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, - variables={ +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["bool"], }, - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, ) def __gt__(self: "tripy.Tensor", other: TensorLike) -> "tripy.Tensor": """ diff --git a/tripy/tripy/frontend/trace/ops/cast.py b/tripy/tripy/frontend/trace/ops/cast.py index 1cb5a3c74..7d45c198a 100644 --- a/tripy/tripy/frontend/trace/ops/cast.py +++ b/tripy/tripy/frontend/trace/ops/cast.py @@ -17,7 +17,7 @@ from dataclasses import dataclass -from tripy import constraints, export +from tripy import export, wrappers from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp @@ -91,9 +91,9 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", "dtype": "T2", constraints.RETURN_VALUE: "T2"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, diff --git a/tripy/tripy/frontend/trace/ops/concatenate.py b/tripy/tripy/frontend/trace/ops/concatenate.py index 2f917a7e9..04d1dfe22 100644 --- a/tripy/tripy/frontend/trace/ops/concatenate.py +++ b/tripy/tripy/frontend/trace/ops/concatenate.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from typing import Sequence -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.exception import raise_error from tripy.frontend.trace.ops.base import BaseTraceOp import tripy.frontend.trace.ops.utils as op_utils @@ -42,9 +42,9 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"tensors": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"tensors": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, ) diff --git a/tripy/tripy/frontend/trace/ops/convolution.py b/tripy/tripy/frontend/trace/ops/convolution.py index 4582d096e..f40a70880 100644 --- a/tripy/tripy/frontend/trace/ops/convolution.py +++ b/tripy/tripy/frontend/trace/ops/convolution.py @@ -19,7 +19,7 @@ from dataclasses import dataclass import tripy.frontend.trace.ops.utils as op_utils -from tripy import constraints +from tripy import wrappers from tripy.frontend.trace.ops.base import BaseTraceOp @@ -50,9 +50,9 @@ def to_flat_ir(self, inputs, outputs): ) -@constraints.interface( - dtype_constraints={"input": "T1", "weight": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", "weight": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16"], }, ) diff --git a/tripy/tripy/frontend/trace/ops/copy.py b/tripy/tripy/frontend/trace/ops/copy.py index 83435101f..bb26202d0 100644 --- a/tripy/tripy/frontend/trace/ops/copy.py +++ b/tripy/tripy/frontend/trace/ops/copy.py @@ -18,7 +18,7 @@ from dataclasses import dataclass import tripy.frontend.trace.ops.utils as op_utils -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.device import device from tripy.frontend.trace.ops.base import BaseTraceOp @@ -39,9 +39,9 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, ) diff --git a/tripy/tripy/frontend/trace/ops/dequantize.py b/tripy/tripy/frontend/trace/ops/dequantize.py index e22dfb5ca..9f379703b 100644 --- a/tripy/tripy/frontend/trace/ops/dequantize.py +++ b/tripy/tripy/frontend/trace/ops/dequantize.py @@ -20,9 +20,8 @@ from typing import Any, Sequence, Union import tripy.frontend.trace.ops.utils as op_utils -from tripy import constraints, export +from tripy import export, wrappers from tripy.common import datatype -from tripy.frontend import utils as frontend_utils from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp import tripy.frontend.trace.ops.utils as op_utils @@ -104,10 +103,10 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/quantization") -@constraints.interface( - dtype_constraints={"input": "T1", "scale": "T2", "dtype": "T2", constraints.RETURN_VALUE: "T2"}, - variables={"T1": ["int4", "int8", "float8"], "T2": ["float32", "float16", "bfloat16"]}, - convert_tensor_and_shape_likes={"scale"}, +@wrappers.interface( + dtype_constraints={"input": "T1", "scale": "T2", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={"T1": ["int4", "int8", "float8"], "T2": ["float32", "float16", "bfloat16"]}, + convert_to_tensors={"scale"}, ) def dequantize( input: "tripy.Tensor", diff --git a/tripy/tripy/frontend/trace/ops/expand.py b/tripy/tripy/frontend/trace/ops/expand.py index f43dff075..87adae404 100644 --- a/tripy/tripy/frontend/trace/ops/expand.py +++ b/tripy/tripy/frontend/trace/ops/expand.py @@ -16,11 +16,9 @@ # from dataclasses import dataclass -from typing import Optional -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.exception import raise_error -from tripy.frontend import utils as frontend_utils from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp from tripy.types import ShapeLike @@ -73,12 +71,12 @@ def process_sizes(input: "tripy.Tensor", sizes: ShapeLike): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"], }, - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, conversion_preprocess_func=process_sizes, ) def expand(input: "tripy.Tensor", sizes: ShapeLike) -> "tripy.Tensor": diff --git a/tripy/tripy/frontend/trace/ops/fill.py b/tripy/tripy/frontend/trace/ops/fill.py index f142f5d43..b4a92d856 100644 --- a/tripy/tripy/frontend/trace/ops/fill.py +++ b/tripy/tripy/frontend/trace/ops/fill.py @@ -20,7 +20,7 @@ import tripy.frontend.trace.ops.utils as op_utils import tripy.frontend.utils as frontend_utils -from tripy import constraints, export, utils +from tripy import export, utils, wrappers from tripy.common import datatype from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp @@ -69,12 +69,12 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, ) def full(shape: ShapeLike, value: TensorLike, dtype: "tripy.dtype" = datatype.float32) -> "tripy.Tensor": """ @@ -100,13 +100,13 @@ def full(shape: ShapeLike, value: TensorLike, dtype: "tripy.dtype" = datatype.fl @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"input": "T1", "dtype": "T2", constraints.RETURN_VALUE: "T2"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, ) def full_like(input: "tripy.Tensor", value: TensorLike, dtype: Optional["tripy.dtype"] = None) -> "tripy.Tensor": """ diff --git a/tripy/tripy/frontend/trace/ops/flip.py b/tripy/tripy/frontend/trace/ops/flip.py index 264bcf2fc..0bc85d9f4 100644 --- a/tripy/tripy/frontend/trace/ops/flip.py +++ b/tripy/tripy/frontend/trace/ops/flip.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from typing import Optional, Sequence, Union -from tripy import export, utils, constraints +from tripy import export, utils, wrappers from tripy.common.exception import raise_error from tripy.frontend.trace.ops.base import BaseTraceOp from tripy.frontend.trace.ops import utils as op_utils @@ -37,9 +37,9 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "int32", "int64", "bool"], }, ) diff --git a/tripy/tripy/frontend/trace/ops/gather.py b/tripy/tripy/frontend/trace/ops/gather.py index 302de7ea7..0e8529ac9 100644 --- a/tripy/tripy/frontend/trace/ops/gather.py +++ b/tripy/tripy/frontend/trace/ops/gather.py @@ -18,7 +18,7 @@ from dataclasses import dataclass import tripy.frontend.trace.ops.utils as op_utils -from tripy import constraints, export, utils +from tripy import export, utils, wrappers from tripy.frontend.trace.ops.base import BaseTraceOp @@ -88,9 +88,9 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", "index": "T2", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", "index": "T2", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float8", "float32", "float16", "bfloat16", "int4", "int8", "int32", "int64", "bool"], "T2": ["int32"], }, diff --git a/tripy/tripy/frontend/trace/ops/iota.py b/tripy/tripy/frontend/trace/ops/iota.py index 14a99ac6a..41e104df1 100644 --- a/tripy/tripy/frontend/trace/ops/iota.py +++ b/tripy/tripy/frontend/trace/ops/iota.py @@ -19,7 +19,7 @@ from typing import Optional import tripy.frontend.trace.ops.utils as op_utils -from tripy import constraints, export, utils +from tripy import export, utils, wrappers from tripy.common import datatype from tripy.frontend import utils as frontend_utils from tripy.frontend.trace.ops.base import BaseTraceOp @@ -61,12 +61,12 @@ def iota_impl(shape: "tripy.Tensor", dim: int, dtype: datatype.dtype, output_ran @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"dtype": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, - convert_tensor_and_shape_likes=True, + convert_to_tensors=True, conversion_preprocess_func=lambda shape, dim=None, dtype=None: ( {"dim": frontend_utils.process_dim(dim, len(shape))} if dim is not None else {} ), @@ -96,9 +96,9 @@ def iota(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float3 @export.public_api(document_under="operations/initializers") -@constraints.interface( - dtype_constraints={"input": "T1", "dtype": "T2", constraints.RETURN_VALUE: "T2"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, diff --git a/tripy/tripy/frontend/trace/ops/matmul.py b/tripy/tripy/frontend/trace/ops/matmul.py index de7b37b44..321998d85 100644 --- a/tripy/tripy/frontend/trace/ops/matmul.py +++ b/tripy/tripy/frontend/trace/ops/matmul.py @@ -19,7 +19,7 @@ from typing import Dict, List import tripy.frontend.trace.ops.utils as op_utils -from tripy import constraints +from tripy import wrappers from tripy.common.exception import raise_error from tripy.frontend.ops.registry import register_tensor_method from tripy.frontend.trace.ops.base import BaseTraceOp @@ -149,9 +149,9 @@ def append_ones_data_tensor(input, nb_ones): @register_tensor_method("__matmul__") -@constraints.interface( - dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "int32"]}, +@wrappers.interface( + dtype_constraints={"self": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32"]}, ) def __matmul__(self: "tripy.Tensor", other: "tripy.Tensor") -> "tripy.Tensor": """ diff --git a/tripy/tripy/frontend/trace/ops/pad.py b/tripy/tripy/frontend/trace/ops/pad.py index f15d3ae8a..feb14caba 100644 --- a/tripy/tripy/frontend/trace/ops/pad.py +++ b/tripy/tripy/frontend/trace/ops/pad.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from typing import Sequence, Union -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.exception import raise_error from tripy.frontend import utils as frontend_utils from tripy.frontend.trace.ops.base import BaseTraceOp @@ -73,9 +73,9 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bool", "int32"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bool", "int32"]}, ) def pad( input: "tripy.Tensor", pad: Sequence[ShapeLike], mode: str = "constant", value: Union[int, float] = 0 diff --git a/tripy/tripy/frontend/trace/ops/permute.py b/tripy/tripy/frontend/trace/ops/permute.py index b2ffe2971..120175858 100644 --- a/tripy/tripy/frontend/trace/ops/permute.py +++ b/tripy/tripy/frontend/trace/ops/permute.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from typing import Sequence -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.exception import raise_error from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp @@ -37,9 +37,9 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, ) def permute(input: "tripy.Tensor", perm: Sequence[int]) -> "tripy.Tensor": """ diff --git a/tripy/tripy/frontend/trace/ops/pooling.py b/tripy/tripy/frontend/trace/ops/pooling.py index d88c7199c..efecee59c 100644 --- a/tripy/tripy/frontend/trace/ops/pooling.py +++ b/tripy/tripy/frontend/trace/ops/pooling.py @@ -19,7 +19,7 @@ from dataclasses import dataclass from typing import Optional, Sequence, Tuple -from tripy import constraints, export, utils +from tripy import export, utils, wrappers from tripy.common.exception import raise_error from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp @@ -139,9 +139,9 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "int8"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "int8"]}, ) def maxpool( input: "tripy.Tensor", @@ -198,9 +198,9 @@ def maxpool( @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "int8"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "int8"]}, ) def avgpool( input: "tripy.Tensor", diff --git a/tripy/tripy/frontend/trace/ops/quantize.py b/tripy/tripy/frontend/trace/ops/quantize.py index 0c1a846b6..2467ea3c5 100644 --- a/tripy/tripy/frontend/trace/ops/quantize.py +++ b/tripy/tripy/frontend/trace/ops/quantize.py @@ -19,9 +19,8 @@ from dataclasses import dataclass from typing import Any, Sequence, Union -from tripy import constraints, export +from tripy import export, wrappers from tripy.common import datatype -from tripy.frontend import utils as frontend_utils from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp @@ -131,10 +130,10 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/quantization") -@constraints.interface( - dtype_constraints={"input": "T1", "scale": "T1", "dtype": "T2", constraints.RETURN_VALUE: "T2"}, - variables={"T1": ["float32", "float16", "bfloat16"], "T2": ["int4", "int8", "float8"]}, - convert_tensor_and_shape_likes={"scale"}, +@wrappers.interface( + dtype_constraints={"input": "T1", "scale": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16"], "T2": ["int4", "int8", "float8"]}, + convert_to_tensors={"scale"}, ) def quantize( input: "tripy.Tensor", diff --git a/tripy/tripy/frontend/trace/ops/reduce.py b/tripy/tripy/frontend/trace/ops/reduce.py index 490370814..d5f2144de 100644 --- a/tripy/tripy/frontend/trace/ops/reduce.py +++ b/tripy/tripy/frontend/trace/ops/reduce.py @@ -20,8 +20,7 @@ from dataclasses import dataclass from typing import Optional, Sequence, Union -import tripy.frontend.trace.ops.utils as op_utils -from tripy import constraints, export +from tripy import export, wrappers from tripy.common import datatype from tripy.frontend.trace.ops.base import BaseTraceOp from tripy.utils import make_list @@ -133,9 +132,9 @@ def _reduce_impl(input: "tripy.Tensor", kind: Reduce.Kind, dim: Union[int, Seque @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, ) def sum( input: "tripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False @@ -166,9 +165,9 @@ def sum( @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["bool"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["bool"]}, ) def all( input: "tripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False @@ -198,9 +197,9 @@ def all( @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["bool"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["bool"]}, ) def any( input: "tripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False @@ -230,9 +229,9 @@ def any( @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, ) def max( input: "tripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False @@ -263,9 +262,9 @@ def max( @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, ) def prod( input: "tripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False @@ -317,9 +316,9 @@ def mean_impl(tensor: "tripy.Tensor", dim: Union[int, Sequence] = None, keepdim: @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "int32", "int64", "float16", "bfloat16"]}, ) def mean( input: "tripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False @@ -350,9 +349,9 @@ def mean( @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, ) def var( input: "tripy.Tensor", dim: Optional[Union[int, Sequence[int]]] = None, keepdim: bool = False, correction: int = 1 @@ -417,9 +416,9 @@ def _arg_min_max_impl(tensor: "tripy.Tensor", kind: ArgMinMax.Kind, dim: Optiona @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T2"}, - variables={"T1": ["float32", "float16", "bfloat16", "int32"], "T2": ["int32"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32"], "T2": ["int32"]}, ) def argmax(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "tripy.Tensor": """ @@ -449,9 +448,9 @@ def argmax(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = Fal @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T2"}, - variables={"T1": ["float32", "float16", "bfloat16", "int32"], "T2": ["int32"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T2"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32"], "T2": ["int32"]}, ) def argmin(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "tripy.Tensor": """ diff --git a/tripy/tripy/frontend/trace/ops/reshape.py b/tripy/tripy/frontend/trace/ops/reshape.py index bd4b6caac..4061adff7 100644 --- a/tripy/tripy/frontend/trace/ops/reshape.py +++ b/tripy/tripy/frontend/trace/ops/reshape.py @@ -18,9 +18,8 @@ import math from dataclasses import dataclass -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.exception import raise_error -from tripy.frontend import utils as frontend_utils from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp from tripy.types import ShapeLike @@ -61,10 +60,10 @@ def infer_dimensions(input: "tripy.Tensor", shape: ShapeLike) -> ShapeLike: @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, + convert_to_tensors=True, conversion_preprocess_func=infer_dimensions, ) def reshape(input: "tripy.Tensor", shape: ShapeLike) -> "tripy.Tensor": diff --git a/tripy/tripy/frontend/trace/ops/resize.py b/tripy/tripy/frontend/trace/ops/resize.py index 90ca64753..0e7317241 100644 --- a/tripy/tripy/frontend/trace/ops/resize.py +++ b/tripy/tripy/frontend/trace/ops/resize.py @@ -19,7 +19,7 @@ from dataclasses import dataclass from typing import Optional, Sequence -from tripy import constraints, export +from tripy import export, wrappers from tripy.common.exception import raise_error from tripy.frontend import utils as frontend_utils from tripy.frontend.trace.ops.base import BaseTraceOp @@ -105,10 +105,10 @@ def _check_mode(mode: str, align_corners: bool): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "int8"]}, - convert_tensor_and_shape_likes=True, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "int8"]}, + convert_to_tensors=True, ) def resize(input: "tripy.Tensor", mode: str, output_shape: ShapeLike, align_corners: bool = False) -> "tripy.Tensor": r""" @@ -143,9 +143,9 @@ def resize(input: "tripy.Tensor", mode: str, output_shape: ShapeLike, align_corn @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "int8"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "int8"]}, ) def resize( input: "tripy.Tensor", mode: str, scales: Sequence[numbers.Number], align_corners: bool = False diff --git a/tripy/tripy/frontend/trace/ops/shape.py b/tripy/tripy/frontend/trace/ops/shape.py index 90edc4c40..2e86ce469 100644 --- a/tripy/tripy/frontend/trace/ops/shape.py +++ b/tripy/tripy/frontend/trace/ops/shape.py @@ -17,7 +17,7 @@ from dataclasses import dataclass -from tripy import constraints +from tripy import wrappers from tripy.common.datatype import DATA_TYPES from tripy.frontend.ops.registry import register_tensor_method from tripy.frontend.trace.ops.base import BaseTraceOp @@ -44,7 +44,7 @@ def to_flat_ir(self, inputs, outputs): @register_tensor_method("shape") @property -@constraints.interface(dtype_constraints={"self": "T1"}, variables={"T1": list(DATA_TYPES.keys())}) +@wrappers.interface(dtype_constraints={"self": "T1"}, dtype_variables={"T1": list(DATA_TYPES.keys())}) def shape(self: "tripy.Tensor") -> ShapeLike: """ Represents the shape of the tensor. diff --git a/tripy/tripy/frontend/trace/ops/slice.py b/tripy/tripy/frontend/trace/ops/slice.py index 95066f5a4..67a09830c 100644 --- a/tripy/tripy/frontend/trace/ops/slice.py +++ b/tripy/tripy/frontend/trace/ops/slice.py @@ -18,9 +18,8 @@ from dataclasses import dataclass from typing import Sequence, Union -from tripy import constraints, utils +from tripy import utils, wrappers from tripy.common.exception import raise_error -from tripy.frontend import utils as frontend_utils from tripy.frontend.ops.registry import register_tensor_method from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp @@ -130,9 +129,9 @@ def adjust_start(start_bound, end_bound): @register_tensor_method("__getitem__") -@constraints.interface( - dtype_constraints={"self": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, +@wrappers.interface( + dtype_constraints={"self": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, ) def __getitem__( self: "tripy.Tensor", index: Union[slice, int, "tripy.Tensor", Sequence[Union[slice, int, "tripy.Tensor"]]] @@ -259,7 +258,7 @@ def clamp_bound(bound: Union[int, Tensor]) -> Union[int, Tensor]: return out -@constraints.interface(convert_tensor_and_shape_likes=True) +@wrappers.interface(convert_to_tensors=True) def slice_helper(tensor, *slice_params: TensorLike): from tripy import function_registry from tripy.utils import get_arg_candidate_column_offsets diff --git a/tripy/tripy/frontend/trace/ops/split.py b/tripy/tripy/frontend/trace/ops/split.py index b1cb2f706..136a5dfdb 100644 --- a/tripy/tripy/frontend/trace/ops/split.py +++ b/tripy/tripy/frontend/trace/ops/split.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from typing import Sequence, Union -from tripy import constraints, export, utils +from tripy import export, utils, wrappers from tripy.common.exception import raise_error from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp @@ -174,9 +174,9 @@ def __str__(self) -> str: @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, ) diff --git a/tripy/tripy/frontend/trace/ops/squeeze.py b/tripy/tripy/frontend/trace/ops/squeeze.py index 03a14c489..9d8cdaefe 100644 --- a/tripy/tripy/frontend/trace/ops/squeeze.py +++ b/tripy/tripy/frontend/trace/ops/squeeze.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from typing import Tuple, Union -from tripy import constraints, export, utils +from tripy import export, utils, wrappers from tripy.frontend.trace.ops import utils as op_utils from tripy.frontend.trace.ops.base import BaseTraceOp @@ -50,9 +50,9 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"]}, ) def squeeze(input: "tripy.Tensor", dims: Union[Tuple, int]) -> "tripy.Tensor": """ diff --git a/tripy/tripy/frontend/trace/ops/unary_elementwise.py b/tripy/tripy/frontend/trace/ops/unary_elementwise.py index bdc0c6bf0..fc334b35f 100644 --- a/tripy/tripy/frontend/trace/ops/unary_elementwise.py +++ b/tripy/tripy/frontend/trace/ops/unary_elementwise.py @@ -18,7 +18,7 @@ import enum from dataclasses import dataclass -from tripy import export, constraints +from tripy import export, wrappers from tripy.frontend.trace.ops.base import BaseTraceOp import tripy.frontend.trace.ops.utils as op_utils @@ -59,9 +59,9 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, ) def exp(input: "tripy.Tensor") -> "tripy.Tensor": r""" @@ -88,9 +88,9 @@ def exp(input: "tripy.Tensor") -> "tripy.Tensor": @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, ) def tanh(input: "tripy.Tensor") -> "tripy.Tensor": """ @@ -115,9 +115,9 @@ def tanh(input: "tripy.Tensor") -> "tripy.Tensor": @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, ) def sin(input: "tripy.Tensor") -> "tripy.Tensor": """ @@ -142,9 +142,9 @@ def sin(input: "tripy.Tensor") -> "tripy.Tensor": @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, ) def cos(input: "tripy.Tensor") -> "tripy.Tensor": """ @@ -169,9 +169,9 @@ def cos(input: "tripy.Tensor") -> "tripy.Tensor": @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8"]}, ) def rsqrt(input: "tripy.Tensor") -> "tripy.Tensor": """ @@ -196,9 +196,9 @@ def rsqrt(input: "tripy.Tensor") -> "tripy.Tensor": @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "float8"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8"]}, ) def sqrt(input: "tripy.Tensor") -> "tripy.Tensor": """ @@ -223,9 +223,9 @@ def sqrt(input: "tripy.Tensor") -> "tripy.Tensor": @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16"]}, ) def log(input: "tripy.Tensor") -> "tripy.Tensor": """ @@ -250,9 +250,9 @@ def log(input: "tripy.Tensor") -> "tripy.Tensor": @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, - variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, +@wrappers.interface( + dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, ) def abs(input: "tripy.Tensor") -> "tripy.Tensor": r""" diff --git a/tripy/tripy/frontend/trace/ops/where.py b/tripy/tripy/frontend/trace/ops/where.py index 5f20f72f9..44415fa3e 100644 --- a/tripy/tripy/frontend/trace/ops/where.py +++ b/tripy/tripy/frontend/trace/ops/where.py @@ -19,7 +19,7 @@ from dataclasses import dataclass import tripy.frontend.trace.ops.utils as op_utils -from tripy import constraints, export +from tripy import export, wrappers from tripy.frontend.trace.ops.base import BaseTraceOp @@ -91,9 +91,9 @@ def to_flat_ir(self, inputs, outputs): @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"condition": "T2", "input": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"condition": "T2", "input": "T1", "other": "T1", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["bool"], }, @@ -130,9 +130,9 @@ def where(condition: "tripy.Tensor", input: "tripy.Tensor", other: "tripy.Tensor @export.public_api(document_under="operations/functions") -@constraints.interface( - dtype_constraints={"input": "T1", "mask": "T2", constraints.RETURN_VALUE: "T1"}, - variables={ +@wrappers.interface( + dtype_constraints={"input": "T1", "mask": "T2", wrappers.RETURN_VALUE: "T1"}, + dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["bool"], }, diff --git a/tripy/tripy/utils/stack_info.py b/tripy/tripy/utils/stack_info.py index d037a2bda..addd4c114 100644 --- a/tripy/tripy/utils/stack_info.py +++ b/tripy/tripy/utils/stack_info.py @@ -140,6 +140,6 @@ def get_module_names_to_exclude_from_stack_info(): or trying to retrieve column information from code. """ import tripy.function_registry - import tripy.constraints + import tripy.wrappers - return {mod.__name__ for mod in [tripy.function_registry, tripy.constraints]} + return {mod.__name__ for mod in [tripy.function_registry, tripy.wrappers]} diff --git a/tripy/tripy/constraints.py b/tripy/tripy/wrappers.py similarity index 94% rename from tripy/tripy/constraints.py rename to tripy/tripy/wrappers.py index e1db4b409..2fe0d4f05 100644 --- a/tripy/tripy/constraints.py +++ b/tripy/tripy/wrappers.py @@ -224,10 +224,10 @@ def add_arg(arg): def interface( dtype_constraints: Dict[str, str] = {}, - variables: Dict[str, List[str]] = {}, + dtype_variables: Dict[str, List[str]] = {}, dtype_exceptions: List[Dict[str, str]] = [], aliases: List[str] = [], - convert_tensor_and_shape_likes: Union[bool, Set[str]] = False, + convert_to_tensors: Union[bool, Set[str]] = False, conversion_preprocess_func: Optional[Callable] = None, ): """ @@ -242,11 +242,11 @@ def interface( Args: dtype_constraints: Maps parameters and return values to data type constraint variables. - Use the special value `constraints.RETURN_VALUE` to denote return values. + Use the special value `wrappers.RETURN_VALUE` to denote return values. For example: - {"input": "T1", "other": T2, constraints.RETURN_VALUE: "T1"} + {"input": "T1", "other": T2, wrappers.RETURN_VALUE: "T1"} - variables: Maps data type constraints variables to their supported data types. + dtype_variables: Maps data type constraints variables to their supported data types. For example: {"T1": ["float32", "float16"], "T2": ["int32", "int64"]} @@ -260,7 +260,7 @@ def interface( (e.g. `__add__` and `__radd__`), this will enable type information to be added to the documentation for the aliases as well. - convert_tensor_and_shape_likes: If False or an empty set, no argument types will be converted. + convert_to_tensors: If False or an empty set, no argument types will be converted. If True, all arguments with the `TensorLike` or `ShapeLike` annotations will be converted into `Tensor`s or, whenever possible, `DimensionSize`. If the argument is a set of argument names, conversions will be done only for those arguments. @@ -268,7 +268,7 @@ def interface( The conversions will respect any datatype constraints, casting the `TensorLike` values as necessary, but will raise an exception for lossy casts like float to int (but *not* for, e.g., `float32` to `float16`). - conversion_preprocess_func: If `convert_tensor_and_shape_likes` is true, this argument is a callback that is + conversion_preprocess_func: If `convert_to_tensors` is true, this argument is a callback that is used to preprocess the arguments before potential conversion. In this case, if provided, the callback will be called regardless of whether the decorator performs any conversions. @@ -282,18 +282,18 @@ def decorator(func): return_dtype = dtype_constraints.get(RETURN_VALUE, None) VerifInfo = namedtuple("VerifInfo", ["obj", "inputs", "exceptions", "return_dtype", "dtypes", "constraints"]) - verif_info = VerifInfo(func, {}, dtype_exceptions, return_dtype, variables, dtype_constraints) + verif_info = VerifInfo(func, {}, dtype_exceptions, return_dtype, dtype_variables, dtype_constraints) signature = inspect.signature(func) conversion_targets = ( - convert_tensor_and_shape_likes - if isinstance(convert_tensor_and_shape_likes, Set) + convert_to_tensors + if isinstance(convert_to_tensors, Set) else {name for name, param in signature.parameters.items() if param.annotation in {TensorLike, ShapeLike}} ) shape_likes = {name for name, param in signature.parameters.items() if param.annotation is ShapeLike} # if no dtype constraints have been specified at all, do not add to the table so we don't generate invalid tests - if dtype_constraints or variables or dtype_exceptions: + if dtype_constraints or dtype_variables or dtype_exceptions: for key in [func.__qualname__] + aliases: TYPE_VERIFICATION[key] = verif_info @@ -301,7 +301,7 @@ def decorator(func): def wrapper(*args, **kwargs): merged_args, var_arg_info = utils.merge_function_arguments(func, *args, **kwargs) - if convert_tensor_and_shape_likes: + if convert_to_tensors: args, kwargs, merged_args = convert_input_types( func, args, @@ -339,7 +339,7 @@ def wrapper(*args, **kwargs): arg_dtype = arg_dtype.value # Check if the type is supported at all - supported_dtypes = variables[type_var] + supported_dtypes = dtype_variables[type_var] if arg_dtype.name not in supported_dtypes: raise_error( f"Unsupported data type for '{func.__qualname__}'.",