Skip to content

Commit

Permalink
Rename constraints to wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Dec 5, 2024
1 parent f4c2068 commit f046a23
Show file tree
Hide file tree
Showing 47 changed files with 364 additions and 374 deletions.
2 changes: 1 addition & 1 deletion tripy/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import tripy as tp
from tripy.common.datatype import DATA_TYPES
from tripy.constraints import TYPE_VERIFICATION
from tripy.wrappers import TYPE_VERIFICATION

PARAM_PAT = re.compile(":param .*?:")

Expand Down
6 changes: 3 additions & 3 deletions tripy/docs/post0_developer_guides/how-to-add-new-ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ it as a `tripy.Module` under [`frontend/module`](source:/tripy/frontend/module).

```py
# doc: no-eval
from tripy import constraints, export
from tripy import export, wrappers
from tripy.types import ShapeLike

# We can use the `export.public_api()` decorator to automatically export this
Expand All @@ -191,12 +191,12 @@ from tripy.types import ShapeLike
# `autodoc_options` parameter.
@export.public_api(document_under="tensor_operations")

# We can use the `constraints.interface` decorator to specify constraints on
# We can use the `wrappers.interface` decorator to specify constraints on
# inputs and perform transformations on them, like automatically converting
# compatible arguments (e.g., `TensorLike` or `ShapeLike`s) into tensors.
# We will aim to include most constraints and transformations in this decorator
# so as to avoid layering too many decorators.
@constraints.interface(convert_tensor_and_shape_likes=True)
@wrappers.interface(convert_to_tensors=True)
def theta(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "tripy.Tensor":
# For any public facing interfaces, we have documentation requirements which
# you can read about in the 'Docs README' (linked below). The docstring
Expand Down
8 changes: 4 additions & 4 deletions tripy/tests/common/test_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ def test_can_determine_column_range(self):
in dedent(error_msg).strip()
)

def test_constraints_is_excluded(self):
from tripy import constraints
def test_wrappers_is_excluded(self):
from tripy import wrappers

tensor = tp.ones((2, 3))

stack_info = tensor.stack_info

assert any(frame.module == constraints.__name__ for frame in stack_info)
assert any(frame.module == wrappers.__name__ for frame in stack_info)

# Make sure that no extraneous wrapper code is included
expected = dedent(
Expand All @@ -146,7 +146,7 @@ def test_constraints_is_excluded(self):
[0-9]+ | return full\(shape, 1, dtype\)
| ^^^^^^^^^^^^^^^^^^^^^ --- required from here
--> [a-z_/\.]+:[0-9]+ in test_constraints_is_excluded\(\)
--> [a-z_/\.]+:[0-9]+ in test_wrappers_is_excluded\(\)
|
[0-9]+ | tensor = tp.ones\(\(2, 3\)\)
| ^^^^^^^^^^^^^^^ --- required from here
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
import pytest
from tests import helper
from tests.conftest import skip_if_older_than_sm89
from tests.constraints.object_builders import create_obj
from tests.wrappers.object_builders import create_obj

import tripy as tp
from tripy import constraints
from tripy import wrappers
from tripy.common.datatype import DATA_TYPES
from tripy.constraints import TYPE_VERIFICATION
from tripy.wrappers import TYPE_VERIFICATION
from tripy.export import PUBLIC_APIS

# Get all functions/methods which have tensors in the type signature
Expand Down Expand Up @@ -239,7 +239,7 @@ def test_dtype_constraints(test_data):
assert ret_val.dtype == namespace[return_dtype]


@constraints.interface(dtype_constraints={"tensors": "T1"}, variables={"T1": ["float32"]})
@wrappers.interface(dtype_constraints={"tensors": "T1"}, dtype_variables={"T1": ["float32"]})
def sequence_func(tensors: List[tp.Tensor]):
return

Expand All @@ -255,7 +255,7 @@ def test_raises_on_mismatched_sequence_dtypes(self):

class TestTensorConversion:
def test_no_effect_on_non_tensor_likes(self):
@constraints.interface(convert_tensor_and_shape_likes=True)
@wrappers.interface(convert_to_tensors=True)
def func(a: tp.Tensor, b: int):
return a, b

Expand All @@ -266,7 +266,7 @@ def func(a: tp.Tensor, b: int):
assert b is 4

def test_tensor_likes(self):
@constraints.interface(convert_tensor_and_shape_likes=True)
@wrappers.interface(convert_to_tensors=True)
def func(a: tp.types.TensorLike):
return a

Expand All @@ -277,7 +277,7 @@ def func(a: tp.types.TensorLike):

def test_converts_to_dimension_size(self):
# The decorator should convert to DimensionSizes when possible.
@constraints.interface(convert_tensor_and_shape_likes=True)
@wrappers.interface(convert_to_tensors=True)
def func(a: tp.types.TensorLike):
return a

Expand All @@ -289,7 +289,7 @@ def func(a: tp.types.TensorLike):
assert type(a) is tp.Tensor

def test_shape_likes(self):
@constraints.interface(convert_tensor_and_shape_likes=True)
@wrappers.interface(convert_to_tensors=True)
def func(a: tp.types.ShapeLike):
return a

Expand All @@ -310,7 +310,7 @@ def func(a: tp.types.ShapeLike):
assert bool(tp.all(a == tp.Tensor([2, 2, 3, 5])))

def test_keyword_args(self):
@constraints.interface(convert_tensor_and_shape_likes=True)
@wrappers.interface(convert_to_tensors=True)
def func(a: tp.types.TensorLike):
return a

Expand All @@ -320,7 +320,7 @@ def func(a: tp.types.TensorLike):
assert a.stack_info[4].column_range == (17, 22)

def test_multiple_args(self):
@constraints.interface(convert_tensor_and_shape_likes=True)
@wrappers.interface(convert_to_tensors=True)
def func(a: tp.types.TensorLike, b: tp.types.TensorLike):
return a, b

Expand All @@ -333,7 +333,7 @@ def func(a: tp.types.TensorLike, b: tp.types.TensorLike):
assert b.stack_info[4].column_range == (25, 28)

def test_args_out_of_order(self):
@constraints.interface(convert_tensor_and_shape_likes=True)
@wrappers.interface(convert_to_tensors=True)
def func(a: tp.types.TensorLike, b: tp.types.TensorLike):
return a, b

Expand All @@ -349,10 +349,10 @@ def func(a: tp.types.TensorLike, b: tp.types.TensorLike):

def test_cast_dtype(self):
# When type constraints are included, the decorator should automatically cast when possible.
@constraints.interface(
dtype_constraints={"a": "T1", "b": "T1", constraints.RETURN_VALUE: "T1"},
variables={"T1": ["float16"]},
convert_tensor_and_shape_likes=True,
@wrappers.interface(
dtype_constraints={"a": "T1", "b": "T1", wrappers.RETURN_VALUE: "T1"},
dtype_variables={"T1": ["float16"]},
convert_to_tensors=True,
)
def func(a: tp.Tensor, b: tp.types.TensorLike):
return a, b
Expand All @@ -369,10 +369,10 @@ def func(a: tp.Tensor, b: tp.types.TensorLike):

@pytest.mark.parametrize("arg, dtype", [(1.0, tp.int32), (1.0, tp.int64), (2, tp.bool)])
def test_refuse_unsafe_cast(self, arg, dtype):
@constraints.interface(
dtype_constraints={"a": "T1", "b": "T1", constraints.RETURN_VALUE: "T1"},
variables={"T1": ["int32", "int64"]},
convert_tensor_and_shape_likes=True,
@wrappers.interface(
dtype_constraints={"a": "T1", "b": "T1", wrappers.RETURN_VALUE: "T1"},
dtype_variables={"T1": ["int32", "int64"]},
convert_to_tensors=True,
)
def func(a: tp.Tensor, b: tp.types.TensorLike):
return a, b
Expand All @@ -385,7 +385,7 @@ def test_preprocess_func(self):
def add_a_to_b(a, b):
return {"b": a + b}

@constraints.interface(convert_tensor_and_shape_likes=True, conversion_preprocess_func=add_a_to_b)
@wrappers.interface(convert_to_tensors=True, conversion_preprocess_func=add_a_to_b)
def func(a: tp.types.TensorLike, b: tp.types.TensorLike):
return a, b

Expand All @@ -398,7 +398,7 @@ def test_variadic_args(self):
def increment(a, *args):
return {"a": a + 1, "args": list(map(lambda arg: arg + 1, args))}

@constraints.interface(convert_tensor_and_shape_likes=True, conversion_preprocess_func=increment)
@wrappers.interface(convert_to_tensors=True, conversion_preprocess_func=increment)
def func(a: tp.Tensor, *args):
return [a] + list(args)

Expand Down
6 changes: 3 additions & 3 deletions tripy/tripy/frontend/ops/allclose.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
# limitations under the License.
#

from tripy import constraints, export
from tripy import export, wrappers


@export.public_api(document_under="operations/functions")
@constraints.interface(
dtype_constraints={"input": "T1", "other": "T1"}, variables={"T1": ["float32", "float16", "bfloat16"]}
@wrappers.interface(
dtype_constraints={"input": "T1", "other": "T1"}, dtype_variables={"T1": ["float32", "float16", "bfloat16"]}
)
def allclose(input: "tripy.Tensor", other: "tripy.Tensor", rtol: float = 1e-05, atol: float = 1e-08) -> bool:
r"""
Expand Down
8 changes: 4 additions & 4 deletions tripy/tripy/frontend/ops/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tripy import constraints, export
from tripy import export, wrappers

from tripy.frontend import utils as frontend_utils


@export.public_api(document_under="operations/functions")
@constraints.interface(
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
variables={
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
dtype_variables={
"T1": ["float32", "float16", "bfloat16", "int32"],
},
)
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/ops/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tripy import constraints, export
from tripy import export, wrappers
from tripy.common.datatype import DATA_TYPES


@export.public_api(document_under="operations/functions")
@constraints.interface(dtype_constraints={"input": "T1", "other": "T1"}, variables={"T1": list(DATA_TYPES.keys())})
@wrappers.interface(dtype_constraints={"input": "T1", "other": "T1"}, dtype_variables={"T1": list(DATA_TYPES.keys())})
def equal(input: "tripy.Tensor", other: "tripy.Tensor") -> bool:
r"""
Returns ``True`` if ``input`` and ``other`` have the same shape and elements.
Expand Down
8 changes: 4 additions & 4 deletions tripy/tripy/frontend/ops/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# limitations under the License.
import math

from tripy import constraints, export
from tripy import export, wrappers
from tripy.common.exception import raise_error


@export.public_api(document_under="operations/functions")
@constraints.interface(
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]},
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]},
)
def flatten(input: "tripy.Tensor", start_dim: int = 0, end_dim: int = -1) -> "tripy.Tensor":
"""
Expand Down
8 changes: 4 additions & 4 deletions tripy/tripy/frontend/ops/gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

import math

from tripy import export, constraints
from tripy import export, wrappers


@export.public_api(document_under="operations/functions")
@constraints.interface(
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
variables={
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
dtype_variables={
"T1": ["float32", "float16", "bfloat16"],
},
)
Expand Down
8 changes: 4 additions & 4 deletions tripy/tripy/frontend/ops/outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
# limitations under the License.
#

from tripy import constraints, export
from tripy import export, wrappers


@export.public_api(document_under="operations/functions")
@constraints.interface(
dtype_constraints={"vec1": "T1", "vec2": "T1", constraints.RETURN_VALUE: "T1"},
variables={"T1": ["float32", "float16", "bfloat16", "int32"]},
@wrappers.interface(
dtype_constraints={"vec1": "T1", "vec2": "T1", wrappers.RETURN_VALUE: "T1"},
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int32"]},
)
def outer(vec1: "tripy.Tensor", vec2: "tripy.Tensor") -> "tripy.Tensor":
r"""
Expand Down
8 changes: 4 additions & 4 deletions tripy/tripy/frontend/ops/relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
# limitations under the License.
#

from tripy import export, constraints
from tripy import export, wrappers


@export.public_api(document_under="operations/functions")
@constraints.interface(
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
variables={
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
dtype_variables={
"T1": ["float32", "float16", "bfloat16", "int4", "int32", "int64", "bool", "int8"],
},
)
Expand Down
8 changes: 4 additions & 4 deletions tripy/tripy/frontend/ops/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
#
from typing import Union

from tripy import constraints, export
from tripy import export, wrappers
from tripy.common.exception import raise_error
from tripy.frontend import utils as frontend_utils


@export.public_api(document_under="operations/functions")
@constraints.interface(
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
variables={
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
dtype_variables={
"T1": ["float32", "float16", "bfloat16", "int4", "float8", "int8", "int32", "int64", "bool"],
},
)
Expand Down
8 changes: 4 additions & 4 deletions tripy/tripy/frontend/ops/sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
# limitations under the License.
#

from tripy import export, constraints
from tripy import export, wrappers


@export.public_api(document_under="operations/functions")
@constraints.interface(
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
variables={
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
dtype_variables={
"T1": ["float32", "float16", "bfloat16"],
},
)
Expand Down
8 changes: 4 additions & 4 deletions tripy/tripy/frontend/ops/silu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
# limitations under the License.
#

from tripy import export, constraints
from tripy import export, wrappers


@export.public_api(document_under="operations/functions")
@constraints.interface(
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
variables={
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
dtype_variables={
"T1": ["float32", "float16", "bfloat16"],
},
)
Expand Down
8 changes: 4 additions & 4 deletions tripy/tripy/frontend/ops/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
# limitations under the License.
#

from tripy import export, constraints
from tripy import export, wrappers


@export.public_api(document_under="operations/functions")
@constraints.interface(
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
variables={
@wrappers.interface(
dtype_constraints={"input": "T1", wrappers.RETURN_VALUE: "T1"},
dtype_variables={
"T1": ["float32", "float16", "bfloat16"],
},
)
Expand Down
Loading

0 comments on commit f046a23

Please sign in to comment.