Skip to content

Commit

Permalink
BAI-1540 continue modelscan api pytest rework
Browse files Browse the repository at this point in the history
  • Loading branch information
PE39806 committed Dec 18, 2024
1 parent f549fe3 commit 0aaa2c4
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 94 deletions.
19 changes: 15 additions & 4 deletions lib/modelscan_api/bailo_modelscan_api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import logging
import re
from pathlib import Path

logger = logging.getLogger(__name__)
Expand All @@ -21,7 +22,17 @@ def parse_path(path: str | Path | None) -> Path:
return Path().cwd() if path == "." else Path(path).absolute()


def safe_join(root_dir: str | Path | None, filename: str | Path) -> Path:
def sanitise_unix_filename(filename: str) -> str:
"""Safely convert an arbitrary string to a valid unix filename by only preserving explicitly allowed characters as per https://en.wikipedia.org/wiki/Filename#Reserved_characters_and_words
Note that this is not safe for Windows users as it doesn't check for reserved words e.g. CON and AUX.
:param filename: the untrusted filename to be sanitised
:return: a valid filename with trusted characters
"""
return re.sub(r"[/\\?%*:|\"<>\x7F\x00-\x1F]", "-", filename)


def safe_join(root_dir: str | Path | None, filename: str) -> Path:
"""Combine a trusted directory path with an untrusted filename to get a full path.
:param root_dir: Trusted path/directory.
Expand All @@ -33,13 +44,13 @@ def safe_join(root_dir: str | Path | None, filename: str | Path) -> Path:
if not filename or not str(filename).strip():
raise ValueError("filename must not be empty")

stripped_filename = Path(str(filename)).name.strip()
safe_filename = sanitise_unix_filename(filename).strip()

if not stripped_filename:
if not safe_filename:
raise ValueError("filename must not be empty")

parent_dir = parse_path(root_dir).resolve()
full_path = parent_dir.joinpath(stripped_filename).resolve()
full_path = parent_dir.joinpath(safe_filename).resolve()
if not full_path.is_relative_to(parent_dir):
raise ValueError("Could not safely join paths.")

Expand Down
169 changes: 81 additions & 88 deletions lib/modelscan_api/tests/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest

from bailo_modelscan_api.dependencies import safe_join
from bailo_modelscan_api.dependencies import parse_path, safe_join, sanitise_unix_filename

# Helpers

Expand All @@ -26,106 +26,99 @@ def type_matrix(data: Iterable[Any], types: Iterable[type]) -> itertools.product
return itertools.product(*[[t(d) for t in types] for d in data])


def string_path_matrix(path1: str | Path, path2: str | Path) -> itertools.product[tuple[str, Path]]:
"""Wrap type_matrix for convenience with str and Path types.
:param path1: A path to process.
:param path2: Another path to process.
:return: The matrix of both paths with types str and Path.
"""
return type_matrix([path1, path2], [str, Path])
# Tests


def helper_test_safe_join(path1: str | Path, path2: str | Path, output: Path) -> None:
"""Helper method for testing that all str and Path representations of the two paths will match the given output when joined.
@pytest.mark.parametrize(
("path", "output"),
[
("foo.bar", "foo.bar"),
(".foo.bar", ".foo.bar"),
("/foo.bar", "-foo.bar"),
("foo/./bar", "foo-.-bar"),
("foo.-/bar", "foo.--bar"),
(".", "."),
("..", ".."),
("/", "-"),
("/.", "-."),
("./", ".-"),
("\n", "-"),
("\r", "-"),
("~", "~"),
("".join(['\\[/\\?%*:|"<>0x7F0x00-0x1F]', chr(0x1F) * 15]), "-[----------0x7F0x00-0x1F]---------------"),
("ad\nbla'{-+\\)(ç?", "ad-bla'{-+-)(ç-"), # type: ignore
],
)
def test_sanitise_unix_filename(path: str, output: str) -> None:
assert sanitise_unix_filename(path) == output


@pytest.mark.parametrize(
("path", "output"),
[
(None, Path().cwd()),
("", Path().cwd()),
(".", Path().cwd()),
("/tmp", Path("/tmp")),
("/foo/bar", Path("/foo/bar")),
("/foo/../bar", Path("/foo/../bar")),
("/foo/bar space/baz", Path("/foo/bar space/baz")),
("/C:\\Program Files\\HAL 9000", Path("/C:\\Program Files\\HAL 9000")),
("/ISO&Emulator", Path("/ISO&Emulator")),
("/$HOME", Path("/$HOME")),
("~", Path().cwd().joinpath("~")),
],
)
def test_parse_path(path: str | Path | None, output: Path) -> None:
if path is None:
assert parse_path(path) == output
else:
for (test_path,) in type_matrix((path,), (str, Path)):
assert parse_path(test_path) == output


@pytest.mark.parametrize(
("path1", "path2", "output"),
[
("", "foo.bar", Path.cwd().joinpath("foo.bar")),
(".", "foo.bar", Path.cwd().joinpath("foo.bar")),
("/tmp", "foo.bar", Path("/tmp/foo.bar")),
("/tmp/", "foo.bar", Path("/tmp/foo.bar")),
("/tmp/", "/foo.bar", Path("/tmp/-foo.bar")),
("/tmp", ".foo.bar", Path("/tmp/.foo.bar")),
("/tmp", "/foo.bar", Path("/tmp/-foo.bar")),
("/tmp", "//foo.bar", Path("/tmp/--foo.bar")),
("/tmp", "./foo.bar", Path("/tmp/.-foo.bar")),
("/tmp", "./.foo.bar", Path("/tmp/.-.foo.bar")),
("/tmp", "..foo.bar", Path("/tmp/..foo.bar")),
("/tmp", "../foo.bar", Path("/tmp/..-foo.bar")),
("/tmp", "../.foo.bar", Path("/tmp/..-.foo.bar")),
("/tmp", ".", Path("/tmp/.")),
("/tmp", "/", Path("/tmp/-")),
("/tmp", "//", Path("/tmp/--")),
("/tmp", "~", Path("/tmp/~")),
],
)
def test_safe_join(path1: str | Path, path2: str, output: Path) -> None:
"""Test that all str and Path representations of the two paths will match the given output when joined.
:param path1: Directory part of the final path.
:param path2: Filename part of the final path.
:param output: Expected final path value.
"""
for test_dir, test_file in string_path_matrix(path1, path2):
res = safe_join(test_dir, test_file)
for (test_dir,) in type_matrix((path1,), (str, Path)):
res = safe_join(test_dir, path2)
assert res == output


def helper_test_safe_join_catch(path1: str | Path, path2: str | Path) -> None:
"""Helper method for testing that all str and Path representation of the two paths will throw an error when joined.
@pytest.mark.parametrize(("path1", "path2"), [("/tmp", ""), ("/tmp", "..")])
def test_safe_join_catch(path1: str | Path, path2: str) -> None:
"""Test that all str and Path representation of the two paths will throw an error when joined.
:param path1: Directory part of the final path.
:param path2: Filename part of the final path.
"""
# check error thrown given two inputs
for test_dir, test_file in string_path_matrix(path1, path2):
for (test_dir,) in type_matrix((path1,), (str, Path)):
with pytest.raises(ValueError):
safe_join(test_dir, test_file)


# Tests


def test_safe_join_blank():
helper_test_safe_join("", "foo.bar", Path.cwd().joinpath("foo.bar"))


def test_safe_join_local():
helper_test_safe_join(".", "foo.bar", Path.cwd().joinpath("foo.bar"))


def test_safe_join_abs():
helper_test_safe_join("/tmp", "foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_trailing():
helper_test_safe_join("/tmp/", "foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_dot():
helper_test_safe_join("/tmp", ".foo.bar", Path("/tmp").joinpath(".foo.bar"))


def test_safe_join_abs_slash():
helper_test_safe_join("/tmp", "/foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_double_slash():
helper_test_safe_join("/tmp", "//foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_dot_slash():
helper_test_safe_join("/tmp", "./foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_dot_slash_dot():
helper_test_safe_join("/tmp", "./.foo.bar", Path("/tmp").joinpath(".foo.bar"))


def test_safe_join_abs_double_dot():
helper_test_safe_join("/tmp", "..foo.bar", Path("/tmp").joinpath("..foo.bar"))


def test_safe_join_abs_double_dot_slash():
helper_test_safe_join("/tmp", "../foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_double_dot_slash_dot():
helper_test_safe_join("/tmp", "../.foo.bar", Path("/tmp").joinpath(".foo.bar"))


def test_safe_join_fail_blank():
helper_test_safe_join_catch("/tmp", "")


def test_safe_join_fail_dot():
helper_test_safe_join_catch("/tmp", ".")


def test_safe_join_fail_double_dot():
helper_test_safe_join_catch("/tmp", "..")


def test_safe_join_fail_slash():
helper_test_safe_join_catch("/tmp", "/")


def test_safe_join_fail_double_slash():
helper_test_safe_join_catch("/tmp", "//")
safe_join(test_dir, path2)
4 changes: 2 additions & 2 deletions lib/modelscan_api/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from unittest.mock import Mock, patch

import modelscan
from fastapi.testclient import TestClient
import pytest
from fastapi.testclient import TestClient

from bailo_modelscan_api.config import Settings
from bailo_modelscan_api.dependencies import parse_path
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_scan_file(mock_scan: Mock, file_name: str, file_content: Any, file_mime

@pytest.mark.parametrize(
("file_name", "file_content", "file_mime_type"),
[("..", EMPTY_CONTENTS, H5_MIME_TYPE), ("../", EMPTY_CONTENTS, H5_MIME_TYPE)],
[("..", EMPTY_CONTENTS, H5_MIME_TYPE)],
)
def test_scan_file_escape_path_error(file_name: str, file_content: Any, file_mime_type: str):
files = {"in_file": (file_name, file_content, file_mime_type)}
Expand Down

0 comments on commit 0aaa2c4

Please sign in to comment.