Skip to content

Commit

Permalink
[running] fix multiprocessing bugs (#547)
Browse files Browse the repository at this point in the history
* [running] fix multiprocessing bugs

* fix tests
  • Loading branch information
chaoming0625 authored Nov 18, 2023
1 parent 20c1f40 commit d8d5793
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 3 deletions.
7 changes: 7 additions & 0 deletions brainpy/_src/running/pathos_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- ``cpu_unordered_parallel``: Performs a parallel unordered map.
"""

import sys
from collections.abc import Sized
from typing import (Any, Callable, Generator, Iterable, List,
Union, Optional, Sequence, Dict)
Expand All @@ -20,6 +21,8 @@
try:
from pathos.helpers import cpu_count # noqa
from pathos.multiprocessing import ProcessPool # noqa
import multiprocess.context as ctx # noqa
ctx._force_start_method('spawn')
except ModuleNotFoundError:
cpu_count = None
ProcessPool = None
Expand Down Expand Up @@ -63,6 +66,10 @@ def _parallel(
A generator which will apply the function to each element of the given Iterables
in parallel in order with a progress bar.
"""
if sys.platform == 'win32' and sys.version_info.minor >= 11:
raise NotImplementedError('Multiprocessing is not available in Python >=3.11 on Windows. '
'Please use Linux or MacOS, or Windows with Python <= 3.10.')

if ProcessPool is None or cpu_count is None:
raise PackageMissingError(
'''
Expand Down
41 changes: 41 additions & 0 deletions brainpy/_src/running/tests/test_pathos_multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import sys

import jax
import pytest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm

if sys.platform == 'win32' and sys.version_info.minor >= 11:
pytest.skip('python 3.11 does not support.', allow_module_level=True)
else:
pytest.skip('Cannot pass tests.', allow_module_level=True)


class TestParallel(parameterized.TestCase):
@parameterized.product(
duration=[1e2, 1e3, 1e4, 1e5]
)
def test_cpu_unordered_parallel_v1(self, duration):
@jax.jit
def body(inp):
return bm.for_loop(lambda x: x + 1e-9, inp)

input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100

r = bp.running.cpu_ordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2)
assert bm.allclose(r[0], r[1])

@parameterized.product(
duration=[1e2, 1e3, 1e4, 1e5]
)
def test_cpu_unordered_parallel_v2(self, duration):
@jax.jit
def body(inp):
return bm.for_loop(lambda x: x + 1e-9, inp)

input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100

r = bp.running.cpu_unordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2)
assert bm.allclose(r[0], r[1])
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ numba
brainpylib
jax
jaxlib
matplotlib>=3.4
matplotlib
msgpack
tqdm
pathos

# test requirements
pytest
Expand Down
4 changes: 2 additions & 2 deletions requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ msgpack
numba
jax
jaxlib
matplotlib>=3.4
scipy>=1.1.0
matplotlib
scipy
numba

# document requirements
Expand Down

0 comments on commit d8d5793

Please sign in to comment.