Skip to content

Commit

Permalink
Merge pull request #1136 from asagilmore/paraLocalTracking
Browse files Browse the repository at this point in the history
Streamlines task in parallel using ray
  • Loading branch information
arokem authored May 24, 2024
2 parents 2889d7a + 6340976 commit 7473e6e
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 11 deletions.
127 changes: 122 additions & 5 deletions AFQ/tasks/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,24 @@
import logging

import pimms
import multiprocessing

from AFQ.tasks.decorators import as_file, as_img
from AFQ.tasks.utils import with_name
from AFQ.definitions.utils import Definition
import AFQ.tractography.tractography as aft
from AFQ.tasks.utils import get_default_args
from AFQ.definitions.image import ScalarImage
from AFQ.tractography.utils import gen_seeds

try:
import ray
has_ray = True
except ModuleNotFoundError:
has_ray = False
try:
from trx.trx_file_memmap import TrxFile
from trx.trx_file_memmap import concatenate as trx_concatenate
has_trx = True
except ModuleNotFoundError:
has_trx = False
Expand Down Expand Up @@ -134,14 +142,108 @@ def streamlines(data_imap, seed, stop,

is_trx = this_tracking_params.get("trx", False)

num_chunks = this_tracking_params.pop("num_chunks", False)

if num_chunks is True:
num_chunks = multiprocessing.cpu_count() - 1

if is_trx:
start_time = time()
dtype_dict = {'positions': np.float16, 'offsets': np.uint32}
lazyt = aft.track(params_file, **this_tracking_params)
sft = TrxFile.from_lazy_tractogram(
lazyt,
seed,
dtype_dict=dtype_dict)
if num_chunks and num_chunks > 1:
if not has_ray:
raise ImportError("Ray is required to perform tractography in"
"parallel, please install ray or remove the"
" 'num_chunks' arg")

@ray.remote
class TractActor():
def __init__(self):
self.TrxFile = TrxFile
self.aft = aft
self.objects = {}

def trx_from_lazy_tractogram(self, lazyt_id, seed, dtype_dict):
id = self.objects[lazyt_id]
return self.TrxFile.from_lazy_tractogram(
id,
seed,
dtype_dict=dtype_dict)

def create_lazyt(self, id, *args, **kwargs):
self.objects[id] = self.aft.track(*args, **kwargs)
return id

def delete_lazyt(self, id):
if id in self.objects:
del self.objects[id]
actors = [TractActor.remote() for _ in range(num_chunks)]
object_id = 1
tracking_params_list = []

# random seeds case
if isinstance(this_tracking_params.get("n_seeds"), int) and \
this_tracking_params.get("random_seeds"):

remainder = this_tracking_params['n_seeds'] % num_chunks
for i in range(num_chunks):
# create copy of tracking params
copy = this_tracking_params.copy()
n_seeds = this_tracking_params['n_seeds']
copy['n_seeds'] = n_seeds // num_chunks
# add remainder to 1st list
if i == 0:
copy['n_seeds'] += remainder
tracking_params_list.append(copy)

elif isinstance(this_tracking_params['n_seeds'], (np.ndarray,
list)):
n_seeds = np.array(this_tracking_params['n_seeds'])
seed_chunks = np.array_split(n_seeds, num_chunks)
tracking_params_list = [this_tracking_params.copy() for _ in
range(num_chunks)]

for i in range(num_chunks):
tracking_params_list[i]['n_seeds'] = seed_chunks[i]

else:
seeds = gen_seeds(
this_tracking_params['seed_mask'],
this_tracking_params['seed_threshold'],
this_tracking_params['n_seeds'],
this_tracking_params['thresholds_as_percentages'],
this_tracking_params['random_seeds'],
this_tracking_params['rng_seed'],
data_imap["dwi_affine"])
seed_chunks = np.array_split(seeds, num_chunks)
tracking_params_list = [this_tracking_params.copy() for _
in range(num_chunks)]
for i in range(num_chunks):
tracking_params_list[i]['n_seeds'] = seed_chunks[i]

# create lazyt inside each actor
tasks = [ray_actor.create_lazyt.remote(object_id, params_file,
**tracking_params_list[i]) for i, ray_actor in
enumerate(actors)]
ray.get(tasks)

# create trx from lazyt
tasks = [ray_actor.trx_from_lazy_tractogram.remote(object_id, seed,
dtype_dict=dtype_dict) for ray_actor in actors]
sfts = ray.get(tasks)

# cleanup objects
tasks = [ray_actor.delete_lazyt.remote(object_id) for ray_actor in
actors]
ray.get(tasks)

sft = trx_concatenate(sfts)
else:
lazyt = aft.track(params_file, **this_tracking_params)
sft = TrxFile.from_lazy_tractogram(
lazyt,
seed,
dtype_dict=dtype_dict)
n_streamlines = len(sft)

else:
Expand Down Expand Up @@ -301,3 +403,18 @@ def get_tractography_plan(kwargs):
seed_mask.get_image_getter("tractography")))

return pimms.plan(**tractography_tasks)


def _gen_seeds(n_seeds, params_file, seed_mask=None, seed_threshold=0,
thresholds_as_percentages=False,
random_seeds=False, rng_seed=None):
if isinstance(params_file, str):
params_img = nib.load(params_file)
else:
params_img = params_file

affine = params_img.affine

return gen_seeds(seed_mask, seed_threshold, n_seeds,
thresholds_as_percentages,
random_seeds, rng_seed, affine)
6 changes: 4 additions & 2 deletions AFQ/tractography/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,
if thresholds_as_percentages:
stop_threshold = get_percentile_threshold(
stop_mask, stop_threshold)
stopping_criterion = ThresholdStoppingCriterion(stop_mask,
stop_threshold)
stop_mask_copy = np.copy(stop_mask)
stop_thresh_copy = np.copy(stop_threshold)
stopping_criterion = ThresholdStoppingCriterion(stop_mask_copy,
stop_thresh_copy)

my_tracker = LocalTracking

Expand Down
3 changes: 3 additions & 0 deletions docs/source/reference/kwargs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ tracking_params: dict, optional
this dict may be ``AFQ.definitions.image.ImageFile`` instances.
If ``tracker`` is set to "pft" then ``stop_mask`` should be
an instance of ``AFQ.definitions.image.PFTImage``.
num_chunks can be specified to cause tracking to be done in
parallel using ray. If set to True it will use the number of
cores available on the machine - 1.

import_tract: dict or str or None, optional
BIDS filters for inputing a user made tractography file,
Expand Down
11 changes: 7 additions & 4 deletions examples/tutorial_examples/plot_001_afq_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,17 @@
# ---------------------------------------
# We make create a `tracking_params` variable, which we will pass to the
# GroupAFQ object which specifies that we want 25,000 seeds randomly
# distributed in the white matter.
#
# We only do this to make this example faster and consume less space.
# distributed in the white matter. We only do this to make this example
# faster and consume less space. We also set ``num_chunks`` to `True`,
# which will use ray to parallelize the tracking across all cores.
# This can be removed to process in serial, or set to use a particular
# distribution of work by setting `n_chunks` to an integer number.

tracking_params = dict(n_seeds=25000,
random_seeds=True,
rng_seed=2022,
trx=True)
trx=True,
num_chunks=True)

##########################################################################
# Initialize a GroupAFQ object:
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ plot =
ipython>=7.13.0,<=7.20.0
trx =
trx-python
ray

all =
%(dev)s
Expand Down

0 comments on commit 7473e6e

Please sign in to comment.