Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6873 data analyzer histogram_only=True fix #6874

Merged
merged 2 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions monai/apps/auto3dseg/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def _check_data_uniformity(keys: list[str], result: dict) -> bool:

"""

if DataStatsKeys.SUMMARY not in result or DataStatsKeys.IMAGE_STATS not in result[DataStatsKeys.SUMMARY]:
return True
constant_props = [result[DataStatsKeys.SUMMARY][DataStatsKeys.IMAGE_STATS][key] for key in keys]
for prop in constant_props:
if "stdev" in prop and np.any(prop["stdev"]):
Expand Down Expand Up @@ -358,10 +360,11 @@ def _get_all_case_stats(
stats_by_cases = {
DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH],
DataStatsKeys.BY_CASE_LABEL_PATH: d[DataStatsKeys.BY_CASE_LABEL_PATH],
DataStatsKeys.IMAGE_STATS: d[DataStatsKeys.IMAGE_STATS],
}
if not self.histogram_only:
stats_by_cases[DataStatsKeys.IMAGE_STATS] = d[DataStatsKeys.IMAGE_STATS]
if self.hist_bins != 0:
stats_by_cases.update({DataStatsKeys.IMAGE_HISTOGRAM: d[DataStatsKeys.IMAGE_HISTOGRAM]})
stats_by_cases[DataStatsKeys.IMAGE_HISTOGRAM] = d[DataStatsKeys.IMAGE_HISTOGRAM]

if self.label_key is not None:
stats_by_cases.update(
Expand Down
16 changes: 10 additions & 6 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class ImageStats(Analyzer):

"""

def __init__(self, image_key: str, stats_name: str = "image_stats") -> None:
def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS) -> None:
if not isinstance(image_key, str):
raise ValueError("image_key input must be str")

Expand Down Expand Up @@ -296,7 +296,7 @@ class FgImageStats(Analyzer):

"""

def __init__(self, image_key: str, label_key: str, stats_name: str = "image_foreground_stats"):
def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKeys.FG_IMAGE_STATS):
self.image_key = image_key
self.label_key = label_key

Expand Down Expand Up @@ -378,7 +378,9 @@ class LabelStats(Analyzer):

"""

def __init__(self, image_key: str, label_key: str, stats_name: str = "label_stats", do_ccp: bool | None = True):
def __init__(
self, image_key: str, label_key: str, stats_name: str = DataStatsKeys.LABEL_STATS, do_ccp: bool | None = True
):
self.image_key = image_key
self.label_key = label_key
self.do_ccp = do_ccp
Expand Down Expand Up @@ -533,7 +535,7 @@ class ImageStatsSumm(Analyzer):

"""

def __init__(self, stats_name: str = "image_stats", average: bool | None = True):
def __init__(self, stats_name: str = DataStatsKeys.IMAGE_STATS, average: bool | None = True):
self.summary_average = average
report_format = {
ImageStatsKeys.SHAPE: None,
Expand Down Expand Up @@ -623,7 +625,7 @@ class FgImageStatsSumm(Analyzer):

"""

def __init__(self, stats_name: str = "image_foreground_stats", average: bool | None = True):
def __init__(self, stats_name: str = DataStatsKeys.FG_IMAGE_STATS, average: bool | None = True):
self.summary_average = average

report_format = {ImageStatsKeys.INTENSITY: None}
Expand Down Expand Up @@ -687,7 +689,9 @@ class LabelStatsSumm(Analyzer):

"""

def __init__(self, stats_name: str = "label_stats", average: bool | None = True, do_ccp: bool | None = True):
def __init__(
self, stats_name: str = DataStatsKeys.LABEL_STATS, average: bool | None = True, do_ccp: bool | None = True
):
self.summary_average = average
self.do_ccp = do_ccp

Expand Down
4 changes: 2 additions & 2 deletions monai/auto3dseg/seg_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def __init__(
self.summary_analyzers: list[Any] = []
super().__init__()

self.add_analyzer(FilenameStats(image_key, DataStatsKeys.BY_CASE_IMAGE_PATH), None)
self.add_analyzer(FilenameStats(label_key, DataStatsKeys.BY_CASE_LABEL_PATH), None)
if not self.histogram_only:
self.add_analyzer(FilenameStats(image_key, DataStatsKeys.BY_CASE_IMAGE_PATH), None)
self.add_analyzer(FilenameStats(label_key, DataStatsKeys.BY_CASE_LABEL_PATH), None)
self.add_analyzer(ImageStats(image_key), ImageStatsSumm(average=average))

if label_key is None:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,21 @@ def test_data_analyzer_cpu(self, input_params):

assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"])

def test_data_analyzer_histogram(self):
create_sim_data(
self.dataroot_dir, sim_datalist, [32] * 3, image_only=True, rad_max=8, rad_min=1, num_seg_classes=1
)
analyser = DataAnalyzer(
self.datalist_file,
self.dataroot_dir,
output_path=self.datastat_file,
label_key=None,
device=device,
histogram_only=True,
)
datastat = analyser.get_all_case_stats()
assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"])

@parameterized.expand(SIM_GPU_TEST_CASES)
@skip_if_no_cuda
def test_data_analyzer_gpu(self, input_params):
Expand Down