diff --git a/pySDC/helpers/spectral_helper.py b/pySDC/helpers/spectral_helper.py index 7381b33e6..0ec096707 100644 --- a/pySDC/helpers/spectral_helper.py +++ b/pySDC/helpers/spectral_helper.py @@ -882,6 +882,7 @@ def __init__(self, comm=None, useGPU=False, debug=False): self.BCs = None self.fft_cache = {} + self.fft_dealias_shape_cache = {} @property def u_init(self): @@ -1470,7 +1471,9 @@ def _transform_dct(self, u, axes, padding=None, **kwargs): if padding is not None: shape = list(v.shape) - if self.comm: + if ('forward', *padding) in self.fft_dealias_shape_cache.keys(): + shape[0] = self.fft_dealias_shape_cache[('forward', *padding)] + elif self.comm: send_buf = np.array(v.shape[0]) recv_buf = np.array(v.shape[0]) self.comm.Allreduce(send_buf, recv_buf) @@ -1645,7 +1648,9 @@ def _transform_idct(self, u, axes, padding=None, **kwargs): if padding is not None: if padding[axis] != 1: shape = list(v.shape) - if self.comm: + if ('backward', *padding) in self.fft_dealias_shape_cache.keys(): + shape[0] = self.fft_dealias_shape_cache[('backward', *padding)] + elif self.comm: send_buf = np.array(v.shape[0]) recv_buf = np.array(v.shape[0]) self.comm.Allreduce(send_buf, recv_buf)