Skip to content

Commit

Permalink
Merge pull request #363 from sbs2001/remove_advisory_hash_method
Browse files Browse the repository at this point in the history
Refactor codebase and tests to treat Advisory class mutable
  • Loading branch information
sbs2001 authored Mar 4, 2021
2 parents f262ef1 + 0e0fc2b commit a187f90
Show file tree
Hide file tree
Showing 34 changed files with 1,703 additions and 1,390 deletions.
48 changes: 34 additions & 14 deletions vulnerabilities/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
@dataclasses.dataclass(order=True)
class VulnerabilitySeverity:
system: ScoringSystem
value: str


@dataclasses.dataclass
@dataclasses.dataclass(order=True)
class Reference:

reference_id: str = ""
Expand All @@ -64,8 +64,16 @@ def __post_init__(self):
if not any([self.url, self.reference_id]):
raise TypeError

def normalized(self):
severities = sorted(self.severities)
return Reference(
reference_id=self.reference_id,
url=self.url,
severities=severities
)

@dataclasses.dataclass

@dataclasses.dataclass(order=True)
class Advisory:
"""
This data class expresses the contract between data sources and the import runner.
Expand All @@ -78,19 +86,27 @@ class Advisory:
"""

summary: str
impacted_package_urls: Iterable[PackageURL]
vulnerability_id: Optional[str] = None
impacted_package_urls: Iterable[PackageURL] = dataclasses.field(default_factory=list)
resolved_package_urls: Iterable[PackageURL] = dataclasses.field(default_factory=list)
vuln_references: List[Reference] = dataclasses.field(default_factory=list)
vulnerability_id: Optional[str] = None

def __hash__(self):
s = "{}{}{}{}".format(
self.summary,
''.join(sorted([str(p) for p in self.impacted_package_urls])),
''.join(sorted([str(p) for p in self.resolved_package_urls])),
self.vulnerability_id,
def normalized(self):
impacted_package_urls = {package_url for package_url in self.impacted_package_urls}
resolved_package_urls = {package_url for package_url in self.resolved_package_urls}
vuln_references = sorted(
self.vuln_references, key=lambda reference: (reference.reference_id, reference.url)
)
for index, _ in enumerate(self.vuln_references):
vuln_references[index] = (vuln_references[index].normalized())

return Advisory(
summary=self.summary,
vulnerability_id=self.vulnerability_id,
impacted_package_urls=impacted_package_urls,
resolved_package_urls=resolved_package_urls,
vuln_references=vuln_references,
)
return hash(s)


class InvalidConfigurationError(Exception):
Expand Down Expand Up @@ -205,11 +221,15 @@ def batch_advisories(self, advisories: List[Advisory]) -> Set[Advisory]:
"""
Yield batches of the passed in list of advisories.
"""
advisories = advisories[:] # copy the list as we are mutating it in the loop below

# TODO make this less cryptic and efficient

advisories = advisories[:]
# copy the list as we are mutating it in the loop below

while advisories:
b, advisories = advisories[: self.batch_size], advisories[self.batch_size:]
yield set(b)
yield b


@dataclasses.dataclass
Expand Down
3 changes: 2 additions & 1 deletion vulnerabilities/importers/alpine_linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,14 @@ def _load_advisories(
)
)

# TODO: Handle the CVE-????-????? case
advisories.append(
Advisory(
summary="",
impacted_package_urls=[],
resolved_package_urls=resolved_purls,
vuln_references=references,
vulnerability_id=vuln_ids[0] if vuln_ids[0] != "CVE-????-?????" else None,
vulnerability_id=vuln_ids[0] if vuln_ids[0] != "CVE-????-?????" else "",
)
)

Expand Down
4 changes: 2 additions & 2 deletions vulnerabilities/importers/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def _load_advisories(self, files) -> Set[Advisory]:

while files:
batch, files = files[: self.batch_size], files[self.batch_size:]
advisories = set()
advisories = []
for path in batch:
advisory = self._load_advisory(path)
if advisory:
advisories.add(advisory)
advisories.append(advisory)
yield advisories

def collect_packages(self, paths):
Expand Down
2 changes: 1 addition & 1 deletion vulnerabilities/severity_systems.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses


@dataclasses.dataclass
@dataclasses.dataclass(order=True)
class ScoringSystem:

# a short identifier for the scoring system.
Expand Down
4 changes: 2 additions & 2 deletions vulnerabilities/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@

@pytest.fixture
def no_mkdir(monkeypatch):
monkeypatch.delattr('os.mkdir')
monkeypatch.delattr("os.mkdir")


@pytest.fixture
def no_rmtree(monkeypatch):
monkeypatch.delattr('shutil.rmtree')
monkeypatch.delattr("shutil.rmtree")
7 changes: 5 additions & 2 deletions vulnerabilities/tests/test_alpine.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test__process_link(self):
url="https://xenbits.xen.org/xsa/advisory-295.html", reference_id="XSA-295"
)
],
vulnerability_id=None,
vulnerability_id="",
),
]
mock_requests = MagicMock()
Expand All @@ -151,4 +151,7 @@ def test__process_link(self):
mock_content.content = f
with patch("vulnerabilities.importers.alpine_linux.requests", new=mock_requests):
found_advisories = self.data_source._process_link("does not matter")
assert expected_advisories == found_advisories

found_advisories = list(map(Advisory.normalized, found_advisories))
expected_advisories = list(map(Advisory.normalized, expected_advisories))
assert sorted(found_advisories) == sorted(expected_advisories)
12 changes: 6 additions & 6 deletions vulnerabilities/tests/test_apache_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ def test_to_version_ranges(self):

def test_to_advisory(self):
data_source = ApacheKafkaDataSource(batch_size=1)
data_source.version_api = GitHubTagsAPI(
cache={"apache/kafka": ["2.1.2", "0.10.2.2"]}
)
expected_data = [
data_source.version_api = GitHubTagsAPI(cache={"apache/kafka": ["2.1.2", "0.10.2.2"]})
expected_advisories = [
Advisory(
summary="In Apache Kafka versions between 0.11.0.0 and 2.1.0, it is possible to "
"manually\n craft a Produce request which bypasses transaction/idempotent ACL "
Expand Down Expand Up @@ -97,6 +95,8 @@ def test_to_advisory(self):
)
]
with open(TEST_DATA) as f:
found_data = data_source.to_advisory(f)
found_advisories = data_source.to_advisory(f)

assert found_data == expected_data
found_advisories = list(map(Advisory.normalized, found_advisories))
expected_advisories = list(map(Advisory.normalized, expected_advisories))
assert sorted(found_advisories) == sorted(expected_advisories)
Loading

0 comments on commit a187f90

Please sign in to comment.