diff --git a/mrQA/tests/conftest.py b/mrQA/tests/conftest.py index 3772bc2..3a10141 100644 --- a/mrQA/tests/conftest.py +++ b/mrQA/tests/conftest.py @@ -1,3 +1,4 @@ +import tempfile import typing as tp from pathlib import Path from typing import Tuple @@ -42,9 +43,11 @@ def create_dataset(draw_from: st.DrawFn) -> Tuple: repetition_time, echo_train_length, flip_angle) + temp_dir = Path(tempfile.mkdtemp()) ds = DicomDataset(name=name, data_source=fake_ds_dir, - config_path=THIS_DIR / 'resources/mri-config.json') + config_path=THIS_DIR / 'resources/mri-config.json', + output_dir=temp_dir) ref_protocol_path = sample_protocol() attributes = { 'name': name, diff --git a/mrQA/tests/test_parallel.py b/mrQA/tests/test_parallel.py index 52bbb67..01de999 100644 --- a/mrQA/tests/test_parallel.py +++ b/mrQA/tests/test_parallel.py @@ -29,7 +29,8 @@ def test_equivalence_seq_vs_parallel(): sequential_ds = import_dataset(data_source=data_source, ds_format='dicom', name='sequential', - config_path=config_path) + config_path=config_path, + output_dir=output_dir) save_mr_dataset(output_path['sequential'], sequential_ds) else: sequential_ds = load_mr_dataset(output_path['sequential']) @@ -83,58 +84,61 @@ def test_equivalence_seq_vs_parallel(): def test_merging(): # Sequential complete processing of the dataset data_source = sample_dicom_dataset() - output_dir = Path(data_source).parent / 'test_merge_mrqa_files' - output_path_seq = output_dir / ('sequential' + MRDS_EXT) - config_path = THIS_DIR / 'resources/mri-config.json' - - if not output_path_seq.exists(): - sequential_ds = import_dataset(data_source=data_source, - ds_format='dicom', - name='sequential', - config_path=config_path) - save_mr_dataset(output_path_seq, sequential_ds) - else: - sequential_ds = load_mr_dataset(output_path_seq) - - # Start processing in batches - folder_paths, files_per_batch, all_ids_path = _make_file_folders(output_dir) - - # For each batch create the list of ids to be processed - ids_path_list = split_folders_list( - data_source, - all_fnames_path=all_ids_path, - per_batch_ids=files_per_batch['fnames'], - output_dir=folder_paths['fnames'], - folders_per_job=5 - ) - - # The paths to the output files - output_path = {i: output_dir/f'seq{i}{MRDS_EXT}' - for i in range(len(ids_path_list))} - ds_list = [] - for i, filepath in enumerate(ids_path_list): - # Read the list of subject ids to be processed - subject_folders_list = txt2list(filepath) - if not output_path[i].exists(): - # Process the batch of subjects - ds = import_dataset(data_source=subject_folders_list, - ds_format='dicom', - name=f'seq{i}', - config_path=config_path) - save_mr_dataset(output_path[i], ds) - else: - ds = load_mr_dataset(output_path[i]) - ds_list.append(ds) - - # Merge batches - combined_mrds = None - for ds in ds_list: - if combined_mrds is None: - # Add the first partial dataset - combined_mrds = ds - else: - # otherwise, keep aggregating - combined_mrds.merge(ds) + with tempfile.TemporaryDirectory() as tempdir: + output_dir = Path(tempdir) + output_path_seq = output_dir / ('sequential' + MRDS_EXT) + config_path = THIS_DIR / 'resources/mri-config.json' - # Check if both datasets are the same - assert combined_mrds == sequential_ds + if not output_path_seq.exists(): + sequential_ds = import_dataset(data_source=data_source, + ds_format='dicom', + name='sequential', + config_path=config_path, + output_dir=output_dir) + save_mr_dataset(output_path_seq, sequential_ds) + else: + sequential_ds = load_mr_dataset(output_path_seq) + + # Start processing in batches + folder_paths, files_per_batch, all_ids_path = _make_file_folders(output_dir) + + # For each batch create the list of ids to be processed + ids_path_list = split_folders_list( + data_source, + all_fnames_path=all_ids_path, + per_batch_ids=files_per_batch['fnames'], + output_dir=folder_paths['fnames'], + folders_per_job=5 + ) + + # The paths to the output files + output_path = {i: folder_paths['mrds']/f'seq{i}{MRDS_EXT}' + for i in range(len(ids_path_list))} + ds_list = [] + for i, filepath in enumerate(ids_path_list): + # Read the list of subject ids to be processed + subject_folders_list = txt2list(filepath) + if not output_path[i].exists(): + # Process the batch of subjects + ds = import_dataset(data_source=subject_folders_list, + ds_format='dicom', + name=f'seq{i}', + config_path=config_path, + output_dir=folder_paths['mrds']) + save_mr_dataset(output_path[i], ds) + else: + ds = load_mr_dataset(output_path[i]) + ds_list.append(ds) + + # Merge batches + combined_mrds = None + for ds in ds_list: + if combined_mrds is None: + # Add the first partial dataset + combined_mrds = ds + else: + # otherwise, keep aggregating + combined_mrds.merge(ds) + save_mr_dataset(output_dir / ('parallel' + MRDS_EXT), combined_mrds) + # Check if both datasets are the same + assert combined_mrds == sequential_ds