Skip to content

Commit

Permalink
Support 'dim in Tensor'
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Aug 9, 2024
1 parent 1bbf9f3 commit 79322b6
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,18 @@ def __index__(self):
assert self.dtype.kind == int, f"Only int tensors can be converted to index but dtype is {self.dtype}"
return int(self.native())

def __contains__(self, item):
if isinstance(item, Shape):
return item in self.shape
elif isinstance(item, BoundDim):
return item.name in self.shape
elif isinstance(item, _BoundDims):
return item.dims in self.shape
elif isinstance(item, str):
assert self.dtype.kind != object, "str in Tensor not allowed for object-type Tensors"
return item in self.shape
raise ValueError(f"'dim in Tensor' requires dim to be a Shape or str but got {item}")

def __repr__(self):
return format_tensor(self, PrintOptions())

Expand Down

0 comments on commit 79322b6

Please sign in to comment.