Skip to content

Commit

Permalink
add output_dir to import_dataset and test merge
Browse files Browse the repository at this point in the history
  • Loading branch information
sinhaharsh committed Nov 22, 2023
1 parent 91307cf commit b60cad4
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 56 deletions.
5 changes: 4 additions & 1 deletion mrQA/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tempfile
import typing as tp
from pathlib import Path
from typing import Tuple
Expand Down Expand Up @@ -42,9 +43,11 @@ def create_dataset(draw_from: st.DrawFn) -> Tuple:
repetition_time,
echo_train_length,
flip_angle)
temp_dir = Path(tempfile.mkdtemp())
ds = DicomDataset(name=name,
data_source=fake_ds_dir,
config_path=THIS_DIR / 'resources/mri-config.json')
config_path=THIS_DIR / 'resources/mri-config.json',
output_dir=temp_dir)
ref_protocol_path = sample_protocol()
attributes = {
'name': name,
Expand Down
114 changes: 59 additions & 55 deletions mrQA/tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def test_equivalence_seq_vs_parallel():
sequential_ds = import_dataset(data_source=data_source,
ds_format='dicom',
name='sequential',
config_path=config_path)
config_path=config_path,
output_dir=output_dir)
save_mr_dataset(output_path['sequential'], sequential_ds)
else:
sequential_ds = load_mr_dataset(output_path['sequential'])
Expand Down Expand Up @@ -83,58 +84,61 @@ def test_equivalence_seq_vs_parallel():
def test_merging():
# Sequential complete processing of the dataset
data_source = sample_dicom_dataset()
output_dir = Path(data_source).parent / 'test_merge_mrqa_files'
output_path_seq = output_dir / ('sequential' + MRDS_EXT)
config_path = THIS_DIR / 'resources/mri-config.json'

if not output_path_seq.exists():
sequential_ds = import_dataset(data_source=data_source,
ds_format='dicom',
name='sequential',
config_path=config_path)
save_mr_dataset(output_path_seq, sequential_ds)
else:
sequential_ds = load_mr_dataset(output_path_seq)

# Start processing in batches
folder_paths, files_per_batch, all_ids_path = _make_file_folders(output_dir)

# For each batch create the list of ids to be processed
ids_path_list = split_folders_list(
data_source,
all_fnames_path=all_ids_path,
per_batch_ids=files_per_batch['fnames'],
output_dir=folder_paths['fnames'],
folders_per_job=5
)

# The paths to the output files
output_path = {i: output_dir/f'seq{i}{MRDS_EXT}'
for i in range(len(ids_path_list))}
ds_list = []
for i, filepath in enumerate(ids_path_list):
# Read the list of subject ids to be processed
subject_folders_list = txt2list(filepath)
if not output_path[i].exists():
# Process the batch of subjects
ds = import_dataset(data_source=subject_folders_list,
ds_format='dicom',
name=f'seq{i}',
config_path=config_path)
save_mr_dataset(output_path[i], ds)
else:
ds = load_mr_dataset(output_path[i])
ds_list.append(ds)

# Merge batches
combined_mrds = None
for ds in ds_list:
if combined_mrds is None:
# Add the first partial dataset
combined_mrds = ds
else:
# otherwise, keep aggregating
combined_mrds.merge(ds)
with tempfile.TemporaryDirectory() as tempdir:
output_dir = Path(tempdir)
output_path_seq = output_dir / ('sequential' + MRDS_EXT)
config_path = THIS_DIR / 'resources/mri-config.json'

# Check if both datasets are the same
assert combined_mrds == sequential_ds
if not output_path_seq.exists():
sequential_ds = import_dataset(data_source=data_source,
ds_format='dicom',
name='sequential',
config_path=config_path,
output_dir=output_dir)
save_mr_dataset(output_path_seq, sequential_ds)
else:
sequential_ds = load_mr_dataset(output_path_seq)

# Start processing in batches
folder_paths, files_per_batch, all_ids_path = _make_file_folders(output_dir)

# For each batch create the list of ids to be processed
ids_path_list = split_folders_list(
data_source,
all_fnames_path=all_ids_path,
per_batch_ids=files_per_batch['fnames'],
output_dir=folder_paths['fnames'],
folders_per_job=5
)

# The paths to the output files
output_path = {i: folder_paths['mrds']/f'seq{i}{MRDS_EXT}'
for i in range(len(ids_path_list))}
ds_list = []
for i, filepath in enumerate(ids_path_list):
# Read the list of subject ids to be processed
subject_folders_list = txt2list(filepath)
if not output_path[i].exists():
# Process the batch of subjects
ds = import_dataset(data_source=subject_folders_list,
ds_format='dicom',
name=f'seq{i}',
config_path=config_path,
output_dir=folder_paths['mrds'])
save_mr_dataset(output_path[i], ds)
else:
ds = load_mr_dataset(output_path[i])
ds_list.append(ds)

# Merge batches
combined_mrds = None
for ds in ds_list:
if combined_mrds is None:
# Add the first partial dataset
combined_mrds = ds
else:
# otherwise, keep aggregating
combined_mrds.merge(ds)
save_mr_dataset(output_dir / ('parallel' + MRDS_EXT), combined_mrds)
# Check if both datasets are the same
assert combined_mrds == sequential_ds

0 comments on commit b60cad4

Please sign in to comment.