Skip to content

Commit

Permalink
Combine the functionality of the convert_to_tensors decorator and t…
Browse files Browse the repository at this point in the history
…he `dtypes` constraint. (#420)

Address #408. Since all of the functions that used the
`convert_to_tensors` decorator also applied a `dtypes` constraint, the
approach here is to simply include the option to do conversions in the
`dtypes` constraint and rename the decorator to `interface`, to indicate
a more general function.
  • Loading branch information
slyubomirsky authored Dec 10, 2024
1 parent 51cb635 commit b82c737
Show file tree
Hide file tree
Showing 56 changed files with 1,326 additions and 935 deletions.
2 changes: 1 addition & 1 deletion tripy/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import tripy as tp
from tripy.common.datatype import DATA_TYPES
from tripy.constraints import TYPE_VERIFICATION
from tripy.wrappers import TYPE_VERIFICATION

PARAM_PAT = re.compile(":param .*?:")

Expand Down
12 changes: 7 additions & 5 deletions tripy/docs/post0_developer_guides/how-to-add-new-ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ it as a `tripy.Module` under [`frontend/module`](source:/tripy/frontend/module).

```py
# doc: no-eval
from tripy import export
import tripy.frontend.utils as frontend_utils
from tripy import export, wrappers
from tripy.types import ShapeLike

# We can use the `export.public_api()` decorator to automatically export this
Expand All @@ -192,9 +191,12 @@ from tripy.types import ShapeLike
# `autodoc_options` parameter.
@export.public_api(document_under="tensor_operations")

# The `convert_to_tensors` decorator automatically converts compatible
# arguments, like `TensorLike` or `ShapeLike`s, into tensors.
@frontend_utils.convert_to_tensors()
# We can use the `wrappers.interface` decorator to specify constraints on
# inputs and perform transformations on them, like automatically converting
# compatible arguments (e.g., `TensorLike` or `ShapeLike`s) into tensors.
# We will aim to include most constraints and transformations in this decorator
# so as to avoid layering too many decorators.
@wrappers.interface(convert_to_tensors=True)
def theta(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "tripy.Tensor":
# For any public facing interfaces, we have documentation requirements which
# you can read about in the 'Docs README' (linked below). The docstring
Expand Down
12 changes: 5 additions & 7 deletions tripy/tests/common/test_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import tripy as tp
from tripy.common.exception import TripyException, _get_function_file_and_lines, str_from_stack_info, raise_error
from tripy.frontend.utils import convert_to_tensors
from tripy.utils import StackInfo, get_stack_info
from tripy.utils.stack_info import SourceInfo

Expand Down Expand Up @@ -125,15 +124,14 @@ def test_can_determine_column_range(self):
in dedent(error_msg).strip()
)

def test_convert_to_tensors_is_excluded(self):
filename, start_line, end_line = _get_function_file_and_lines(convert_to_tensors)
def test_wrappers_is_excluded(self):
from tripy import wrappers

tensor = tp.ones((2, 3))

stack_info = tensor.stack_info

assert any(
frame.file == filename and frame.line >= start_line and frame.line <= end_line for frame in stack_info
)
assert any(frame.module == wrappers.__name__ for frame in stack_info)

# Make sure that no extraneous wrapper code is included
expected = dedent(
Expand All @@ -148,7 +146,7 @@ def test_convert_to_tensors_is_excluded(self):
[0-9]+ | return full\(shape, 1, dtype\)
| ^^^^^^^^^^^^^^^^^^^^^ --- required from here
--> [a-z_/\.]+:[0-9]+ in test_convert_to_tensors_is_excluded\(\)
--> [a-z_/\.]+:[0-9]+ in test_wrappers_is_excluded\(\)
|
[0-9]+ | tensor = tp.ones\(\(2, 3\)\)
| ^^^^^^^^^^^^^^^ --- required from here
Expand Down
24 changes: 24 additions & 0 deletions tripy/tests/frontend/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,27 @@ def test_explicit_cast_copy(self):
)
def test_tolist(self, tensor, expected):
assert np.allclose(tensor.tolist(), expected)

# testing the invariant that stack trace of build is not past the limit
@pytest.mark.parametrize(
"tensor",
[
tp.Tensor([1, 2, 3]),
tp.Tensor([1, 2, 3]) + tp.Tensor([4, 5, 6]),
# This case should trigger datatype conversions.
(4 * tp.Tensor([1, 2, 3])) + (3 * tp.Tensor([4, 5, 6])),
# Slice is an interesting case because it adds slice_helper to the stack.
# Additionally, the use of slices may also require more ops, increasing the total stack depth.
(tp.Tensor([1, 2, 3]) + tp.Tensor([4, 5, 6]))[:],
(tp.Tensor([1, 2, 3]) + tp.Tensor([4, 5, 6]))[0:],
(tp.Tensor([1, 2, 3]) + tp.Tensor([4, 5, 6]))[:3],
(tp.Tensor([1, 2, 3]) + tp.Tensor([4, 5, 6]))[0:3:1],
(tp.Tensor([[1], [2], [3]]) + tp.Tensor([[4], [5], [6]]))[0],
],
)
def test_stack_depth_of_build(self, tensor):
if any(info.function == "build" for info in tensor.stack_info):
# + 1 for inclusive bound
assert any(
info.function == "build" for info in tensor.stack_info[: tp.frontend.tensor.STACK_DEPTH_OF_BUILD + 1]
)
160 changes: 1 addition & 159 deletions tripy/tests/frontend/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
#

import pytest
from tests import helper

import tripy as tp
from tripy import constraints
from tripy.frontend.utils import convert_to_tensors, tensor_from_shape_like
from tripy.frontend.utils import tensor_from_shape_like


@pytest.mark.parametrize(
Expand All @@ -37,158 +34,3 @@ def test_tensor_from_shape_like(shape, expected):
tensor = tensor_from_shape_like(shape)

assert tensor.tolist() == expected


class TestConvertToTensors:
def test_no_effect_on_non_tensor_likes(self):
@convert_to_tensors()
def func(a: tp.Tensor, b: int):
return a, b

original_a = tp.Tensor([1, 2])
a, b = func(original_a, 4)

assert a is original_a
assert b is 4

def test_tensor_likes(self):
@convert_to_tensors()
def func(a: tp.types.TensorLike):
return a

a = func(1.0)

assert isinstance(a, tp.Tensor)
assert a.stack_info[3].column_range == (17, 20)

def test_converts_to_dimension_size(self):
# The decorator should convert to DimensionSizes when possible.
@convert_to_tensors()
def func(a: tp.types.TensorLike):
return a

a = func(1)
assert type(a) is tp.DimensionSize

# floats cannot be DimensionSizes
a = func(1.0)
assert type(a) is tp.Tensor

def test_shape_likes(self):
@convert_to_tensors()
def func(a: tp.types.ShapeLike):
return a

a = func([1, 2, 3])

assert isinstance(a, tp.Tensor)
assert a.shape == [3]
assert bool(tp.all(a == tp.Tensor([1, 2, 3])))

# Should also work from shapes of tensors
inp = tp.Tensor([[1, 2], [2, 3]])
a = inp.shape + [3, 5] # Should yield: [2, 2, 3, 5]

a = func(a)

assert isinstance(a, tp.Tensor)
assert a.shape == [4]
assert bool(tp.all(a == tp.Tensor([2, 2, 3, 5])))

def test_keyword_args(self):
@convert_to_tensors()
def func(a: tp.types.TensorLike):
return a

a = func(a=1.0)

assert isinstance(a, tp.Tensor)
assert a.stack_info[3].column_range == (17, 22)

def test_multiple_args(self):
@convert_to_tensors()
def func(a: tp.types.TensorLike, b: tp.types.TensorLike):
return a, b

a, b = func(1.0, 2.0)

assert isinstance(a, tp.Tensor)
assert a.stack_info[3].column_range == (20, 23)

assert isinstance(b, tp.Tensor)
assert b.stack_info[3].column_range == (25, 28)

def test_args_out_of_order(self):
@convert_to_tensors()
def func(a: tp.types.TensorLike, b: tp.types.TensorLike):
return a, b

a, b = func(b=1.0, a=2.0)

assert isinstance(a, tp.Tensor)
assert a.stack_info[3].column_range == (27, 32)
assert a.tolist() == 2.0

assert isinstance(b, tp.Tensor)
assert b.stack_info[3].column_range == (20, 25)
assert b.tolist() == 1.0

def test_cast_dtype(self):
# When type constraints are included, the decorator should automatically cast when possible.
@convert_to_tensors()
@constraints.dtypes(
constraints={"a": "T1", "b": "T1", constraints.RETURN_VALUE: "T1"},
variables={"T1": ["float16"]},
)
def func(a: tp.Tensor, b: tp.types.TensorLike):
return a, b

a, b = func(tp.Tensor([1.0], dtype=tp.float16), 4.0)

assert isinstance(b, tp.Tensor)
assert b.dtype == tp.float16

a, b = func(tp.Tensor([1.0], dtype=tp.float16), 4)

assert isinstance(b, tp.Tensor)
assert b.dtype == tp.float16

@pytest.mark.parametrize("arg, dtype", [(1.0, tp.int32), (1.0, tp.int64), (2, tp.bool)])
def test_refuse_unsafe_cast(self, arg, dtype):
@convert_to_tensors()
@constraints.dtypes(
constraints={"a": "T1", "b": "T1", constraints.RETURN_VALUE: "T1"},
variables={"T1": ["int32", "int64"]},
)
def func(a: tp.Tensor, b: tp.types.TensorLike):
return a, b

with helper.raises(tp.TripyException, "Refusing to automatically cast"):
func(tp.Tensor([1, 2], dtype=dtype), arg)

def test_preprocess_args(self):

def add_a_to_b(a, b):
return {"b": a + b}

@convert_to_tensors(preprocess_args=add_a_to_b)
def func(a: tp.types.TensorLike, b: tp.types.TensorLike):
return a, b

a, b = func(1, 2)

assert b.tolist() == 3

def test_variadic_args(self):

def increment(a, *args):
return {"a": a + 1, "args": list(map(lambda arg: arg + 1, args))}

@convert_to_tensors(preprocess_args=increment)
def func(a: tp.Tensor, *args):
return [a] + list(args)

a, b, c = func(tp.Tensor(1), tp.Tensor(2), tp.Tensor(3))
assert a.tolist() == 2
assert b.tolist() == 3
assert c.tolist() == 4
28 changes: 28 additions & 0 deletions tripy/tests/frontend/trace/ops/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@ def test_op_func_all_partial(self):
assert isinstance(a, tp.Tensor)
assert isinstance(a.trace_tensor.producer, Slice)

def test_slice_of_inline_output(self):
a = tp.Tensor([1, 2, 3, 4])
# The start and stop params use clamp bound, but the step parameter doesn't.
# The result is that the stack traces for the slice params are of different lengths.
s = (a + a)[3:4:]
assert isinstance(s, tp.Tensor)
assert isinstance(s.trace_tensor.producer, Slice)

# input 0 is a + a, so it's not one of the slice params
slice_inputs = s.trace_tensor.producer.inputs[1:]
assert len(slice_inputs) == 3

assert any(frame.function == "clamp_bound" for frame in slice_inputs[0].stack_info)
assert any(frame.function == "clamp_bound" for frame in slice_inputs[1].stack_info)
assert not any(frame.function == "clamp_bound" for frame in slice_inputs[2].stack_info)

# Consequently, the frame corresponding to the caller is at different depths.
def index_of_caller(trace_input):
for i, frame in enumerate(trace_input.stack_info):
if frame.function == TestSlice.test_slice_of_inline_output.__name__:
return i
return -1

caller_idxs = [index_of_caller(inp) for inp in slice_inputs]
assert all(idx != -1 for idx in caller_idxs)
assert caller_idxs[0] == caller_idxs[1]
assert caller_idxs[2] != caller_idxs[1]

def test_incorrect_index_size(self):
with helper.raises(
tp.TripyException,
Expand Down
Loading

0 comments on commit b82c737

Please sign in to comment.