Skip to content

Commit

Permalink
Rearrange ray initialization to avoid RuntimeWarning (#515)
Browse files Browse the repository at this point in the history
* Rearrange ray initialization to avoid os.fork() RuntimeWarning

* Locate ray initialization to avoid os.fork() RuntimeWarning

* Remove module-level ray.init call

* Attempt to clarify comment

* Bug fix
  • Loading branch information
bwohlberg authored Apr 29, 2024
1 parent e1f6c58 commit 738b4a0
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 46 deletions.
32 changes: 25 additions & 7 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@
"""
Configure the environment in which doctests run. This is necessary
because `np` is used in doc strings for jax functions
(e.g. `linear_transpose`) that get pulled into `scico/__init__.py`.
Also allow `snp` to be used without explicitly importing, and add
`level` parameter.
Configure pytest.
"""

import numpy as np

import pytest

try:
import ray # noqa: F401
except ImportError:
have_ray = False
else:
have_ray = True
ray.init(num_cpus=1) # call required to be here: see ray-project/ray#44087

import scico.numpy as snp


def pytest_sessionstart(session):
"""Initialize before start of test session."""
# placeholder: currently unused


def pytest_sessionfinish(session, exitstatus):
"""Clean up after end of test session."""
ray.shutdown()


@pytest.fixture(autouse=True)
def add_modules(doctest_namespace):
"""Add common modules for use in docstring examples."""
"""Add common modules for use in docstring examples.
Necessary because `np` is used in doc strings for jax functions
(e.g. `linear_transpose`) that get pulled into `scico/__init__.py`.
Also allow `snp` to be used without explicitly importing.
"""
doctest_namespace["np"] = np
doctest_namespace["snp"] = snp
5 changes: 5 additions & 0 deletions examples/scripts/ct_abel_tv_admm_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@

import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
Expand Down Expand Up @@ -151,6 +155,7 @@ def step(self):
num_iterations=10, # perform at most 10 steps for each parameter evaluation
)
results = tuner.fit()
ray.shutdown()


"""
Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/deconv_tv_admm_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["JAX_PLATFORMS"] = "cpu"


from xdesign import SiemensStar, discrete_phantom

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, metric, plot
Expand Down Expand Up @@ -119,6 +123,7 @@ def eval_params(config, x_gt, psf, y):
num_samples=100, # perform 100 parameter evaluations
)
results = tuner.fit()
ray.shutdown()


"""
Expand Down
45 changes: 14 additions & 31 deletions scico/flax/examples/data_generation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Functionality to generate training data for Flax example scripts.
Computation is distributed via ray (if available) or jax or to reduce
Computation is distributed via ray (if available) or JAX or to reduce
processing time.
"""

Expand All @@ -17,26 +17,26 @@

import numpy as np

import jax
import jax.numpy as jnp

try:
import xdesign # noqa: F401
import ray # noqa: F401
except ImportError:
have_xdesign = False
have_ray = False
else:
have_xdesign = True
have_ray = True

try:
import ray # noqa: F401
import xdesign # noqa: F401
except ImportError:
have_ray = False
have_xdesign = False
else:
have_ray = True
have_xdesign = True

if have_xdesign:
from xdesign import Foam, SimpleMaterial, UnitCircle, discrete_phantom

import jax
import jax.numpy as jnp

from scico.linop import CircularConvolve
from scico.numpy import Array

Expand Down Expand Up @@ -159,7 +159,6 @@ def generate_ct_data(
imgfunc: Callable = generate_foam2_images,
seed: int = 1234,
verbose: bool = False,
test_flag: bool = False,
prefer_ray: bool = True,
) -> Tuple[Array, ...]:
"""Generate batch of computed tomography (CT) data.
Expand All @@ -175,9 +174,6 @@ def generate_ct_data(
seed: Seed for data generation.
verbose: Flag indicating whether to print status messages.
Default: ``False``.
test_flag: Flag to indicate if running in testing mode. Testing
mode requires a different initialization of ray. Default:
``False``.
prefer_ray: Use ray for distributed processing if available.
Default: ``True``.
Expand All @@ -194,7 +190,7 @@ def generate_ct_data(
# Generate input data.
if have_ray and prefer_ray:
start_time = time()
img = ray_distributed_data_generation(imgfunc, size, nimg, seed, test_flag)
img = ray_distributed_data_generation(imgfunc, size, nimg, seed)
time_dtgen = time() - start_time
else:
start_time = time()
Expand Down Expand Up @@ -249,7 +245,6 @@ def generate_blur_data(
imgfunc: Callable,
seed: int = 4321,
verbose: bool = False,
test_flag: bool = False,
prefer_ray: bool = True,
) -> Tuple[Array, ...]:
"""Generate batch of blurred data.
Expand All @@ -266,9 +261,6 @@ def generate_blur_data(
seed: Seed for data generation.
verbose: Flag indicating whether to print status messages.
Default: ``False``.
test_flag: Flag to indicate if running in testing mode. Testing
mode requires a different initialization of ray.
Default: ``False``.
prefer_ray: Use ray for distributed processing if available.
Default: ``True``.
Expand All @@ -280,7 +272,7 @@ def generate_blur_data(
"""
if have_ray and prefer_ray:
start_time = time()
img = ray_distributed_data_generation(imgfunc, size, nimg, seed, test_flag)
img = ray_distributed_data_generation(imgfunc, size, nimg, seed)
time_dtgen = time() - start_time
else:
start_time = time()
Expand Down Expand Up @@ -353,7 +345,7 @@ def distributed_data_generation(


def ray_distributed_data_generation(
imgenf: Callable, size: int, nimg: int, seedg: float = 123, test_flag: bool = False
imgenf: Callable, size: int, nimg: int, seedg: float = 123
) -> Array:
"""Data generation distributed among processes using ray.
Expand All @@ -362,21 +354,13 @@ def ray_distributed_data_generation(
size: Size of image to generate.
ndata: Number of images to generate.
seedg: Base seed for data generation. Default: 123.
test_flag: Flag to indicate if running in testing mode. Testing
mode requires a different initialization of ray. Default:
``False``.
Returns:
Array of generated data.
"""
if not have_ray:
raise RuntimeError("Package ray is required for use of this function.")

if test_flag:
ray.init(ignore_reinit_error=True)
else:
ray.init()

@ray.remote
def data_gen(seed, size, ndata, imgf):
return imgf(seed, size, ndata)
Expand All @@ -398,6 +382,5 @@ def data_gen(seed, size, ndata, imgf):
[data_gen.remote(seed + seedg, size, ndata_per_proc, imgenf) for seed in range(nproc)]
)
imgs = np.vstack([t for t in ray_return])
ray.shutdown()

return imgs
10 changes: 7 additions & 3 deletions scico/test/flax/test_examples_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_distdatagen_exception():
def test_ray_distdatagen():
N = 16
nimg = 8
dt = ray_distributed_data_generation(fake_data_gen, N, nimg, test_flag=True)
dt = ray_distributed_data_generation(fake_data_gen, N, nimg)
assert dt.ndim == 4
assert dt.shape == (nimg, N, N, 1)

Expand All @@ -115,7 +115,7 @@ def random_img_gen(seed, size, ndata):
return np.random.randn(ndata, size, size, 1)

try:
img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen, test_flag=True)
img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen)
except Exception as e:
print(e)
assert 0
Expand Down Expand Up @@ -158,7 +158,11 @@ def random_img_gen(seed, size, ndata):

try:
img, blurn = generate_blur_data(
nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen, test_flag=True
nimg,
N,
blur_kernel,
noise_sigma=0.01,
imgfunc=random_img_gen,
)
except Exception as e:
print(e)
Expand Down
2 changes: 0 additions & 2 deletions scico/test/test_ray_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
try:
import ray
from scico.ray import report, tune

ray.init(num_cpus=1)
except ImportError as e:
pytest.skip("ray.tune not installed", allow_module_level=True)

Expand Down
4 changes: 2 additions & 2 deletions scico/test/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ def _internet_connected(host="8.8.8.8", port=53, timeout=3):

@pytest.mark.skipif(not _internet_connected(), reason="No internet connection")
def test_url_get():
url = "https://github.com/lanl/scico/blob/main/README.rst"
url = "https://github.com/lanl/scico/blob/main/README.md"
assert not url_get(url).getvalue().find(b"SCICO") == -1

url = "about:blank"
np.testing.assert_raises(urlerror.URLError, url_get, url)

url = "https://github.com/lanl/scico/blob/main/README.rst"
url = "https://github.com/lanl/scico/blob/main/README.md"
np.testing.assert_raises(ValueError, url_get, url, -1)


Expand Down

0 comments on commit 738b4a0

Please sign in to comment.