Skip to content

Commit

Permalink
BAI-1502 prevent upload filepath escaping for security
Browse files Browse the repository at this point in the history
  • Loading branch information
PE39806 committed Nov 15, 2024
1 parent fe40643 commit 4c52676
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 10 deletions.
27 changes: 26 additions & 1 deletion lib/modelscan_api/bailo_modelscan_api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,32 @@ def parse_path(path: str | Path | None) -> Path:
:param path: System path to parse. Defaults to the file's current working directory if unspecified.
:return: An absolute Path representation of the path parameter.
"""
logger.info("Parsing path.")
logger.debug("Parsing path %s", path)
if path is None:
path = "."
return Path().cwd() if path == "." else Path(path).absolute()


def safe_join(root_dir: str | Path | None, filename: str | Path) -> Path:
"""Combine a trusted directory path with an untrusted filename to get a full path.
:param root_dir: Trusted path/directory.
:param filename: Untrusted filename to join to the trusted path. Any path components are stripped off.
:return: Fully joined path with filename.
"""
logger.debug("Safely joining path '%s' with filename '%s'", root_dir, filename)

if not filename or not str(filename).strip():
raise ValueError("filename must not be empty")

stripped_filename = Path(str(filename)).name.strip()

if not stripped_filename:
raise ValueError("filename must not be empty")

parent_dir = parse_path(root_dir).resolve()
full_path = parent_dir.joinpath(stripped_filename).resolve()
if not full_path.is_relative_to(parent_dir):
raise ValueError("Could not safely join paths.")

return full_path
26 changes: 17 additions & 9 deletions lib/modelscan_api/bailo_modelscan_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from http import HTTPStatus
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
from typing import Annotated, Any

import uvicorn
from bailo_modelscan_api.config import Settings
from bailo_modelscan_api.dependencies import parse_path
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, UploadFile
from modelscan.modelscan import ModelScan

from bailo_modelscan_api.config import Settings
from bailo_modelscan_api.dependencies import safe_join

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,7 +65,9 @@ def health_check() -> dict[str, str]:
status_code=HTTPStatus.OK,
response_description="The result from ModelScan",
)
def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[str, Any]:
def scan_file(
in_file: UploadFile, background_tasks: BackgroundTasks, settings: Annotated[Settings, Depends(get_settings)]
) -> dict[str, Any]:
"""API endpoint to upload and scan a file using modelscan.
:param in_file: uploaded file to be scanned
Expand All @@ -76,11 +78,17 @@ def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[st
logger.info("Called the API endpoint to scan an uploaded file")
try:
# Use Setting's download_dir if defined else use a temporary directory.
with (
TemporaryDirectory() if not get_settings().download_dir else nullcontext(get_settings().download_dir)
) as download_dir:
with TemporaryDirectory() if not settings.download_dir else nullcontext(settings.download_dir) as download_dir:
if in_file.filename and str(in_file.filename).strip():
pathlib_path = Path.joinpath(parse_path(download_dir), str(in_file.filename))
# Prevent escaping to a parent dir
try:
pathlib_path = safe_join(download_dir, in_file.filename)
except ValueError:
logger.exception("Failed to safely join the filename to the path.")
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="An error occurred while processing the uploaded file's name.",
)
else:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
Expand All @@ -92,7 +100,7 @@ def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[st
# doesn't currently support streaming.
try:
with open(pathlib_path, "wb") as out_file:
while content := in_file.file.read(get_settings().block_size):
while content := in_file.file.read(settings.block_size):
out_file.write(content)
except OSError as exception:
logger.exception("Failed writing the file to the disk.")
Expand Down
102 changes: 102 additions & 0 deletions lib/modelscan_api/bailo_modelscan_api/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Test for the main.py file.
"""

import itertools
from pathlib import Path

import pytest

from .dependencies import safe_join


# Helpers


def string_path_matrix(path1, path2):
# List of pairs of paths, as a str and Path representation of each.
return itertools.product(*[[str(x), Path(x)] for x in [path1, path2]])


def helper_test_safe_join(path1, path2, output):
# check expected output given 2 inputs
for test_dir, test_file in string_path_matrix(path1, path2):
res = safe_join(test_dir, test_file)
assert res == output


def helper_test_safe_join_catch(path1, path2):
# check error thrown given two inputs
for test_dir, test_file in string_path_matrix(path1, path2):
with pytest.raises(ValueError):
safe_join(test_dir, test_file)


# Tests


def test_safe_join_blank():
helper_test_safe_join("", "foo.bar", Path.cwd().joinpath("foo.bar"))


def test_safe_join_local():
helper_test_safe_join(".", "foo.bar", Path.cwd().joinpath("foo.bar"))


def test_safe_join_abs():
helper_test_safe_join("/tmp", "foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_trailing():
helper_test_safe_join("/tmp/", "foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_dot():
helper_test_safe_join("/tmp", ".foo.bar", Path("/tmp").joinpath(".foo.bar"))


def test_safe_join_abs_slash():
helper_test_safe_join("/tmp", "/foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_double_slash():
helper_test_safe_join("/tmp", "//foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_dot_slash():
helper_test_safe_join("/tmp", "./foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_dot_slash_dot():
helper_test_safe_join("/tmp", "./.foo.bar", Path("/tmp").joinpath(".foo.bar"))


def test_safe_join_abs_double_dot():
helper_test_safe_join("/tmp", "..foo.bar", Path("/tmp").joinpath("..foo.bar"))


def test_safe_join_abs_double_dot_slash():
helper_test_safe_join("/tmp", "../foo.bar", Path("/tmp").joinpath("foo.bar"))


def test_safe_join_abs_double_dot_slash_dot():
helper_test_safe_join("/tmp", "../.foo.bar", Path("/tmp").joinpath(".foo.bar"))


def test_safe_join_fail_blank():
helper_test_safe_join_catch("/tmp", "")


def test_safe_join_fail_dot():
helper_test_safe_join_catch("/tmp", ".")


def test_safe_join_fail_double_dot():
helper_test_safe_join_catch("/tmp", "..")


def test_safe_join_fail_slash():
helper_test_safe_join_catch("/tmp", "/")


def test_safe_join_fail_double_slash():
helper_test_safe_join_catch("/tmp", "//")
20 changes: 20 additions & 0 deletions lib/modelscan_api/bailo_modelscan_api/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ def test_scan_file(mock_scan: Mock):
mock_scan.assert_called_once()


@patch("modelscan.modelscan.ModelScan.scan")
def test_scan_file_escape_path(mock_scan: Mock):
mock_scan.return_value = {}
files = {"in_file": ("../foo.bar", rb"", "application/x-hdf5")}

response = client.post("/scan/file", files=files)

assert response.status_code == 200
mock_scan.assert_called_once()


def test_scan_file_escape_path_error():
files = {"in_file": ("..", rb"", "text/plain")}

response = client.post("/scan/file", files=files)

assert response.status_code == 500
assert response.json() == {"detail": "An error occurred while processing the uploaded file's name."}


@patch("modelscan.modelscan.ModelScan.scan")
def test_scan_file_exception(mock_scan: Mock):
mock_scan.side_effect = Exception("Mocked error!")
Expand Down

0 comments on commit 4c52676

Please sign in to comment.