From f662e41157c4beb7ed843edcc0ee9f89c8271e69 Mon Sep 17 00:00:00 2001 From: Tushar Goel Date: Wed, 17 Aug 2022 17:50:28 +0530 Subject: [PATCH] Add tests Signed-off-by: Tushar Goel --- vulnerabilities/api.py | 13 ++-- vulnerabilities/templates/base.html | 3 +- vulnerabilities/tests/test_auth.py | 90 +++++++++++++++++++++++++++ vulnerabilities/tests/test_fix_api.py | 89 +++++++++++++++++--------- vulnerablecode/settings.py | 4 +- 5 files changed, 156 insertions(+), 43 deletions(-) create mode 100644 vulnerabilities/tests/test_auth.py diff --git a/vulnerabilities/api.py b/vulnerabilities/api.py index 0fc0962a6..c2a45e357 100644 --- a/vulnerabilities/api.py +++ b/vulnerabilities/api.py @@ -22,7 +22,6 @@ from vulnerabilities.models import Vulnerability from vulnerabilities.models import VulnerabilityReference from vulnerabilities.models import VulnerabilitySeverity -from vulnerablecode.auth import ConditionalLoginRequired class VulnerabilitySeveritySerializer(serializers.ModelSerializer): @@ -31,10 +30,6 @@ class Meta: fields = ["value", "scoring_system"] -class AuthenticatedAPIViewSet(ConditionalLoginRequired, viewsets.ReadOnlyModelViewSet): - pass - - class VulnerabilityReferenceSerializer(serializers.ModelSerializer): scores = VulnerabilitySeveritySerializer(many=True, source="vulnerabilityseverity_set") reference_url = serializers.CharField(source="url") @@ -211,7 +206,7 @@ def filter_purl(self, queryset, name, value): return self.queryset.filter(**attrs) -class PackageViewSet(AuthenticatedAPIViewSet): +class PackageViewSet(viewsets.ReadOnlyModelViewSet): queryset = Package.objects.all() serializer_class = PackageSerializer filter_backends = (filters.DjangoFilterBackend,) @@ -258,7 +253,7 @@ class Meta: fields = ["vulnerability_id"] -class VulnerabilityViewSet(AuthenticatedAPIViewSet): +class VulnerabilityViewSet(viewsets.ReadOnlyModelViewSet): def get_fixed_packages_qs(self): """ Filter the packages that fixes a vulnerability @@ -300,7 +295,7 @@ def filter_cpe(self, queryset, name, value): return self.queryset.filter(vulnerabilityreference__reference_id__startswith=cpe).distinct() -class CPEViewSet(AuthenticatedAPIViewSet): +class CPEViewSet(viewsets.ReadOnlyModelViewSet): queryset = Vulnerability.objects.filter( vulnerabilityreference__reference_id__startswith="cpe" ).distinct() @@ -341,7 +336,7 @@ def filter_alias(self, queryset, name, value): return self.queryset.filter(aliases__alias__icontains=alias) -class AliasViewSet(AuthenticatedAPIViewSet): +class AliasViewSet(viewsets.ReadOnlyModelViewSet): queryset = Vulnerability.objects.all() serializer_class = VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) diff --git a/vulnerabilities/templates/base.html b/vulnerabilities/templates/base.html index 089d17d7f..2bb426a50 100644 --- a/vulnerabilities/templates/base.html +++ b/vulnerabilities/templates/base.html @@ -63,5 +63,4 @@ {% block javascript %} -{% endblock %} - \ No newline at end of file +{% endblock %} \ No newline at end of file diff --git a/vulnerabilities/tests/test_auth.py b/vulnerabilities/tests/test_auth.py new file mode 100644 index 000000000..3001b6602 --- /dev/null +++ b/vulnerabilities/tests/test_auth.py @@ -0,0 +1,90 @@ +# +# Copyright (c) nexB Inc. and others. All rights reserved. +# VulnerableCode is a trademark of nexB Inc. +# SPDX-License-Identifier: Apache-2.0 +# See http://www.apache.org/licenses/LICENSE-2.0 for the license text. +# See https://github.com/nexB/vulnerablecode for support or download. +# See https://aboutcode.org for more information about nexB OSS projects. +# This is copied from https://github.com/nexB/scancode.io/commit/eab8eeb13989c26a1600cc64e8b054f171341063 +# + +from django.conf import settings +from django.contrib.auth import get_user_model +from django.contrib.auth.models import AnonymousUser +from django.test import TestCase +from django.test import override_settings +from django.urls import reverse + +from vulnerablecode.auth import is_authenticated_when_required + +TEST_PASSWORD = "secret" + +User = get_user_model() + +login_url = reverse("login") +logout_url = reverse("logout") +profile_url = reverse("account_profile") +api_package_url = "/api/packages/" +login_redirect_url = settings.LOGIN_REDIRECT_URL + + +class VulnerableCodeAuthTest(TestCase): + def setUp(self): + self.anonymous_user = AnonymousUser() + self.basic_user = User.objects.create_user(username="basic_user", password=TEST_PASSWORD) + + def test_vulnerablecode_auth_is_authenticated_when_required(self): + self.assertFalse(self.anonymous_user.is_authenticated) + self.assertFalse(is_authenticated_when_required(user=self.anonymous_user)) + + self.assertTrue(self.basic_user.is_authenticated) + self.assertTrue(is_authenticated_when_required(user=self.basic_user)) + + with override_settings(VULNERABLECODEIO_REQUIRE_AUTHENTICATION=False): + self.assertTrue(is_authenticated_when_required(user=None)) + + def test_vulnerablecode_auth_login_view(self): + data = {"username": self.basic_user.username, "password": ""} + response = self.client.post(login_url, data) + form = response.context_data["form"] + expected_error = {"password": ["This field is required."]} + self.assertEqual(expected_error, form.errors) + + data = {"username": self.basic_user.username, "password": "wrong"} + response = self.client.post(login_url, data) + form = response.context_data["form"] + expected_error = { + "__all__": [ + "Please enter a correct username and password. " + "Note that both fields may be case-sensitive." + ] + } + self.assertEqual(expected_error, form.errors) + + data = {"username": self.basic_user.username, "password": TEST_PASSWORD} + response = self.client.post(login_url, data, follow=True) + self.assertRedirects(response, login_redirect_url) + expected = 'basic_user' + self.assertContains(response, expected, html=True) + + def test_vulnerablecode_auth_logout_view(self): + response = self.client.get(logout_url) + self.assertRedirects(response, login_url) + + self.client.login(username=self.basic_user.username, password=TEST_PASSWORD) + response = self.client.get(logout_url) + self.assertRedirects(response, login_url) + + def test_vulnerablecode_account_profile_view(self): + self.client.login(username=self.basic_user.username, password=TEST_PASSWORD) + response = self.client.get(profile_url) + expected = '' + self.assertContains(response, expected, html=True) + expected = '' + self.assertContains(response, self.basic_user.auth_token.key) + + def test_vulnerablecode_auth_api_required_authentication(self): + response = self.client.get(api_package_url) + expected = {"detail": "Authentication credentials were not provided."} + self.assertEqual(expected, response.json()) + self.assertEqual(401, response.status_code) diff --git a/vulnerabilities/tests/test_fix_api.py b/vulnerabilities/tests/test_fix_api.py index 0b6aa47d5..2cd53db12 100644 --- a/vulnerabilities/tests/test_fix_api.py +++ b/vulnerabilities/tests/test_fix_api.py @@ -7,10 +7,15 @@ # See https://aboutcode.org for more information about nexB OSS projects. # +import json + +from django.contrib.auth import get_user_model from django.test import TestCase +from django.test import TransactionTestCase from django.utils.http import int_to_base36 from packageurl import PackageURL from rest_framework import status +from rest_framework.test import APIClient from vulnerabilities.models import Alias from vulnerabilities.models import Package @@ -19,9 +24,15 @@ from vulnerabilities.models import VulnerabilityReference from vulnerabilities.models import VulnerabilityRelatedReference +User = get_user_model() + -class APITestCaseVulnerability(TestCase): +class APITestCaseVulnerability(TransactionTestCase): def setUp(self): + self.user = User.objects.create_user("username", "e@mail.com", "secret") + self.auth = f"Token {self.user.auth_token.key}" + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) for i in range(0, 200): Vulnerability.objects.create( summary=str(i), @@ -35,15 +46,15 @@ def setUp(self): ) def test_api_status(self): - response = self.client.get("/api/vulnerabilities/", format="json") + response = self.csrf_client.get("/api/vulnerabilities/") self.assertEqual(status.HTTP_200_OK, response.status_code) def test_api_response(self): - response = self.client.get("/api/vulnerabilities/", format="json").data + response = self.csrf_client.get("/api/vulnerabilities/").data self.assertEqual(response["count"], 201) def test_api_with_single_vulnerability(self): - response = self.client.get( + response = self.csrf_client.get( f"/api/vulnerabilities/{self.vulnerability.id}", format="json" ).data assert response == { @@ -66,7 +77,7 @@ def test_api_with_single_vulnerability(self): } def test_api_with_single_vulnerability_with_filters(self): - response = self.client.get( + response = self.csrf_client.get( f"/api/vulnerabilities/{self.vulnerability.id}?type=pypi", format="json" ).data assert response == { @@ -87,6 +98,10 @@ def test_api_with_single_vulnerability_with_filters(self): class APITestCasePackage(TestCase): def setUp(self): + self.user = User.objects.create_user("username", "e@mail.com", "secret") + self.auth = f"Token {self.user.auth_token.key}" + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) vuln = Vulnerability.objects.create( summary="test-vuln", ) @@ -123,15 +138,15 @@ def setUp(self): ) def test_api_status(self): - response = self.client.get("/api/packages/", format="json") + response = self.csrf_client.get("/api/packages/", format="json") self.assertEqual(status.HTTP_200_OK, response.status_code) def test_api_response(self): - response = self.client.get("/api/packages/", format="json").data + response = self.csrf_client.get("/api/packages/", format="json").data self.assertEqual(response["count"], 11) def test_api_with_single_vulnerability_and_fixed_package(self): - response = self.client.get(f"/api/packages/{self.package.id}", format="json").data + response = self.csrf_client.get(f"/api/packages/{self.package.id}", format="json").data assert response == { "url": f"http://testserver/api/packages/{self.package.id}", "purl": "pkg:generic/nginx/test@11", @@ -160,7 +175,7 @@ def test_api_with_single_vulnerability_and_fixed_package(self): } def test_api_with_single_vulnerability_and_vulnerable_package(self): - response = self.client.get(f"/api/packages/{self.vuln_package.id}", format="json").data + response = self.csrf_client.get(f"/api/packages/{self.vuln_package.id}", format="json").data assert response == { "url": f"http://testserver/api/packages/{self.vuln_package.id}", "purl": "pkg:generic/nginx/test@9", @@ -204,6 +219,10 @@ def test_api_with_single_vulnerability_and_vulnerable_package(self): class CPEApi(TestCase): def setUp(self): + self.user = User.objects.create_user("username", "e@mail.com", "secret") + self.auth = f"Token {self.user.auth_token.key}" + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) self.vulnerability = Vulnerability.objects.create(summary="test") for i in range(0, 10): ref, _ = VulnerabilityReference.objects.get_or_create( @@ -214,31 +233,39 @@ def setUp(self): ) def test_api_status(self): - response = self.client.get("/api/cpes/", format="json") + response = self.csrf_client.get("/api/cpes/", format="json") self.assertEqual(status.HTTP_200_OK, response.status_code) def test_api_response(self): - response = self.client.get("/api/cpes/?cpe=cpe:/a:nginx:9", format="json").data + response = self.csrf_client.get("/api/cpes/?cpe=cpe:/a:nginx:9", format="json").data self.assertEqual(response["count"], 1) class AliasApi(TestCase): def setUp(self): + self.user = User.objects.create_user("username", "e@mail.com", "secret") + self.auth = f"Token {self.user.auth_token.key}" + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) self.vulnerability = Vulnerability.objects.create(summary="test") for i in range(0, 10): Alias.objects.create(alias=f"CVE-{i}", vulnerability=self.vulnerability) def test_api_status(self): - response = self.client.get("/api/alias/", format="json") + response = self.csrf_client.get("/api/alias/", format="json") self.assertEqual(status.HTTP_200_OK, response.status_code) def test_api_response(self): - response = self.client.get("/api/alias?alias=CVE-9", format="json").data + response = self.csrf_client.get("/api/alias?alias=CVE-9", format="json").data self.assertEqual(response["count"], 1) -class BulkSearchAPI(TestCase): +class BulkSearchAPIPackage(TestCase): def setUp(self): + self.user = User.objects.create_user("username", "e@mail.com", "secret") + self.auth = f"Token {self.user.auth_token.key}" + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) packages = [ "pkg:nginx/nginx@0.6.18", "pkg:nginx/nginx@1.20.0", @@ -264,16 +291,20 @@ def test_api_response(self): request_body = { "purls": self.packages, } - response = self.client.post( + response = self.csrf_client.post( "/api/packages/bulk_search", - data=request_body, + data=json.dumps(request_body), content_type="application/json", ).json() assert len(response) == 13 -class BulkSearchAPI(TestCase): +class BulkSearchAPICPE(TestCase): def setUp(self): + self.user = User.objects.create_user("username", "e@mail.com", "secret") + self.auth = f"Token {self.user.auth_token.key}" + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) self.exclusive_cpes = [ "cpe:/a:nginx:1.0.7", "cpe:/a:nginx:1.0.15", @@ -305,9 +336,9 @@ def test_api_response_with_with_exclusive_cpes_associated_with_two_vulnerabiliti request_body = { "cpes": self.exclusive_cpes, } - response = self.client.post( + response = self.csrf_client.post( "/api/cpes/bulk_search", - data=request_body, + data=json.dumps(request_body), content_type="application/json", ).json() assert len(response) == 1 @@ -320,9 +351,9 @@ def test_api_response_with_no_cpe_associated(self): request_body = { "cpes": ["cpe:/a:nginx:1.10.7"], } - response = self.client.post( + response = self.csrf_client.post( "/api/cpes/bulk_search", - data=request_body, + data=json.dumps(request_body), content_type="application/json", ).json() assert len(response) == 0 @@ -331,9 +362,9 @@ def test_api_response_with_with_non_exclusive_cpes_associated_with_two_vulnerabi request_body = { "cpes": self.non_exclusive_cpes, } - response = self.client.post( + response = self.csrf_client.post( "/api/cpes/bulk_search", - data=request_body, + data=json.dumps(request_body), content_type="application/json", ).json() assert len(response) == 2 @@ -342,20 +373,18 @@ def test_with_empty_list(self): request_body = { "cpes": [], } - response = self.client.post( + response = self.csrf_client.post( "/api/cpes/bulk_search", - data=request_body, + data=json.dumps(request_body), content_type="application/json", ).json() assert response == {"Error": "A non-empty 'cpe' list of package URLs is required."} def test_with_invalid_cpes(self): - request_body = { - "cpes": ["CVE-2022-2022"], - } - response = self.client.post( + request_body = {"cpes": ["CVE-2022-2022"]} + response = self.csrf_client.post( "/api/cpes/bulk_search", - data=request_body, + data=json.dumps(request_body), content_type="application/json", ).json() assert response == {"Error": "Invalid CPE: CVE-2022-2022"} diff --git a/vulnerablecode/settings.py b/vulnerablecode/settings.py index b7cf0b16f..5f89b8776 100644 --- a/vulnerablecode/settings.py +++ b/vulnerablecode/settings.py @@ -155,8 +155,8 @@ # Django restframework REST_FRAMEWORK = { - "DEFAULT_AUTHENTICATION_CLASSES": ("rest_framework.authentication.SessionAuthentication",), - "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.AllowAny",), + "DEFAULT_AUTHENTICATION_CLASSES": ("rest_framework.authentication.TokenAuthentication",), + "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",), "DEFAULT_RENDERER_CLASSES": ( "rest_framework.renderers.JSONRenderer", "rest_framework.renderers.BrowsableAPIRenderer",