diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index c67567c1..3f6167c9 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -5974,28 +5974,24 @@ class PolarAzimuthalComponent(operators.AzimuthalComponent): basis_type = IntervalBasis def subproblem_matrix(self, subproblem): - # I'm not sure how to generalize this to higher order tensors, since we do - # not have spin_weights for the S1 basis. - matrix = np.array([[1,0]]) + operand = self.args[0] + input_dim = len(operand.tensorsig) + output_dim = len(self.tensorsig) + matrix = [] + for output in range(2**output_dim): + index_out = np.unravel_index(output, [2]*output_dim) + matrix_row = [] + for input in range(2**input_dim): + index_in = np.unravel_index(input, [2]*input_dim) + if tuple(index_in[:self.index] + index_in[self.index+1:]) == index_out and index_in[self.index] == 0: + matrix_row.append(1) + else: + matrix_row.append(0) + matrix.append(matrix_row) + matrix = np.array(matrix) if self.dtype == np.float64: # Block-diag for sin/cos parts for real dtype matrix = sparse.kron(matrix, sparse.eye(2)) - -# operand = self.args[0] -# basis = self.domain.get_basis(self.coordsys) -# S_in = basis.spin_weights(operand.tensorsig) -# S_out = basis.spin_weights(self.tensorsig) -# -# matrix = [] -# for spinindex_out, spintotal_out in np.ndenumerate(S_out): -# matrix_row = [] -# for spinindex_in, spintotal_in in np.ndenumerate(S_in): -# if tuple(spinindex_in[:self.index] + spinindex_in[self.index+1:]) == spinindex_out and spinindex_in[self.index] == 2: -# matrix_row.append( 1 ) -# else: -# matrix_row.append( 0 ) -# matrix.append(matrix_row) -# matrix = np.array(matrix) return matrix def operate(self, out): @@ -6012,28 +6008,24 @@ class PolarRadialComponent(operators.RadialComponent): basis_type = IntervalBasis def subproblem_matrix(self, subproblem): - # I'm not sure how to generalize this to higher order tensors, since we do - # not have spin_weights for the S1 basis. - matrix = np.array([[0,1]]) + operand = self.args[0] + input_dim = len(operand.tensorsig) + output_dim = len(self.tensorsig) + matrix = [] + for output in range(2**output_dim): + index_out = np.unravel_index(output, [2]*output_dim) + matrix_row = [] + for input in range(2**input_dim): + index_in = np.unravel_index(input, [2]*input_dim) + if tuple(index_in[:self.index] + index_in[self.index+1:]) == index_out and index_in[self.index] == 1: + matrix_row.append(1) + else: + matrix_row.append(0) + matrix.append(matrix_row) + matrix = np.array(matrix) if self.dtype == np.float64: # Block-diag for sin/cos parts for real dtype matrix = sparse.kron(matrix, sparse.eye(2)) - -# operand = self.args[0] -# basis = self.domain.get_basis(self.coordsys) -# S_in = basis.spin_weights(operand.tensorsig) -# S_out = basis.spin_weights(self.tensorsig) -# -# matrix = [] -# for spinindex_out, spintotal_out in np.ndenumerate(S_out): -# matrix_row = [] -# for spinindex_in, spintotal_in in np.ndenumerate(S_in): -# if tuple(spinindex_in[:self.index] + spinindex_in[self.index+1:]) == spinindex_out and spinindex_in[self.index] == 2: -# matrix_row.append( 1 ) -# else: -# matrix_row.append( 0 ) -# matrix.append(matrix_row) -# matrix = np.array(matrix) return matrix def operate(self, out):