Skip to content

Commit

Permalink
correct pyopencl arithmetic to limit only for scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd authored and inducer committed Jun 27, 2021
1 parent eb0cbbb commit 78bedd8
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions pyopencl/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ def __add__(self, other):
other.dtype.type(1), other))

return result
else:
elif np.isscalar(other):
# add a scalar
if other == 0:
return self.copy()
Expand All @@ -1078,6 +1078,8 @@ def __add__(self, other):
self._axpbz(result, self.dtype.type(1),
self, common_dtype.type(other)))
return result
else:
return NotImplemented

__radd__ = __add__

Expand All @@ -1093,7 +1095,7 @@ def __sub__(self, other):
result.dtype.type(-1), other))

return result
else:
elif np.isscalar(other):
# subtract a scalar
if other == 0:
return self.copy()
Expand All @@ -1103,6 +1105,8 @@ def __sub__(self, other):
result.add_event(
self._axpbz(result, self.dtype.type(1), self, -other))
return result
else:
return NotImplemented

def __rsub__(self, other):
"""Substracts an array by a scalar or an array::
Expand Down Expand Up @@ -1135,9 +1139,11 @@ def __isub__(self, other):
self._axpbyz(self, self.dtype.type(1), self,
other.dtype.type(-1), other))
return self
else:
elif np.isscalar(other):
self._axpbz(self, self.dtype.type(1), self, -other)
return self
else:
return NotImplemented

def __pos__(self):
return self
Expand All @@ -1154,13 +1160,15 @@ def __mul__(self, other):
result.add_event(
self._elwise_multiply(result, self, other))
return result
else:
elif np.isscalar(other):
common_dtype = _get_common_dtype(self, other, self.queue)
result = self._new_like_me(common_dtype)
result.add_event(
self._axpbz(result,
common_dtype.type(other), self, self.dtype.type(0)))
return result
else:
return NotImplemented

def __rmul__(self, scalar):
common_dtype = _get_common_dtype(self, scalar, self.queue)
Expand All @@ -1174,12 +1182,14 @@ def __imul__(self, other):
if isinstance(other, Array):
self.add_event(
self._elwise_multiply(self, self, other))
else:
return self
elif np.isscalar(other):
# scalar
self.add_event(
self._axpbz(self, other, self, self.dtype.type(0)))

return self
return self
else:
return NotImplemented

def __div__(self, other):
"""Divides an array by an array or a scalar, i.e. ``self / other``.
Expand All @@ -1188,7 +1198,8 @@ def __div__(self, other):
if isinstance(other, Array):
result = self._new_like_me(common_dtype)
result.add_event(self._div(result, self, other))
else:
return result
elif np.isscalar(other):
if other == 1:
return self.copy()
else:
Expand All @@ -1198,8 +1209,9 @@ def __div__(self, other):
self._axpbz(result,
np.true_divide(common_dtype.type(1), other),
self, self.dtype.type(0)))

return result
return result
else:
return NotImplemented

__truediv__ = __div__

Expand All @@ -1211,13 +1223,15 @@ def __rdiv__(self, other):
if isinstance(other, Array):
result = self._new_like_me(common_dtype)
result.add_event(other._div(result, self))
else:
return result
elif np.isscalar(other):
# create a new array for the result
result = self._new_like_me(common_dtype)
result.add_event(
self._rdiv_scalar(result, self, common_dtype.type(other)))

return result
return result
else:
return NotImplemented

__rtruediv__ = __rdiv__

Expand All @@ -1231,15 +1245,17 @@ def __itruediv__(self, other):
if isinstance(other, Array):
self.add_event(
self._div(self, self, other))
else:
return self
elif np.isscalar(other):
if other == 1:
return self
else:
self.add_event(
self._axpbz(self, common_dtype.type(np.true_divide(1, other)),
self, self.dtype.type(0)))

return self
return self
else:
return NotImplemented

def __and__(self, other):
common_dtype = _get_common_dtype(self, other, self.queue)
Expand Down

0 comments on commit 78bedd8

Please sign in to comment.