diff --git a/brainpy/_src/running/pathos_multiprocessing.py b/brainpy/_src/running/pathos_multiprocessing.py index 1573a541c..f652217d9 100644 --- a/brainpy/_src/running/pathos_multiprocessing.py +++ b/brainpy/_src/running/pathos_multiprocessing.py @@ -9,6 +9,7 @@ - ``cpu_unordered_parallel``: Performs a parallel unordered map. """ +import sys from collections.abc import Sized from typing import (Any, Callable, Generator, Iterable, List, Union, Optional, Sequence, Dict) @@ -20,6 +21,8 @@ try: from pathos.helpers import cpu_count # noqa from pathos.multiprocessing import ProcessPool # noqa + import multiprocess.context as ctx # noqa + ctx._force_start_method('spawn') except ModuleNotFoundError: cpu_count = None ProcessPool = None @@ -63,6 +66,10 @@ def _parallel( A generator which will apply the function to each element of the given Iterables in parallel in order with a progress bar. """ + if sys.platform == 'win32' and sys.version_info.minor >= 11: + raise NotImplementedError('Multiprocessing is not available in Python >=3.11 on Windows. ' + 'Please use Linux or MacOS, or Windows with Python <= 3.10.') + if ProcessPool is None or cpu_count is None: raise PackageMissingError( ''' diff --git a/brainpy/_src/running/tests/test_pathos_multiprocessing.py b/brainpy/_src/running/tests/test_pathos_multiprocessing.py new file mode 100644 index 000000000..6f92bda7e --- /dev/null +++ b/brainpy/_src/running/tests/test_pathos_multiprocessing.py @@ -0,0 +1,41 @@ +import sys + +import jax +import pytest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +if sys.platform == 'win32' and sys.version_info.minor >= 11: + pytest.skip('python 3.11 does not support.', allow_module_level=True) +else: + pytest.skip('Cannot pass tests.', allow_module_level=True) + + +class TestParallel(parameterized.TestCase): + @parameterized.product( + duration=[1e2, 1e3, 1e4, 1e5] + ) + def test_cpu_unordered_parallel_v1(self, duration): + @jax.jit + def body(inp): + return bm.for_loop(lambda x: x + 1e-9, inp) + + input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100 + + r = bp.running.cpu_ordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2) + assert bm.allclose(r[0], r[1]) + + @parameterized.product( + duration=[1e2, 1e3, 1e4, 1e5] + ) + def test_cpu_unordered_parallel_v2(self, duration): + @jax.jit + def body(inp): + return bm.for_loop(lambda x: x + 1e-9, inp) + + input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100 + + r = bp.running.cpu_unordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2) + assert bm.allclose(r[0], r[1]) diff --git a/requirements-dev.txt b/requirements-dev.txt index 93fa26af3..068c38546 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,9 +3,10 @@ numba brainpylib jax jaxlib -matplotlib>=3.4 +matplotlib msgpack tqdm +pathos # test requirements pytest diff --git a/requirements-doc.txt b/requirements-doc.txt index d4fe3f43e..c399c03b0 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -4,8 +4,8 @@ msgpack numba jax jaxlib -matplotlib>=3.4 -scipy>=1.1.0 +matplotlib +scipy numba # document requirements