From d98cab18d566ad4e2aa21180bab4f4b67368422f Mon Sep 17 00:00:00 2001 From: behnazh-w Date: Fri, 14 Jul 2023 15:52:48 +1000 Subject: [PATCH] fix: check paths in an archive file before extracting The paths in an archive file are checked for path traversal patterns before extraction. Also, Bandit v1.7.5 is producing false positives for request timeout arguments, which have been suppressed. Signed-off-by: behnazh-w --- .../checks/provenance_l3_check.py | 30 ++++++++++++++----- src/macaron/util.py | 20 +++++++++---- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/macaron/slsa_analyzer/checks/provenance_l3_check.py b/src/macaron/slsa_analyzer/checks/provenance_l3_check.py index 3064f822c..9b778acc3 100644 --- a/src/macaron/slsa_analyzer/checks/provenance_l3_check.py +++ b/src/macaron/slsa_analyzer/checks/provenance_l3_check.py @@ -239,21 +239,35 @@ def _extract_archive(self, file_path: str, temp_path: str) -> bool: bool Returns True if successful. """ + + def _validate_path_traversal(path: str) -> bool: + """Check for path traversal attacks.""" + if path.startswith("/") or ".." in path: + logger.debug("Found suspicious path in the archive file: %s.", path) + return False + try: + # Check if there are any symbolic links. + if os.path.realpath(path): + return True + except OSError as error: + logger.debug("Failed to extract artifact from archive file: %s", error) + return False + return False + try: if zipfile.is_zipfile(file_path): with zipfile.ZipFile(file_path, "r") as zip_file: - zip_file.extractall(temp_path) + members = (path for path in zip_file.namelist() if _validate_path_traversal(path)) + zip_file.extractall(temp_path, members=members) # nosec B202:tarfile_unsafe_members return True elif tarfile.is_tarfile(file_path): with tarfile.open(file_path, mode="r:gz") as tar_file: - tar_file.extractall(temp_path) + members_tarinfo = ( + tarinfo for tarinfo in tar_file.getmembers() if _validate_path_traversal(tarinfo.name) + ) + tar_file.extractall(temp_path, members=members_tarinfo) # nosec B202:tarfile_unsafe_members return True - except ( - tarfile.TarError, - zipfile.BadZipFile, - zipfile.LargeZipFile, - OSError, - ) as error: + except (tarfile.TarError, zipfile.BadZipFile, zipfile.LargeZipFile, OSError, ValueError) as error: logger.info(error) return False diff --git a/src/macaron/util.py b/src/macaron/util.py index 91758c720..2db5a3d56 100644 --- a/src/macaron/util.py +++ b/src/macaron/util.py @@ -36,7 +36,9 @@ def send_get_http(url: str, headers: dict) -> dict: The response's json data or an empty dict if there is an error. """ logger.debug("GET - %s", url) - response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)) + response = requests.get( + url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10) + ) # nosec B113:request_without_timeout while response.status_code != 200: logger.error( "Receiving error code %s from server. Message: %s.", @@ -47,7 +49,9 @@ def send_get_http(url: str, headers: dict) -> dict: check_rate_limit(response) else: return {} - response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)) + response = requests.get( + url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10) + ) # nosec B113:request_without_timeout return dict(response.json()) @@ -70,7 +74,9 @@ def send_get_http_raw(url: str, headers: dict) -> Response | None: The response object or None if there is an error. """ logger.debug("GET - %s", url) - response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)) + response = requests.get( + url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10) + ) # nosec B113:request_without_timeout while response.status_code != 200: logger.error( "Receiving error code %s from server. Message: %s.", @@ -81,7 +87,9 @@ def send_get_http_raw(url: str, headers: dict) -> Response | None: check_rate_limit(response) else: return None - response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)) + response = requests.get( + url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10) + ) # nosec B113:request_without_timeout return response @@ -155,7 +163,9 @@ def download_github_build_log(url: str, headers: dict) -> str: The content of the downloaded build log or empty if error. """ logger.debug("Downloading content at link %s", url) - response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)) + response = requests.get( + url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10) + ) # nosec B113:request_without_timeout return response.content.decode("utf-8")