Skip to content

Commit

Permalink
autorunner params from config (#7175)
Browse files Browse the repository at this point in the history
allows setting AutoRunner params from config
allows specifying number of folds in config

---------

Signed-off-by: myron <amyronenko@nvidia.com>
  • Loading branch information
myron authored Nov 3, 2023
1 parent 2658b00 commit 4847df2
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 28 deletions.
79 changes: 52 additions & 27 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,22 +214,11 @@ def __init__(
mlflow_tracking_uri: str | None = None,
**kwargs: Any,
):
logger.info(f"AutoRunner using work directory {work_dir}")
os.makedirs(work_dir, exist_ok=True)

self.work_dir = os.path.abspath(work_dir)
self.data_src_cfg = dict()
self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml")
self.algos = algos
self.templates_path_or_url = templates_path_or_url
self.allow_skip = allow_skip
self.mlflow_tracking_uri = mlflow_tracking_uri
self.kwargs = deepcopy(kwargs)

if input is None and os.path.isfile(self.data_src_cfg_name):
input = self.data_src_cfg_name
if input is None and os.path.isfile(os.path.join(os.path.abspath(work_dir), "input.yaml")):
input = os.path.join(os.path.abspath(work_dir), "input.yaml")
logger.info(f"Input config is not provided, using the default {input}")

self.data_src_cfg = dict()
if isinstance(input, dict):
self.data_src_cfg = input
elif isinstance(input, str) and os.path.isfile(input):
Expand All @@ -238,6 +227,51 @@ def __init__(
else:
raise ValueError(f"{input} is not a valid file or dict")

if "work_dir" in self.data_src_cfg: # override from config
work_dir = self.data_src_cfg["work_dir"]
self.work_dir = os.path.abspath(work_dir)

logger.info(f"AutoRunner using work directory {self.work_dir}")
os.makedirs(self.work_dir, exist_ok=True)
self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml")

self.algos = algos
self.templates_path_or_url = templates_path_or_url
self.allow_skip = allow_skip

# cache.yaml
self.not_use_cache = not_use_cache
self.cache_filename = os.path.join(self.work_dir, "cache.yaml")
self.cache = self.read_cache()
self.export_cache()

# determine if we need to analyze, algo_gen or train from cache, unless manually provided
self.analyze = not self.cache["analyze"] if analyze is None else analyze
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
self.train = train
self.ensemble = ensemble # last step, no need to check
self.hpo = hpo and has_nni
self.hpo_backend = hpo_backend
self.mlflow_tracking_uri = mlflow_tracking_uri
self.kwargs = deepcopy(kwargs)

# parse input config for AutoRunner param overrides
for param in [
"analyze",
"algo_gen",
"train",
"hpo",
"ensemble",
"not_use_cache",
"allow_skip",
]: # override from config
if param in self.data_src_cfg and isinstance(self.data_src_cfg[param], bool):
setattr(self, param, self.data_src_cfg[param]) # e.g. self.analyze = self.data_src_cfg["analyze"]

for param in ["algos", "hpo_backend", "templates_path_or_url", "mlflow_tracking_uri"]: # override from config
if param in self.data_src_cfg:
setattr(self, param, self.data_src_cfg[param]) # e.g. self.algos = self.data_src_cfg["algos"]

missing_keys = {"dataroot", "datalist", "modality"}.difference(self.data_src_cfg.keys())
if len(missing_keys) > 0:
raise ValueError(f"Config keys are missing {missing_keys}")
Expand All @@ -256,6 +290,8 @@ def __init__(

# inspect and update folds
num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)
if "num_fold" in self.data_src_cfg:
num_fold = int(self.data_src_cfg["num_fold"]) # override from config

self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input
ConfigParser.export_config_file(
Expand All @@ -266,17 +302,6 @@ def __init__(
self.datastats_filename = os.path.join(self.work_dir, "datastats.yaml")
self.datalist_filename = datalist_filename

self.not_use_cache = not_use_cache
self.cache_filename = os.path.join(self.work_dir, "cache.yaml")
self.cache = self.read_cache()
self.export_cache()

# determine if we need to analyze, algo_gen or train from cache, unless manually provided
self.analyze = not self.cache["analyze"] if analyze is None else analyze
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
self.train = train
self.ensemble = ensemble # last step, no need to check

self.set_training_params()
self.set_device_info()
self.set_prediction_params()
Expand All @@ -288,9 +313,9 @@ def __init__(
self.gpu_customization_specs: dict[str, Any] = {}

# hpo
if hpo_backend.lower() != "nni":
if self.hpo_backend.lower() != "nni":
raise NotImplementedError("HPOGen backend only supports NNI")
self.hpo = hpo and has_nni
self.hpo = self.hpo and has_nni
self.set_hpo_params()
self.search_space: dict[str, dict[str, Any]] = {}
self.hpo_tasks = 0
Expand Down
3 changes: 2 additions & 1 deletion tests/test_vis_gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from monai.networks.nets import DenseNet, DenseNet121, SEResNet50
from monai.visualize import GradCAM, GradCAMpp
from tests.utils import assert_allclose
from tests.utils import assert_allclose, skip_if_quick


class DenseNetAdjoint(DenseNet121):
Expand Down Expand Up @@ -147,6 +147,7 @@ def __call__(self, x, adjoint_info):
TESTS_ILL.append([cam])


@skip_if_quick
class TestGradientClassActivationMap(unittest.TestCase):
@parameterized.expand(TESTS)
def test_shape(self, cam_class, input_data, expected_shape):
Expand Down

0 comments on commit 4847df2

Please sign in to comment.