From 823e26fc57d85393f083ca869996efef571ebf3e Mon Sep 17 00:00:00 2001 From: Olwe Samuel Date: Tue, 23 Jul 2024 23:58:01 +0300 Subject: [PATCH] Added `cves` field using numerical_id --- tests/__init__.py | 18 ++++--- tests/fixtures/models.py | 106 +++++++++++++++++++++++++++------------ tests/test_routes.py | 26 ++++++++++ webapp/app.py | 3 ++ webapp/commands.py | 18 +++++-- webapp/models.py | 10 ++++ webapp/schemas.py | 18 ++++++- webapp/views.py | 60 +++++++++++++++------- 8 files changed, 197 insertions(+), 62 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 521129e..a0ec41d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -47,6 +47,8 @@ class BaseTestCase(unittest.TestCase): + db = db + def setUp(self): app.testing = True @@ -55,7 +57,7 @@ def setUp(self): self.context.push() # Clear DB - db.drop_all() + self.db.drop_all() with redirect_stderr(io.StringIO()): flask_migrate.stamp(revision="base") @@ -65,18 +67,18 @@ def setUp(self): # Import data self.models = make_models() - db.session.add(self.models["cve"]) - db.session.add(self.models["notice"]) - db.session.add(self.models["release"]) - db.session.add(self.models["package"]) - db.session.add(self.models["status"]) - db.session.commit() + self.db.session.add(self.models["cve"]) + self.db.session.add(self.models["notice"]) + self.db.session.add(self.models["release"]) + self.db.session.add(self.models["package"]) + self.db.session.add(self.models["status"]) + self.db.session.commit() self.client = app.test_client() return super().setUp() def tearDown(self): - db.session.close() + self.db.session.close() self.context.pop() diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 8bab63e..b03273e 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -1,6 +1,76 @@ from datetime import datetime -from webapp.models import Notice, Release, Status, CVE, Package +from webapp.models import ( + Notice, + Release, + Status, + CVE, + Package, +) + + +def make_cve( + id, + published=datetime.now(), + description="", + ubuntu_description="", + notes=[ + { + "author": "mysql", + "note": "mysql-1.2 is not affected by this CVE", + } + ], + priority="critical", + cvss3=2.3, + impact={}, + codename="test_name", + mitigation="", + references={}, + patches={}, + tags={}, + bugs={}, + status="active", +): + cve = CVE( + id=id, + published=published, + description=description, + ubuntu_description=ubuntu_description, + notes=notes, + priority=priority, + cvss3=cvss3, + impact=impact, + codename=codename, + mitigation=mitigation, + references=references, + patches=patches, + tags=tags, + bugs=bugs, + status=status, + ) + return cve + + +def make_notice( + id, + is_hidden=False, + published=datetime.now(), + summary="", + details="", + instructions="", + releases=[], + cves=[], +): + return Notice( + id=id, + is_hidden=is_hidden, + published=published, + summary=summary, + details=details, + instructions=instructions, + releases=releases, + cves=cves, + ) def make_models(): @@ -23,28 +93,7 @@ def make_models(): debian="test-package-debian", ) - cve = CVE( - id="CVE-1111-0001", - published=datetime.now(), - description="", - ubuntu_description="", - notes=[ - { - "author": "mysql", - "note": "mysql-1.2 is not affected by this CVE", - } - ], - priority="critical", - cvss3=2.3, - impact={}, - codename="test_name", - mitigation="", - references={}, - patches={}, - tags={}, - bugs={}, - status="active", - ) + cve = make_cve("CVE-1111-0001") status = Status( status="pending", @@ -53,16 +102,7 @@ def make_models(): release=release, ) - notice = Notice( - id="USN-1111-01", - is_hidden=False, - published=datetime.now(), - summary="", - details="", - instructions="", - releases=[release], - cves=[cve], - ) + notice = make_notice("USN-1111-01", releases=[release], cves=[cve]) return { "release": release, diff --git a/tests/test_routes.py b/tests/test_routes.py index 361bf79..ccf4f60 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -1,6 +1,7 @@ import unittest from tests import BaseTestCase from tests.fixtures import payloads +from tests.fixtures.models import make_cve, make_notice class TestRoutes(BaseTestCase): @@ -945,6 +946,31 @@ def test_usn(self): assert response.status_code == 200 assert response.json["cves_ids"] == self.models["notice"].cves_ids + # Test cves field + cve_id = self.models["notice"].cves[0].id + + test_cve = make_cve("CVE-9999-0001") + test_notice = make_notice("USN-9999-0001", cves=[test_cve]) + self.db.session.add(test_cve) + self.db.session.add(test_notice) + self.db.session.commit() + + response = self.client.get( + f"/security/notices.json?cves={cve_id},{test_cve.id}" + ) + + assert response.status_code == 200 + assert len(response.json["notices"]) == 2 + # Check for either cve_id in the returned notices + assert ( + cve_id in response.json["notices"][0]["cves_ids"] + or test_cve.id in response.json["notices"][0]["cves_ids"] + ) + assert ( + test_cve.id in response.json["notices"][0]["cves_ids"] + or cve_id in response.json["notices"][0]["cves_ids"] + ) + def test_multiple_usn(self): response = self.client.get("/security/notices.json") diff --git a/webapp/app.py b/webapp/app.py index ed49ec3..086884b 100644 --- a/webapp/app.py +++ b/webapp/app.py @@ -7,6 +7,7 @@ from flask_migrate import Migrate from webapp.api_spec import WebappFlaskApiSpec +from webapp.commands import register_commands from webapp.database import db from webapp.views import ( create_notice, @@ -50,6 +51,8 @@ db.init_app(app) migrate = Migrate(app, db) +register_commands(app) + app.add_url_rule( "/security/cves/.json", view_func=get_cve, diff --git a/webapp/commands.py b/webapp/commands.py index 0a550a0..c3dc716 100644 --- a/webapp/commands.py +++ b/webapp/commands.py @@ -1,12 +1,24 @@ import click -from webapp.app import app from webapp.models import ( upsert_numerical_cve_ids, ) -@app.cli.command("insert-numerical-cve-ids") +@click.command("insert_numerical_cve_ids") def insert_numerical_cve_ids(): - """Management script for the Wiki application.""" + """ + For each cve, update cves.numerical_field with the numerical value + of the CVE id e.g 'CVE-2025-12345' -> 202512345. + """ + upsert_numerical_cve_ids() + click.echo("Numerical CVE ids inserted successfully.") + + +def register_commands(app): + """Register Click commands.""" + # Set up app context + app.app_context().push() + + app.cli.add_command(insert_numerical_cve_ids) diff --git a/webapp/models.py b/webapp/models.py index 18f53c0..f64b149 100644 --- a/webapp/models.py +++ b/webapp/models.py @@ -163,12 +163,22 @@ def upsert_numerical_cve_ids(): all_cves = db.session.query(CVE).all() updated_cves = [] for cve in all_cves: + print(f"Updating numerical_id for {cve.id}") cve.numerical_id = convert_cve_id_to_numerical_id(cve.id) updated_cves.append(cve) db.session.add_all(updated_cves) db.session.commit() +@db.event.listens_for(CVE, "after_insert") +def insert_numerical_id(mapper, connection, target): + """ + Update the numerical_id column using the CVE id whenever a new CVE is + inserted. + """ + target.numerical_id = convert_cve_id_to_numerical_id(target.id) + + class Notice(db.Model): __tablename__ = "notice" diff --git a/webapp/schemas.py b/webapp/schemas.py index 0c331e5..6c09861 100644 --- a/webapp/schemas.py +++ b/webapp/schemas.py @@ -1,5 +1,5 @@ import dateutil.parser -from marshmallow import Schema +from marshmallow import Schema, ValidationError from marshmallow.fields import ( Boolean, DateTime, @@ -197,6 +197,21 @@ def _deserialize(self, value, attr, data, **kwargs): return super()._deserialize(value, attr, data, **kwargs) +class StringDelimitedList(String): + """ + Support lists of strings that are delimited by commas e.g + "foo,bar" -> ["foo", "bar",] + """ + + def _deserialize(self, value, attr, data, **kwargs): + try: + return value.split(",") + except AttributeError: + raise ValidationError( + f"{attr} is not a string delimited list.\n value: {value}." + ) + + # Notices # -- class NoticePackage(Schema): @@ -264,6 +279,7 @@ class NoticeAPISchema(NoticeSchema): allow_none=True, ), "cve_id": String(allow_none=True), + "cves": StringDelimitedList(allow_none=True), "release": String(allow_none=True), "limit": Int( validate=Range(min=1, max=100), diff --git a/webapp/views.py b/webapp/views.py index 615f7cf..1bc13e8 100644 --- a/webapp/views.py +++ b/webapp/views.py @@ -2,7 +2,6 @@ from collections import defaultdict from datetime import datetime from distutils.util import strtobool - from flask import make_response, jsonify, request from flask_apispec import marshal_with, use_kwargs from sqlalchemy import desc, or_, and_, case, asc, text @@ -18,6 +17,7 @@ Status, Package, STATUS_STATUSES, + convert_cve_id_to_numerical_id, ) from webapp.schemas import ( CreateNoticeImportSchema, @@ -425,6 +425,7 @@ def get_notices(**kwargs): limit = kwargs.get("limit", 20) offset = kwargs.get("offset", 0) order_by = kwargs.get("order") + cves = kwargs.get("cves") notices_query: Query = db.session.query(Notice) @@ -445,29 +446,54 @@ def get_notices(**kwargs): Notice.id.ilike(f"%{details}%"), Notice.details.ilike(f"%{details}%"), Notice.title.ilike(f"%{details}%"), - Notice.cves.any(CVE.id.ilike(f"%{details}%")), ) ) sort = asc if order_by == "oldest" else desc - notices = ( - notices_query.options( - selectinload(Notice.cves).options( - selectinload(CVE.statuses), - selectinload(CVE.notices).options( - load_only( - Notice.id, Notice.is_hidden, Notice.release_packages - ) - ), + if cves: + # Get cves by numerical id + numerical_cve_ids = [ + convert_cve_id_to_numerical_id(cve) for cve in cves + ] + matched_cves = ( + db.session.query(CVE) + .filter(CVE.numerical_id.in_(numerical_cve_ids)) + .all() + ) + # Get notices_ids from cves + notice_ids = [] + for cve in matched_cves: + notice_ids += [notice.id for notice in cve.notices] + + notices = ( + notices_query.filter(Notice.id.in_(notice_ids)) + .order_by(sort(Notice.published), sort(Notice.id)) + .offset(offset) + .limit(limit) + .all() + ) + + else: + notices = ( + notices_query.options( + selectinload(Notice.cves).options( + selectinload(CVE.statuses), + selectinload(CVE.notices).options( + load_only( + Notice.id, + Notice.is_hidden, + Notice.release_packages, + ) + ), + ) ) + .options(selectinload(Notice.releases)) + .order_by(sort(Notice.published), sort(Notice.id)) + .offset(offset) + .limit(limit) + .all() ) - .options(selectinload(Notice.releases)) - .order_by(sort(Notice.published), sort(Notice.id)) - .offset(offset) - .limit(limit) - .all() - ) return { "notices": notices,