-
Notifications
You must be signed in to change notification settings - Fork 0
/
pytest_mpi.py
281 lines (217 loc) · 9.33 KB
/
pytest_mpi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import collections
import numbers
import os
import subprocess
import sys
from pathlib import Path
from warnings import warn
import pytest
from mpi4py import MPI
MAX_NPROCS_FLAG = "PYTEST_MPI_MAX_NPROCS"
"""Environment variable that can be set to limit the maximum number of processes.
If set then requesting a parallel test with more processes than it will raise an
error. If unset then any value is accepted.
"""
CHILD_PROCESS_FLAG = "_PYTEST_MPI_CHILD_PROCESS"
"""Environment variable set for the processes spawned by the mpiexec call."""
_plugin_in_use = False
"""Global variable set internally to indicate that parallel markers are used."""
@pytest.hookimpl()
def pytest_configure(config):
config.addinivalue_line(
"markers",
"parallel(nprocs): mark test to run in parallel on nprocs processors (default: 3)"
)
@pytest.hookimpl(trylast=True)
def pytest_sessionstart(session):
if MPI.COMM_WORLD.size > 1 and not _is_parallel_child_process() and _xdist_active(session):
raise pytest.UsageError(
"Wrapping pytest calls in mpiexec is only supported if pytest-xdist "
"is not in use"
)
@pytest.hookimpl()
def pytest_generate_tests(metafunc):
"""Identify tests with parallel markers and break them apart if necessary.
This hook turns tests with marks like ``@pytest.mark.parallel([2, 3, 4])``
into multiple tests, one for each requested size. The tests are then
distinguished by ID. For example ``test_abc[nprocs=2]``, ``test_abc[nprocs=3]``
and ``test_abc[nprocs=4]``. If only one parallel size is requested then this
is skipped.
"""
markers = tuple(
m
for m in getattr(metafunc.function, "pytestmark", ())
if m.name == "parallel"
)
if not markers:
return
marker, = markers
nprocss = _parse_marker_nprocs(marker)
if MAX_NPROCS_FLAG in os.environ:
max_nprocs = int(os.environ[MAX_NPROCS_FLAG])
for nprocs in nprocss:
if nprocs > max_nprocs:
raise pytest.UsageError(
"Requested a parallel test with too many ranks "
f"({nprocs} > {MAX_NPROCS_FLAG}={max_nprocs})"
)
# Only label tests if more than one parallel size is requested
if len(nprocss) > 1:
# Trick the function into thinking that it needs an extra fixture argument
metafunc.fixturenames.append("_nprocs")
metafunc.parametrize("_nprocs", nprocss, ids=lambda n: f"nprocs={n}")
@pytest.hookimpl()
def pytest_collection_modifyitems(config, items):
global _plugin_in_use
_plugin_in_use = any(item.get_closest_marker("parallel") for item in items)
if not _plugin_in_use:
return
for item in items:
if item.get_closest_marker("parallel"):
# Add extra markers to each test to allow for querying specific levels of
# parallelism (e.g. "-m parallel[3]")
nprocs = _extract_nprocs_for_single_test(item)
new_marker = f"parallel[{nprocs}]"
if new_marker not in pytest.mark._markers:
config.addinivalue_line(
"markers",
f"{new_marker}: internal marker"
)
item.add_marker(getattr(pytest.mark, new_marker))
@pytest.hookimpl()
def pytest_runtest_setup(item):
if not _plugin_in_use:
return
if item.get_closest_marker("parallel"):
if MPI.COMM_WORLD.size == 1:
# If using pytest-mpi in "forking" mode, add a callback to item
# that calls mpiexec
assert not _is_parallel_child_process()
_set_parallel_callback(item)
elif _is_parallel_child_process():
# Already a forked subprocess, run the test unmodified
pass
else:
# Outer mpiexec used, do not fork a subprocess but fail if the
# requested parallelism does not match the provided amount
nprocs = _extract_nprocs_for_single_test(item)
if nprocs != MPI.COMM_WORLD.size:
raise pytest.UsageError(
"Attempting to run parallel tests inside an mpiexec call "
"where the requested and provided process counts do not match"
)
else:
# serial test
if MPI.COMM_WORLD.size != 1:
raise pytest.UsageError(
"Serial tests should not be run by multiple processes, consider "
"adding a parallel marker to the test"
)
@pytest.fixture(scope="function", autouse=True)
def barrier_finalize(request):
"""Call an MPI barrier at the end of each test.
This should help localise tests that are not fully collective.
"""
if _plugin_in_use:
request.addfinalizer(lambda: MPI.COMM_WORLD.barrier())
@pytest.fixture(scope="session", autouse=True)
def spawn_finalize(request):
"""Disconnect from a parent process at the end of the session.
If the session is started by an MPI.Intracomm.Spawn call report the process
status back to the parent and clean up.
"""
def _disconnect():
parent_comm = MPI.Comm.Get_parent()
if request.session.testsfailed:
status = pytest.ExitCode.TESTS_FAILED
elif request.session.testscollected == 0:
status = pytest.ExitCode.NO_TESTS_COLLECTED
else:
status = pytest.ExitCode.OK
if parent_comm != MPI.COMM_NULL:
parent_comm.gather(status, root=0)
parent_comm.Disconnect()
if _plugin_in_use:
request.addfinalizer(_disconnect)
def _is_parallel_child_process():
return CHILD_PROCESS_FLAG in os.environ
def _xdist_active(session):
try:
import xdist
return xdist.is_xdist_controller(session) or xdist.is_xdist_worker(session)
except ImportError:
return False
def _set_parallel_callback(item):
"""Replace the callback for a test item with one that calls ``mpiexec``.
If the number of processes requested is 1 then this function does nothing.
Parameters
----------
item : _pytest.nodes.Item
The test item to run.
"""
nprocs = _extract_nprocs_for_single_test(item)
assert isinstance(nprocs, numbers.Integral)
if nprocs == 1:
return
# Set the executable by sniffing sys.argv[0]
# This is necessary since invoking pytest in different ways leads to different behaviour:
# https://docs.pytest.org/en/latest/how-to/usage.html#calling-pytest-through-python-m-pytest
full_path = Path(sys.argv[0])
if full_path.name == "pytest":
# If pytest was launched as `pytest ...`
executable = [sys.argv[0]]
else:
# Otherwise assume pytest was launched as `python -m pytest ...`
executable = [sys.executable, "-m", "pytest"]
if Path('/'.join(full_path.parts[-2:])) != Path('pytest/__main__.py'):
# But warn users if it doesn't look right!
warn(f"Unrecognised pytest invocation, trying {' '.join(executable)}")
# Run xfailing tests to ensure that errors are reported to calling process
pytest_args = ["--runxfail", "-s", "-q", f"{item.fspath}::{item.name}"]
# Try to generate less output on other ranks so stdout is easier to read
quieter_pytest_args = pytest_args + [
"--tb=no", "--no-summary", "--no-header",
"--disable-warnings", "--show-capture=no"
]
cmd = [
"mpiexec", "-n", "1", "-genv", CHILD_PROCESS_FLAG, "1", *executable
] + pytest_args + [
":", "-n", f"{nprocs-1}", *executable
] + quieter_pytest_args
def parallel_callback(*args, **kwargs):
subprocess.run(cmd, check=True)
item.obj = parallel_callback
def _extract_nprocs_for_single_test(item):
"""Extract the number of processes that a test is supposed to be run with.
Unlike `_parse_marker_nprocs`, this function applies to tests that have already
been set to require a fixed level of parallelism. In other words, if the
parallel marker requested, say, ``[2, 3]`` processes, the tests input to
this function have already been split into ``[nprocs=2]`` and ``[nprocs=3]``
versions. Therefore, this function returns an integer, rather than a tuple.
"""
# First check to see if we have parametrised nprocs (if multiple were requested)
if hasattr(item, "callspec") and "_nprocs" in item.callspec.params:
nprocs = item.callspec.params["_nprocs"]
else:
# The parallel marker must just want one value of nprocs
marker = item.get_closest_marker("parallel")
nprocs, = _parse_marker_nprocs(marker)
return nprocs
def _parse_marker_nprocs(marker):
"""Return the number of processes requested from a parallel marker.
This function enables one to use the parallel marker with or without
using the ``nprocs`` keyword argument.
The returned process counts are provided as a tuple, even if only a
single value is requested.
"""
assert marker.name == "parallel"
if len(marker.args) == 1 and not marker.kwargs:
return _as_tuple(marker.args[0])
elif len(marker.kwargs) == 1 and not marker.args:
return _as_tuple(marker.kwargs["nprocs"])
elif not marker.args and not marker.kwargs:
return (3,)
else:
raise pytest.UsageError("Bad arguments given to parallel marker")
def _as_tuple(arg):
return tuple(arg) if isinstance(arg, collections.abc.Iterable) else (arg,)