From ef7ad7ab80c6a0df8288253bd0129d23597e74ed Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 30 Sep 2024 15:30:20 -0600 Subject: [PATCH] Add maximum and minimum torch wrappers (for fixed two-arg type promotion) --- array_api_compat/torch/_aliases.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 10725253..5ac66bcb 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -177,6 +177,8 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: logaddexp = _two_arg(torch.logaddexp) # logical functions are not included here because they only accept bool in the # spec, so type promotion is irrelevant. +maximum = _two_arg(torch.maximum) +minimum = _two_arg(torch.minimum) multiply = _two_arg(torch.multiply) not_equal = _two_arg(torch.not_equal) pow = _two_arg(torch.pow) @@ -735,15 +737,16 @@ def sign(x: array, /) -> array: 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', - 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', - 'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', - 'cumulative_sum', 'sort', 'prod', 'sum', 'any', 'all', 'mean', - 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', - 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', - 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', - 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', - 'UniqueInverseResult', 'unique_all', 'unique_counts', - 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', - 'vecdot', 'tensordot', 'isdtype', 'take', 'sign'] + 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', + 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', + 'min', 'clip', 'unstack', 'cumulative_sum', 'sort', 'prod', 'sum', + 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', + 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', + 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', + 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', + 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', + 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', + 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', + 'take', 'sign'] _all_ignore = ['torch', 'get_xp']