Skip to content

Commit

Permalink
Just a few small fixes (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
brownbaerchen authored Aug 26, 2024
1 parent 48bf389 commit 5b3a12e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 4 deletions.
17 changes: 17 additions & 0 deletions pySDC/helpers/NCCL_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __getattr__(self, name):
"""
if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']:
cp.cuda.get_current_stream().synchronize()

return getattr(self.commMPI, name)

@staticmethod
Expand Down Expand Up @@ -71,6 +72,9 @@ def get_op(self, MPI_op):
raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!')

def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
if not hasattr(sendbuf.data, 'ptr'):
return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root)

dtype = self.get_dtype(sendbuf)
count = self.get_count(sendbuf)
op = self.get_op(op)
Expand All @@ -88,6 +92,9 @@ def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
)

def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM):
if not hasattr(sendbuf.data, 'ptr'):
return self.commMPI.Allreduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op)

dtype = self.get_dtype(sendbuf)
count = self.get_count(sendbuf)
op = self.get_op(op)
Expand All @@ -96,3 +103,13 @@ def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM):
self.commNCCL.allReduce(
sendbuf=sendbuf.data.ptr, recvbuf=recvbuf.data.ptr, count=count, datatype=dtype, op=op, stream=stream.ptr
)

def Bcast(self, buf, root=0):
if not hasattr(buf.data, 'ptr'):
return self.commMPI.Bcast(buf=buf, root=root)

dtype = self.get_dtype(buf)
count = self.get_count(buf)
stream = cp.cuda.get_current_stream()

self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr)
2 changes: 2 additions & 0 deletions pySDC/helpers/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def figsize_by_journal(journal, scale, ratio): # pragma: no cover
textwidths = {
'JSC_beamer': 426.79135,
'Springer_Numerical_Algorithms': 338.58778,
'JSC_thesis': 434.26027,
}
# store text height in points here, get this from LaTeX using \the\textheight
textheights = {
'JSC_beamer': 214.43411,
'JSC_thesis': 635.5,
}
assert (
journal in textwidths.keys()
Expand Down
10 changes: 6 additions & 4 deletions pySDC/implementations/problem_classes/AllenCahn_MPIFFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class allencahn_imex(IMEX_Laplacian_MPIFFT):
r"""
Example implementing the :math:`N`-dimensional Allen-Cahn equation with periodic boundary conditions :math:`u \in [0, 1]^2`
Example implementing the :math:`2`-dimensional Allen-Cahn equation with periodic boundary conditions :math:`u \in [0, 1]^2`
.. math::
\frac{\partial u}{\partial t} = \Delta u - \frac{2}{\varepsilon^2} u (1 - u) (1 - 2u)
Expand Down Expand Up @@ -153,7 +153,9 @@ def u_exact(self, t, **kwargs):
# build radius
r2 = (self.X[0] + i - L + 0.5) ** 2 + (self.X[1] + j - L + 0.5) ** 2
# add this blob, shifted by 1 to avoid issues with adding up negative contributions
tmp += self.xp.tanh((rand_radii[i, j] - np.sqrt(r2)) / (np.sqrt(2) * self.eps)) + 1
tmp += self.xp.tanh((rand_radii[i, j] - self.xp.sqrt(r2)) / (np.sqrt(2) * self.eps)) + 1
else:
raise NotImplementedError
# normalize to [0,1]
tmp *= 0.5
assert self.xp.all(tmp <= 1.0)
Expand Down Expand Up @@ -183,8 +185,8 @@ class allencahn_imex_timeforcing(allencahn_imex):
u({\bf x}, 0) = \tanh\left(\frac{r - \sqrt{(x_i-0.5)^2 + (y_j-0.5)^2}}{\sqrt{2}\varepsilon}\right),
for :math:`i, j=0,..,N-1`, where :math:`N` is the number of spatial grid points. For time-stepping, the problem is treated
*semi-implicitly*, i.e., the linear part is solved with Fast-Fourier Transform (FFT) and the nonlinear part in the right-hand
side will be treated explicitly using ``mpi4py-fft`` [1]_ to solve them.
*semi-implicitly*, i.e., the linear part is solved with Fast-Fourier Transform (FFT) using ``mpi4py-fft`` [1]_ and the nonlinear part in the right-hand
side will be treated explicitly.
"""

def eval_f(self, u, t):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class nonlinearschroedinger_fully_implicit(nonlinearschroedinger_imex):
dtype_f = mesh

def __init__(self, lintol=1e-9, liniter=99, **kwargs):
assert kwargs.get('useGPU', False) is False

super().__init__(**kwargs)
self._makeAttributeAndRegister('liniter', 'lintol', localVars=locals(), readOnly=False)

Expand Down

0 comments on commit 5b3a12e

Please sign in to comment.