diff --git a/lib/modelscan_api/bailo_modelscan_api/dependencies.py b/lib/modelscan_api/bailo_modelscan_api/dependencies.py index b8ee898ff..2fbc9a1c4 100644 --- a/lib/modelscan_api/bailo_modelscan_api/dependencies.py +++ b/lib/modelscan_api/bailo_modelscan_api/dependencies.py @@ -15,7 +15,32 @@ def parse_path(path: str | Path | None) -> Path: :param path: System path to parse. Defaults to the file's current working directory if unspecified. :return: An absolute Path representation of the path parameter. """ - logger.info("Parsing path.") + logger.debug("Parsing path %s", path) if path is None: path = "." return Path().cwd() if path == "." else Path(path).absolute() + + +def safe_join(root_dir: str | Path | None, filename: str | Path) -> Path: + """Combine a trusted directory path with an untrusted filename to get a full path. + + :param root_dir: Trusted path/directory. + :param filename: Untrusted filename to join to the trusted path. Any path components are stripped off. + :return: Fully joined path with filename. + """ + logger.debug("Safely joining path '%s' with filename '%s'", root_dir, filename) + + if not filename or not str(filename).strip(): + raise ValueError("filename must not be empty") + + stripped_filename = Path(str(filename)).name.strip() + + if not stripped_filename: + raise ValueError("filename must not be empty") + + parent_dir = parse_path(root_dir).resolve() + full_path = parent_dir.joinpath(stripped_filename).resolve() + if not full_path.is_relative_to(parent_dir): + raise ValueError("Could not safely join paths.") + + return full_path diff --git a/lib/modelscan_api/bailo_modelscan_api/main.py b/lib/modelscan_api/bailo_modelscan_api/main.py index 929484a84..f9537163a 100644 --- a/lib/modelscan_api/bailo_modelscan_api/main.py +++ b/lib/modelscan_api/bailo_modelscan_api/main.py @@ -9,14 +9,14 @@ from http import HTTPStatus from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any +from typing import Annotated, Any import uvicorn -from bailo_modelscan_api.config import Settings -from bailo_modelscan_api.dependencies import parse_path from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, UploadFile from modelscan.modelscan import ModelScan +from bailo_modelscan_api.config import Settings +from bailo_modelscan_api.dependencies import safe_join logger = logging.getLogger(__name__) @@ -65,7 +65,9 @@ def health_check() -> dict[str, str]: status_code=HTTPStatus.OK, response_description="The result from ModelScan", ) -def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[str, Any]: +def scan_file( + in_file: UploadFile, background_tasks: BackgroundTasks, settings: Annotated[Settings, Depends(get_settings)] +) -> dict[str, Any]: """API endpoint to upload and scan a file using modelscan. :param in_file: uploaded file to be scanned @@ -76,11 +78,17 @@ def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[st logger.info("Called the API endpoint to scan an uploaded file") try: # Use Setting's download_dir if defined else use a temporary directory. - with ( - TemporaryDirectory() if not get_settings().download_dir else nullcontext(get_settings().download_dir) - ) as download_dir: + with TemporaryDirectory() if not settings.download_dir else nullcontext(settings.download_dir) as download_dir: if in_file.filename and str(in_file.filename).strip(): - pathlib_path = Path.joinpath(parse_path(download_dir), str(in_file.filename)) + # Prevent escaping to a parent dir + try: + pathlib_path = safe_join(download_dir, in_file.filename) + except ValueError: + logger.exception("Failed to safely join the filename to the path.") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="An error occurred while processing the uploaded file's name.", + ) else: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, @@ -92,7 +100,7 @@ def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[st # doesn't currently support streaming. try: with open(pathlib_path, "wb") as out_file: - while content := in_file.file.read(get_settings().block_size): + while content := in_file.file.read(settings.block_size): out_file.write(content) except OSError as exception: logger.exception("Failed writing the file to the disk.") diff --git a/lib/modelscan_api/bailo_modelscan_api/test_dependencies.py b/lib/modelscan_api/bailo_modelscan_api/test_dependencies.py new file mode 100644 index 000000000..17f452bc8 --- /dev/null +++ b/lib/modelscan_api/bailo_modelscan_api/test_dependencies.py @@ -0,0 +1,102 @@ +"""Test for the main.py file. +""" + +import itertools +from pathlib import Path + +import pytest + +from .dependencies import safe_join + + +# Helpers + + +def string_path_matrix(path1, path2): + # List of pairs of paths, as a str and Path representation of each. + return itertools.product(*[[str(x), Path(x)] for x in [path1, path2]]) + + +def helper_test_safe_join(path1, path2, output): + # check expected output given 2 inputs + for test_dir, test_file in string_path_matrix(path1, path2): + res = safe_join(test_dir, test_file) + assert res == output + + +def helper_test_safe_join_catch(path1, path2): + # check error thrown given two inputs + for test_dir, test_file in string_path_matrix(path1, path2): + with pytest.raises(ValueError): + safe_join(test_dir, test_file) + + +# Tests + + +def test_safe_join_blank(): + helper_test_safe_join("", "foo.bar", Path.cwd().joinpath("foo.bar")) + + +def test_safe_join_local(): + helper_test_safe_join(".", "foo.bar", Path.cwd().joinpath("foo.bar")) + + +def test_safe_join_abs(): + helper_test_safe_join("/tmp", "foo.bar", Path("/tmp").joinpath("foo.bar")) + + +def test_safe_join_abs_trailing(): + helper_test_safe_join("/tmp/", "foo.bar", Path("/tmp").joinpath("foo.bar")) + + +def test_safe_join_abs_dot(): + helper_test_safe_join("/tmp", ".foo.bar", Path("/tmp").joinpath(".foo.bar")) + + +def test_safe_join_abs_slash(): + helper_test_safe_join("/tmp", "/foo.bar", Path("/tmp").joinpath("foo.bar")) + + +def test_safe_join_abs_double_slash(): + helper_test_safe_join("/tmp", "//foo.bar", Path("/tmp").joinpath("foo.bar")) + + +def test_safe_join_abs_dot_slash(): + helper_test_safe_join("/tmp", "./foo.bar", Path("/tmp").joinpath("foo.bar")) + + +def test_safe_join_abs_dot_slash_dot(): + helper_test_safe_join("/tmp", "./.foo.bar", Path("/tmp").joinpath(".foo.bar")) + + +def test_safe_join_abs_double_dot(): + helper_test_safe_join("/tmp", "..foo.bar", Path("/tmp").joinpath("..foo.bar")) + + +def test_safe_join_abs_double_dot_slash(): + helper_test_safe_join("/tmp", "../foo.bar", Path("/tmp").joinpath("foo.bar")) + + +def test_safe_join_abs_double_dot_slash_dot(): + helper_test_safe_join("/tmp", "../.foo.bar", Path("/tmp").joinpath(".foo.bar")) + + +def test_safe_join_fail_blank(): + helper_test_safe_join_catch("/tmp", "") + + +def test_safe_join_fail_dot(): + helper_test_safe_join_catch("/tmp", ".") + + +def test_safe_join_fail_double_dot(): + helper_test_safe_join_catch("/tmp", "..") + + +def test_safe_join_fail_slash(): + helper_test_safe_join_catch("/tmp", "/") + + +def test_safe_join_fail_double_slash(): + helper_test_safe_join_catch("/tmp", "//") diff --git a/lib/modelscan_api/bailo_modelscan_api/test_main.py b/lib/modelscan_api/bailo_modelscan_api/test_main.py index 590e95c36..05492e8aa 100644 --- a/lib/modelscan_api/bailo_modelscan_api/test_main.py +++ b/lib/modelscan_api/bailo_modelscan_api/test_main.py @@ -37,6 +37,26 @@ def test_scan_file(mock_scan: Mock): mock_scan.assert_called_once() +@patch("modelscan.modelscan.ModelScan.scan") +def test_scan_file_escape_path(mock_scan: Mock): + mock_scan.return_value = {} + files = {"in_file": ("../foo.bar", rb"", "application/x-hdf5")} + + response = client.post("/scan/file", files=files) + + assert response.status_code == 200 + mock_scan.assert_called_once() + + +def test_scan_file_escape_path_error(): + files = {"in_file": ("..", rb"", "text/plain")} + + response = client.post("/scan/file", files=files) + + assert response.status_code == 500 + assert response.json() == {"detail": "An error occurred while processing the uploaded file's name."} + + @patch("modelscan.modelscan.ModelScan.scan") def test_scan_file_exception(mock_scan: Mock): mock_scan.side_effect = Exception("Mocked error!")