From 78bedd865e2e6afaab9e876d53af0f717e3e4748 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 26 Jun 2021 12:02:21 -0500 Subject: [PATCH] correct pyopencl arithmetic to limit only for scalars --- pyopencl/array.py | 48 +++++++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index 9a66895fa..9627986d7 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -1067,7 +1067,7 @@ def __add__(self, other): other.dtype.type(1), other)) return result - else: + elif np.isscalar(other): # add a scalar if other == 0: return self.copy() @@ -1078,6 +1078,8 @@ def __add__(self, other): self._axpbz(result, self.dtype.type(1), self, common_dtype.type(other))) return result + else: + return NotImplemented __radd__ = __add__ @@ -1093,7 +1095,7 @@ def __sub__(self, other): result.dtype.type(-1), other)) return result - else: + elif np.isscalar(other): # subtract a scalar if other == 0: return self.copy() @@ -1103,6 +1105,8 @@ def __sub__(self, other): result.add_event( self._axpbz(result, self.dtype.type(1), self, -other)) return result + else: + return NotImplemented def __rsub__(self, other): """Substracts an array by a scalar or an array:: @@ -1135,9 +1139,11 @@ def __isub__(self, other): self._axpbyz(self, self.dtype.type(1), self, other.dtype.type(-1), other)) return self - else: + elif np.isscalar(other): self._axpbz(self, self.dtype.type(1), self, -other) return self + else: + return NotImplemented def __pos__(self): return self @@ -1154,13 +1160,15 @@ def __mul__(self, other): result.add_event( self._elwise_multiply(result, self, other)) return result - else: + elif np.isscalar(other): common_dtype = _get_common_dtype(self, other, self.queue) result = self._new_like_me(common_dtype) result.add_event( self._axpbz(result, common_dtype.type(other), self, self.dtype.type(0))) return result + else: + return NotImplemented def __rmul__(self, scalar): common_dtype = _get_common_dtype(self, scalar, self.queue) @@ -1174,12 +1182,14 @@ def __imul__(self, other): if isinstance(other, Array): self.add_event( self._elwise_multiply(self, self, other)) - else: + return self + elif np.isscalar(other): # scalar self.add_event( self._axpbz(self, other, self, self.dtype.type(0))) - - return self + return self + else: + return NotImplemented def __div__(self, other): """Divides an array by an array or a scalar, i.e. ``self / other``. @@ -1188,7 +1198,8 @@ def __div__(self, other): if isinstance(other, Array): result = self._new_like_me(common_dtype) result.add_event(self._div(result, self, other)) - else: + return result + elif np.isscalar(other): if other == 1: return self.copy() else: @@ -1198,8 +1209,9 @@ def __div__(self, other): self._axpbz(result, np.true_divide(common_dtype.type(1), other), self, self.dtype.type(0))) - - return result + return result + else: + return NotImplemented __truediv__ = __div__ @@ -1211,13 +1223,15 @@ def __rdiv__(self, other): if isinstance(other, Array): result = self._new_like_me(common_dtype) result.add_event(other._div(result, self)) - else: + return result + elif np.isscalar(other): # create a new array for the result result = self._new_like_me(common_dtype) result.add_event( self._rdiv_scalar(result, self, common_dtype.type(other))) - - return result + return result + else: + return NotImplemented __rtruediv__ = __rdiv__ @@ -1231,15 +1245,17 @@ def __itruediv__(self, other): if isinstance(other, Array): self.add_event( self._div(self, self, other)) - else: + return self + elif np.isscalar(other): if other == 1: return self else: self.add_event( self._axpbz(self, common_dtype.type(np.true_divide(1, other)), self, self.dtype.type(0))) - - return self + return self + else: + return NotImplemented def __and__(self, other): common_dtype = _get_common_dtype(self, other, self.queue)