diff --git a/vulnerabilities/api_v2.py b/vulnerabilities/api_v2.py new file mode 100644 index 000000000..e8361c092 --- /dev/null +++ b/vulnerabilities/api_v2.py @@ -0,0 +1,193 @@ +# +# Copyright (c) nexB Inc. and others. All rights reserved. +# VulnerableCode is a trademark of nexB Inc. +# SPDX-License-Identifier: Apache-2.0 +# See http://www.apache.org/licenses/LICENSE-2.0 for the license text. +# See https://github.com/aboutcode-org/vulnerablecode for support or download. +# See https://aboutcode.org for more information about nexB OSS projects. +# + + +from rest_framework import serializers +from rest_framework import viewsets +from rest_framework.response import Response +from rest_framework.reverse import reverse + +from vulnerabilities.api import VulnerabilitySeveritySerializer +from vulnerabilities.models import Package +from vulnerabilities.models import Vulnerability +from vulnerabilities.models import VulnerabilityReference +from vulnerabilities.models import VulnerabilitySeverity +from vulnerabilities.models import Weakness + + +class WeaknessV2Serializer(serializers.ModelSerializer): + cwe_id = serializers.CharField() + name = serializers.CharField() + description = serializers.CharField() + + class Meta: + model = Weakness + fields = ["cwe_id", "name", "description"] + + +class VulnerabilityReferenceV2Serializer(serializers.ModelSerializer): + url = serializers.CharField() + reference_type = serializers.CharField() + reference_id = serializers.CharField() + + class Meta: + model = VulnerabilityReference + fields = ["url", "reference_type", "reference_id"] + + +class VulnerabilityV2Serializer(serializers.ModelSerializer): + aliases = serializers.SerializerMethodField() + weaknesses = WeaknessV2Serializer(many=True) + references = VulnerabilityReferenceV2Serializer(many=True, source="vulnerabilityreference_set") + severities = VulnerabilitySeveritySerializer(many=True) + + class Meta: + model = Vulnerability + fields = [ + "vulnerability_id", + "aliases", + "summary", + "severities", + "weaknesses", + "references", + ] + + def get_aliases(self, obj): + return [alias.alias for alias in obj.aliases.all()] + + def get_severities(self, obj): + return obj.severities + + +class VulnerabilityListSerializer(serializers.ModelSerializer): + url = serializers.SerializerMethodField() + + class Meta: + model = Vulnerability + fields = ["vulnerability_id", "url"] + + def get_url(self, obj): + request = self.context.get("request") + return reverse( + "vulnerability-v2-detail", + kwargs={"vulnerability_id": obj.vulnerability_id}, + request=request, + ) + + +class VulnerabilityV2ViewSet(viewsets.ReadOnlyModelViewSet): + queryset = Vulnerability.objects.all() + serializer_class = VulnerabilityV2Serializer + lookup_field = "vulnerability_id" + + def get_queryset(self): + queryset = super().get_queryset() + vulnerability_ids = self.request.query_params.getlist("vulnerability_id") + aliases = self.request.query_params.getlist("alias") + + if vulnerability_ids: + queryset = queryset.filter(vulnerability_id__in=vulnerability_ids) + + if aliases: + queryset = queryset.filter(aliases__alias__in=aliases).distinct() + + return queryset + + def get_serializer_class(self): + if self.action == "list": + return VulnerabilityListSerializer + return super().get_serializer_class() + + def list(self, request, *args, **kwargs): + queryset = self.get_queryset() + vulnerability_ids = request.query_params.getlist("vulnerability_id") + + # If exactly one vulnerability_id is provided, return the serialized data + if len(vulnerability_ids) == 1: + try: + vulnerability = queryset.get(vulnerability_id=vulnerability_ids[0]) + serializer = self.get_serializer(vulnerability) + return Response(serializer.data) + except Vulnerability.DoesNotExist: + return Response({"detail": "Not found."}, status=404) + + # Otherwise, return a dictionary of vulnerabilities keyed by vulnerability_id + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + data = serializer.data + vulnerabilities = {item["vulnerability_id"]: item for item in data} + return self.get_paginated_response({"vulnerabilities": vulnerabilities}) + + serializer = self.get_serializer(queryset, many=True) + data = serializer.data + vulnerabilities = {item["vulnerability_id"]: item for item in data} + return Response({"vulnerabilities": vulnerabilities}) + + +class PackageV2Serializer(serializers.ModelSerializer): + purl = serializers.CharField(source="package_url") + affected_by_vulnerabilities = serializers.SerializerMethodField() + fixing_vulnerabilities = serializers.SerializerMethodField() + next_non_vulnerable_version = serializers.CharField(read_only=True) + latest_non_vulnerable_version = serializers.CharField(read_only=True) + + class Meta: + model = Package + fields = [ + "purl", + "affected_by_vulnerabilities", + "fixing_vulnerabilities", + "next_non_vulnerable_version", + "latest_non_vulnerable_version", + ] + + def get_affected_by_vulnerabilities(self, obj): + return [vuln.vulnerability_id for vuln in obj.affected_by_vulnerabilities.all()] + + def get_fixing_vulnerabilities(self, obj): + return [vuln.vulnerability_id for vuln in obj.fixing_vulnerabilities.all()] + + +class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet): + queryset = Package.objects.all() + serializer_class = PackageV2Serializer + + def get_queryset(self): + queryset = super().get_queryset() + package_purls = self.request.query_params.getlist("purl") + affected_by_vulnerability = self.request.query_params.get("affected_by_vulnerability") + fixing_vulnerability = self.request.query_params.get("fixing_vulnerability") + + if package_purls: + queryset = queryset.filter(package_url__in=package_purls) + if affected_by_vulnerability: + queryset = queryset.filter( + affected_by_vulnerabilities__vulnerability_id=affected_by_vulnerability + ) + if fixing_vulnerability: + queryset = queryset.filter( + fixing_vulnerabilities__vulnerability_id=fixing_vulnerability + ) + return queryset.with_is_vulnerable() + + def list(self, request, *args, **kwargs): + queryset = self.get_queryset() + # Apply pagination + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + data = serializer.data + # Use 'self.get_paginated_response' to include pagination data + return self.get_paginated_response({"packages": data}) + + # If pagination is not applied + serializer = self.get_serializer(queryset, many=True) + data = serializer.data + return Response({"packages": data}) diff --git a/vulnerabilities/tests/test_api_v2.py b/vulnerabilities/tests/test_api_v2.py new file mode 100644 index 000000000..eeaa18776 --- /dev/null +++ b/vulnerabilities/tests/test_api_v2.py @@ -0,0 +1,307 @@ +# +# Copyright (c) nexB Inc. and others. All rights reserved. +# VulnerableCode is a trademark of nexB Inc. +# SPDX-License-Identifier: Apache-2.0 +# See http://www.apache.org/licenses/LICENSE-2.0 for the license text. +# See https://github.com/aboutcode-org/vulnerablecode for support or download. +# See https://aboutcode.org for more information about nexB OSS projects. +# + +from django.urls import reverse +from packageurl import PackageURL +from rest_framework import status +from rest_framework.test import APIClient +from rest_framework.test import APITestCase + +from vulnerabilities.api_v2 import PackageV2Serializer +from vulnerabilities.api_v2 import VulnerabilityListSerializer +from vulnerabilities.models import Alias +from vulnerabilities.models import ApiUser +from vulnerabilities.models import Package +from vulnerabilities.models import Vulnerability +from vulnerabilities.models import VulnerabilityReference +from vulnerabilities.models import Weakness + + +class VulnerabilityV2ViewSetTest(APITestCase): + def setUp(self): + # Create vulnerabilities + self.vuln1 = Vulnerability.objects.create( + vulnerability_id="VCID-1234", summary="Test vulnerability 1" + ) + self.vuln2 = Vulnerability.objects.create( + vulnerability_id="VCID-5678", summary="Test vulnerability 2" + ) + + # Create aliases + Alias.objects.create(alias="CVE-2021-1234", vulnerability=self.vuln1) + Alias.objects.create(alias="CVE-2021-5678", vulnerability=self.vuln2) + + # Create weaknesses + self.weakness1 = Weakness.objects.create(cwe_id=79) + self.weakness1.vulnerabilities.add(self.vuln1) + + self.weakness2 = Weakness.objects.create(cwe_id=89) + self.weakness2.vulnerabilities.add(self.vuln2) + + # Create references + self.reference1 = VulnerabilityReference.objects.create( + url="https://example.com/ref1", reference_type="advisory", reference_id="REF-1" + ) + self.reference1.vulnerabilities.add(self.vuln1) + + self.reference2 = VulnerabilityReference.objects.create( + url="https://example.com/ref2", reference_type="exploit", reference_id="REF-2" + ) + self.reference2.vulnerabilities.add(self.vuln2) + self.user = ApiUser.objects.create_api_user(username="e@mail.com") + self.auth = f"Token {self.user.auth_token.key}" + self.client = APIClient(enforce_csrf_checks=True) + self.client.credentials(HTTP_AUTHORIZATION=self.auth) + + def test_list_vulnerabilities(self): + """ + Test listing vulnerabilities without filters. + Should return a list of vulnerabilities with IDs and URLs. + """ + url = reverse("vulnerability-v2-list") + response = self.client.get(url, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("vulnerabilities", response.data["results"]) + self.assertEqual(len(response.data["results"]["vulnerabilities"]), 2) + self.assertTrue("url" in response.data["results"]["vulnerabilities"]["VCID-1234"]) + + def test_retrieve_vulnerability_detail(self): + """ + Test retrieving vulnerability details by vulnerability_id. + """ + url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-1234"}) + response = self.client.get(url, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["vulnerability_id"], "VCID-1234") + self.assertEqual(response.data["summary"], "Test vulnerability 1") + self.assertEqual(response.data["aliases"], ["CVE-2021-1234"]) + self.assertEqual(len(response.data["weaknesses"]), 1) + self.assertEqual(len(response.data["references"]), 1) + + def test_filter_vulnerability_by_vulnerability_id(self): + """ + Test filtering vulnerabilities by vulnerability_id. + """ + url = reverse("vulnerability-v2-list") + response = self.client.get(url, {"vulnerability_id": "VCID-1234"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["vulnerability_id"], "VCID-1234") + + def test_filter_vulnerability_by_alias(self): + """ + Test filtering vulnerabilities by alias. + """ + url = reverse("vulnerability-v2-list") + response = self.client.get(url, {"alias": "CVE-2021-5678"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.data["results"]["vulnerabilities"]["VCID-5678"]["vulnerability_id"], + "VCID-5678", + ) + + def test_filter_vulnerabilities_multiple_ids(self): + """ + Test filtering vulnerabilities by multiple vulnerability_ids. + """ + url = reverse("vulnerability-v2-list") + response = self.client.get( + url, {"vulnerability_id": ["VCID-1234", "VCID-5678"]}, format="json" + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]["vulnerabilities"]), 2) + + def test_filter_vulnerabilities_multiple_aliases(self): + """ + Test filtering vulnerabilities by multiple aliases. + """ + url = reverse("vulnerability-v2-list") + response = self.client.get( + url, {"alias": ["CVE-2021-1234", "CVE-2021-5678"]}, format="json" + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]["vulnerabilities"]), 2) + + def test_invalid_vulnerability_id(self): + """ + Test retrieving a vulnerability with an invalid vulnerability_id. + Should return 404 Not Found. + """ + url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-9999"}) + response = self.client.get(url, format="json") + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_get_url_in_serializer(self): + """ + Test that the serializer correctly includes the URL field. + """ + vulnerability = Vulnerability.objects.get(vulnerability_id="VCID-1234") + serializer = VulnerabilityListSerializer(vulnerability, context={"request": None}) + self.assertIn("url", serializer.data) + self.assertEqual(serializer.data["vulnerability_id"], "VCID-1234") + + def test_list_vulnerabilities_pagination(self): + """ + Test listing vulnerabilities with pagination. + """ + # Create additional vulnerabilities to trigger pagination + for i in range(3, 15): + Vulnerability.objects.create( + vulnerability_id=f"VCID-{i}", summary=f"Test vulnerability {i}" + ) + + url = reverse("vulnerability-v2-list") + response = self.client.get(url, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("results", response.data) + self.assertIn("next", response.data) + self.assertIn("previous", response.data) + self.assertEqual(len(response.data["results"]), 1) # Assuming default page size is 10 + + +class PackageV2ViewSetTest(APITestCase): + def setUp(self): + # Create packages + self.package1 = Package.objects.create( + package_url="pkg:pypi/django@3.2", name="django", version="3.2", type="pypi" + ) + self.package2 = Package.objects.create( + package_url="pkg:npm/lodash@4.17.20", name="lodash", version="4.17.20", type="npm" + ) + + # Create vulnerabilities + self.vuln1 = Vulnerability.objects.create( + vulnerability_id="VCID-1234", summary="Test vulnerability 1" + ) + self.vuln2 = Vulnerability.objects.create( + vulnerability_id="VCID-5678", summary="Test vulnerability 2" + ) + + # Associate packages with vulnerabilities + self.package1.affected_by_vulnerabilities.add(self.vuln1) + self.package2.fixing_vulnerabilities.add(self.vuln2) + self.user = ApiUser.objects.create_api_user(username="e@mail.com") + self.auth = f"Token {self.user.auth_token.key}" + self.client = APIClient(enforce_csrf_checks=True) + self.client.credentials(HTTP_AUTHORIZATION=self.auth) + + def test_list_packages(self): + """ + Test listing packages without filters. + Should return a list of packages with their details. + """ + url = reverse("package-v2-list") + response = self.client.get(url, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("packages", response.data["results"]) + self.assertEqual(len(response.data["results"]["packages"]), 2) + + def test_filter_packages_by_purl(self): + """ + Test filtering packages by one or more PURLs. + """ + url = reverse("package-v2-list") + response = self.client.get(url, {"purl": "pkg:pypi/django@3.2"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]["packages"]), 1) + self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:pypi/django@3.2") + + def test_filter_packages_by_affected_vulnerability(self): + """ + Test filtering packages by affected_by_vulnerability. + """ + url = reverse("package-v2-list") + response = self.client.get(url, {"affected_by_vulnerability": "VCID-1234"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]["packages"]), 1) + self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:pypi/django@3.2") + + def test_filter_packages_by_fixing_vulnerability(self): + """ + Test filtering packages by fixing_vulnerability. + """ + url = reverse("package-v2-list") + response = self.client.get(url, {"fixing_vulnerability": "VCID-5678"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]["packages"]), 1) + self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:npm/lodash@4.17.20") + + def test_package_serializer_fields(self): + """ + Test that the PackageV2Serializer returns the correct fields. + """ + package = Package.objects.get(package_url="pkg:pypi/django@3.2") + serializer = PackageV2Serializer(package) + data = serializer.data + self.assertIn("purl", data) + self.assertIn("affected_by_vulnerabilities", data) + self.assertIn("fixing_vulnerabilities", data) + self.assertIn("next_non_vulnerable_version", data) + self.assertIn("latest_non_vulnerable_version", data) + self.assertEqual(data["purl"], "pkg:pypi/django@3.2") + self.assertEqual(data["affected_by_vulnerabilities"], ["VCID-1234"]) + self.assertEqual(data["fixing_vulnerabilities"], []) + + def test_list_packages_pagination(self): + """ + Test listing packages with pagination. + """ + # Create additional packages to trigger pagination + for i in range(3, 15): + Package.objects.create( + package_url=f"pkg:pypi/package{i}@1.0.{i}", + name=f"package{i}", + version=f"1.0.{i}", + type="pypi", + ) + + url = reverse("package-v2-list") + response = self.client.get(url, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("results", response.data) + self.assertIn("next", response.data) + self.assertIn("previous", response.data) + self.assertEqual(len(response.data["results"]), 1) # Assuming default page size is 10 + + def test_invalid_vulnerability_filter(self): + """ + Test filtering packages with an invalid vulnerability ID. + Should return an empty list. + """ + url = reverse("package-v2-list") + response = self.client.get(url, {"affected_by_vulnerability": "VCID-9999"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]["packages"]), 0) + + def test_invalid_purl_filter(self): + """ + Test filtering packages with an invalid PURL. + Should return an empty list. + """ + url = reverse("package-v2-list") + response = self.client.get(url, {"purl": "pkg:nonexistent/package@1.0.0"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]["packages"]), 0) + + def test_get_affected_by_vulnerabilities(self): + """ + Test the get_affected_by_vulnerabilities method in the serializer. + """ + package = Package.objects.get(package_url="pkg:pypi/django@3.2") + serializer = PackageV2Serializer() + vulnerabilities = serializer.get_affected_by_vulnerabilities(package) + self.assertEqual(vulnerabilities, ["VCID-1234"]) + + def test_get_fixing_vulnerabilities(self): + """ + Test the get_fixing_vulnerabilities method in the serializer. + """ + package = Package.objects.get(package_url="pkg:npm/lodash@4.17.20") + serializer = PackageV2Serializer() + vulnerabilities = serializer.get_fixing_vulnerabilities(package) + self.assertEqual(vulnerabilities, ["VCID-5678"]) diff --git a/vulnerablecode/urls.py b/vulnerablecode/urls.py index 51d303138..d92d6ce40 100644 --- a/vulnerablecode/urls.py +++ b/vulnerablecode/urls.py @@ -20,6 +20,8 @@ from vulnerabilities.api import CPEViewSet from vulnerabilities.api import PackageViewSet from vulnerabilities.api import VulnerabilityViewSet +from vulnerabilities.api_v2 import PackageV2ViewSet +from vulnerabilities.api_v2 import VulnerabilityV2ViewSet from vulnerabilities.views import ApiUserCreateView from vulnerabilities.views import HomePage from vulnerabilities.views import PackageDetails @@ -43,6 +45,10 @@ def __init__(self, *args, **kwargs): api_router.register("cpes", CPEViewSet, basename="cpe") api_router.register("aliases", AliasViewSet, basename="alias") +api_v2_router = OptionalSlashRouter() +api_v2_router.register("vulnerabilities", VulnerabilityV2ViewSet, basename="vulnerability-v2") +api_v2_router.register("packages", PackageV2ViewSet, basename="package-v2") + urlpatterns = [ path( "robots.txt", @@ -98,6 +104,7 @@ def __init__(self, *args, **kwargs): TemplateView.as_view(template_name="tos.html"), name="api_tos", ), + path("api/v2/", include(api_v2_router.urls)), path( "admin/", admin.site.urls,