Skip to content

Commit

Permalink
Enable Chaining in Auto3DSeg CLI (Project-MONAI#7168)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#7167

### Description
Make all the setting methods in `AutoRunner` to return `self` to enable
chaining in cli.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).

---------

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>
  • Loading branch information
drbeh authored Oct 30, 2023
1 parent 798570c commit 1c17f0e
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int:

def set_gpu_customization(
self, gpu_customization: bool = False, gpu_customization_specs: dict[str, Any] | None = None
) -> None:
) -> AutoRunner:
"""
Set options for GPU-based parameter customization/optimization.
Expand Down Expand Up @@ -442,7 +442,9 @@ def set_gpu_customization(
if gpu_customization_specs is not None:
self.gpu_customization_specs = gpu_customization_specs

def set_num_fold(self, num_fold: int = 5) -> None:
return self

def set_num_fold(self, num_fold: int = 5) -> AutoRunner:
"""
Set the number of cross validation folds for all algos.
Expand All @@ -454,7 +456,9 @@ def set_num_fold(self, num_fold: int = 5) -> None:
raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}")
self.num_fold = num_fold

def set_training_params(self, params: dict[str, Any] | None = None) -> None:
return self

def set_training_params(self, params: dict[str, Any] | None = None) -> AutoRunner:
"""
Set the training params for all algos.
Expand All @@ -474,13 +478,15 @@ def set_training_params(self, params: dict[str, Any] | None = None) -> None:
DeprecationWarning,
)

return self

def set_device_info(
self,
cuda_visible_devices: list[int] | str | None = None,
num_nodes: int | None = None,
mn_start_method: str | None = None,
cmd_prefix: str | None = None,
) -> None:
) -> AutoRunner:
"""
Set the device related info
Expand Down Expand Up @@ -531,7 +537,9 @@ def set_device_info(
if cmd_prefix is not None:
logger.info(f"Using user defined command running prefix {cmd_prefix}, will override other settings")

def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> None:
return self

def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> AutoRunner:
"""
Set the bundle ensemble method name and parameters for save image transform parameters.
Expand All @@ -546,7 +554,9 @@ def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFol
)
self.kwargs.update(kwargs)

def set_image_save_transform(self, **kwargs: Any) -> None:
return self

def set_image_save_transform(self, **kwargs: Any) -> AutoRunner:
"""
Set the ensemble output transform.
Expand All @@ -565,7 +575,9 @@ def set_image_save_transform(self, **kwargs: Any) -> None:
"Check https://docs.monai.io/en/stable/transforms.html#saveimage for more information."
)

def set_prediction_params(self, params: dict[str, Any] | None = None) -> None:
return self

def set_prediction_params(self, params: dict[str, Any] | None = None) -> AutoRunner:
"""
Set the prediction params for all algos.
Expand All @@ -581,7 +593,9 @@ def set_prediction_params(self, params: dict[str, Any] | None = None) -> None:
"""
self.pred_params = deepcopy(params) if params is not None else {}

def set_analyze_params(self, params: dict[str, Any] | None = None) -> None:
return self

def set_analyze_params(self, params: dict[str, Any] | None = None) -> AutoRunner:
"""
Set the data analysis extra params.
Expand All @@ -595,7 +609,9 @@ def set_analyze_params(self, params: dict[str, Any] | None = None) -> None:
else:
self.analyze_params = deepcopy(params)

def set_hpo_params(self, params: dict[str, Any] | None = None) -> None:
return self

def set_hpo_params(self, params: dict[str, Any] | None = None) -> AutoRunner:
"""
Set parameters for the HPO module and the algos before the training. It will attempt to (1) override bundle
templates with the key-value pairs in ``params`` (2) change the config of the HPO module (e.g. NNI) if the
Expand All @@ -621,7 +637,9 @@ def set_hpo_params(self, params: dict[str, Any] | None = None) -> None:
"""
self.hpo_params = self.train_params if params is None else params

def set_nni_search_space(self, search_space):
return self

def set_nni_search_space(self, search_space: dict[str, Any]) -> AutoRunner:
"""
Set the search space for NNI parameter search.
Expand All @@ -638,6 +656,8 @@ def set_nni_search_space(self, search_space):
self.search_space = search_space
self.hpo_tasks = value_combinations

return self

def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:
"""
Train the Algos in a sequential scheme. The order of training is randomized.
Expand Down

0 comments on commit 1c17f0e

Please sign in to comment.