Skip to content

Commit

Permalink
BAI-1540 rework basic modelscan tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PE39806 committed Dec 17, 2024
1 parent 31a25d3 commit f549fe3
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
5 changes: 5 additions & 0 deletions lib/modelscan_api/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[pytest]
filterwarnings = ignore::DeprecationWarning
pythonpath = "."
testpaths = "tests"
junit_family = "xunit2"
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest

from .dependencies import safe_join
from bailo_modelscan_api.dependencies import safe_join

# Helpers

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@

from functools import lru_cache
from pathlib import Path
from typing import Any
from unittest.mock import Mock, patch

import modelscan
from fastapi.testclient import TestClient
import pytest

from .config import Settings
from .dependencies import parse_path
from .main import app, get_settings
from bailo_modelscan_api.config import Settings
from bailo_modelscan_api.dependencies import parse_path
from bailo_modelscan_api.main import app, get_settings

client = TestClient(app)

Expand All @@ -25,6 +27,10 @@ def get_settings_override():
app.dependency_overrides[get_settings] = get_settings_override


EMPTY_CONTENTS = rb""
H5_MIME_TYPE = "application/x-hdf5"


def test_info():
response = client.get("/info")

Expand All @@ -38,29 +44,26 @@ def test_info():


@patch("modelscan.modelscan.ModelScan.scan")
def test_scan_file(mock_scan: Mock):
mock_scan.return_value = {}
files = {"in_file": ("foo.h5", rb"", "application/x-hdf5")}

response = client.post("/scan/file", files=files)

assert response.status_code == 200
mock_scan.assert_called_once()


@patch("modelscan.modelscan.ModelScan.scan")
def test_scan_file_escape_path(mock_scan: Mock):
@pytest.mark.parametrize(
("file_name", "file_content", "file_mime_type"),
[("foo.h5", EMPTY_CONTENTS, H5_MIME_TYPE), ("../foo.h5", EMPTY_CONTENTS, H5_MIME_TYPE)],
)
def test_scan_file(mock_scan: Mock, file_name: str, file_content: Any, file_mime_type: str):
mock_scan.return_value = {}
files = {"in_file": ("../foo.bar", rb"", "application/x-hdf5")}
files = {"in_file": (file_name, file_content, file_mime_type)}

response = client.post("/scan/file", files=files)

assert response.status_code == 200
mock_scan.assert_called_once()


def test_scan_file_escape_path_error():
files = {"in_file": ("..", rb"", "text/plain")}
@pytest.mark.parametrize(
("file_name", "file_content", "file_mime_type"),
[("..", EMPTY_CONTENTS, H5_MIME_TYPE), ("../", EMPTY_CONTENTS, H5_MIME_TYPE)],
)
def test_scan_file_escape_path_error(file_name: str, file_content: Any, file_mime_type: str):
files = {"in_file": (file_name, file_content, file_mime_type)}

response = client.post("/scan/file", files=files)

Expand All @@ -69,9 +72,13 @@ def test_scan_file_escape_path_error():


@patch("modelscan.modelscan.ModelScan.scan")
def test_scan_file_exception(mock_scan: Mock):
@pytest.mark.parametrize(
("file_name", "file_content", "file_mime_type"),
[("foo.h5", EMPTY_CONTENTS, H5_MIME_TYPE)],
)
def test_scan_file_exception(mock_scan: Mock, file_name: str, file_content: Any, file_mime_type: str):
mock_scan.side_effect = Exception("Mocked error!")
files = {"in_file": ("foo.h5", rb"", "application/x-hdf5")}
files = {"in_file": (file_name, file_content, file_mime_type)}

response = client.post("/scan/file", files=files)

Expand All @@ -83,8 +90,12 @@ def test_scan_file_exception(mock_scan: Mock):
Path.unlink(Path.joinpath(parse_path(get_settings().download_dir), "foo.h5"), missing_ok=True)


def test_scan_file_filename_missing():
files = {"in_file": (" ", rb"", "application/x-hdf5")}
@pytest.mark.parametrize(
("file_name", "file_content", "file_mime_type"),
[(" ", EMPTY_CONTENTS, H5_MIME_TYPE)],
)
def test_scan_file_filename_missing(file_name: str, file_content: Any, file_mime_type: str):
files = {"in_file": (file_name, file_content, file_mime_type)}

response = client.post("/scan/file", files=files)

Expand Down

0 comments on commit f549fe3

Please sign in to comment.