From 7f96025086fcc6eeeb9293400178f3f34bcfa839 Mon Sep 17 00:00:00 2001 From: Erik De Bonte Date: Thu, 3 Oct 2024 17:54:52 -0700 Subject: [PATCH] Fix crash when using `--check` switch (#74) --- sarif/cmdline/main.py | 3 +-- sarif/sarif_file.py | 15 --------------- tests/test_check_switch.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 17 deletions(-) create mode 100644 tests/test_check_switch.py diff --git a/sarif/cmdline/main.py b/sarif/cmdline/main.py index 9d53176..b964f20 100644 --- a/sarif/cmdline/main.py +++ b/sarif/cmdline/main.py @@ -226,9 +226,8 @@ def _create_arg_parser(): def _check(input_files: sarif_file.SarifFileSet, check_level): ret = 0 if check_level: - counts = input_files.get_result_count_by_severity() for severity in sarif_file.SARIF_SEVERITIES_WITH_NONE: - ret += counts.get(severity, 0) + ret += input_files.get_report().get_issue_count_for_severity(severity) if severity == check_level: break if ret > 0: diff --git a/sarif/sarif_file.py b/sarif/sarif_file.py index 8c94e23..181572a 100644 --- a/sarif/sarif_file.py +++ b/sarif/sarif_file.py @@ -477,21 +477,6 @@ def get_result_count(self) -> int: """ return sum(run.get_result_count() for run in self.runs) - def get_result_count_by_severity(self, severities=None) -> Dict[str, int]: - """ - Return a dict from SARIF severity to number of records. - """ - severities = severities or self.get_severities() - result_count_by_severity_per_run = [ - run.get_result_count_by_severity(severities) for run in self.runs - ] - return { - severity: sum( - rc.get(severity, 0) for rc in result_count_by_severity_per_run - ) - for severity in severities - } - def get_filter_stats(self) -> Optional[FilterStats]: """ Get the number of records that were included or excluded by the filter. diff --git a/tests/test_check_switch.py b/tests/test_check_switch.py new file mode 100644 index 0000000..e726c13 --- /dev/null +++ b/tests/test_check_switch.py @@ -0,0 +1,28 @@ +import datetime +from sarif.cmdline.main import _check +from sarif import sarif_file + +SARIF = { + "runs": [ + { + "tool": {"driver": {"name": "Tool"}}, + "results": [{"level": "warning", "ruleId": "rule"}], + } + ] +} + + +def test_check(): + fileSet = sarif_file.SarifFileSet() + fileSet.add_file( + sarif_file.SarifFile("SARIF", SARIF, mtime=datetime.datetime.now()) + ) + + result = _check(fileSet, "error") + assert result == 0 + + result = _check(fileSet, "warning") + assert result == 1 + + result = _check(fileSet, "note") + assert result == 1