Skip to content

Commit

Permalink
fix: check paths in an archive file before extracting (#366)
Browse files Browse the repository at this point in the history
The paths in an archive file are checked for path traversal patterns before extraction. Also, Bandit v1.7.5 is producing false positives for request timeout arguments, which have been suppressed.

Signed-off-by: behnazh-w <behnaz.hassanshahi@oracle.com>
  • Loading branch information
behnazh-w authored Jul 17, 2023
1 parent d5bea3c commit 74c9637
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
30 changes: 22 additions & 8 deletions src/macaron/slsa_analyzer/checks/provenance_l3_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,21 +239,35 @@ def _extract_archive(self, file_path: str, temp_path: str) -> bool:
bool
Returns True if successful.
"""

def _validate_path_traversal(path: str) -> bool:
"""Check for path traversal attacks."""
if path.startswith("/") or ".." in path:
logger.debug("Found suspicious path in the archive file: %s.", path)
return False
try:
# Check if there are any symbolic links.
if os.path.realpath(path):
return True
except OSError as error:
logger.debug("Failed to extract artifact from archive file: %s", error)
return False
return False

try:
if zipfile.is_zipfile(file_path):
with zipfile.ZipFile(file_path, "r") as zip_file:
zip_file.extractall(temp_path)
members = (path for path in zip_file.namelist() if _validate_path_traversal(path))
zip_file.extractall(temp_path, members=members) # nosec B202:tarfile_unsafe_members
return True
elif tarfile.is_tarfile(file_path):
with tarfile.open(file_path, mode="r:gz") as tar_file:
tar_file.extractall(temp_path)
members_tarinfo = (
tarinfo for tarinfo in tar_file.getmembers() if _validate_path_traversal(tarinfo.name)
)
tar_file.extractall(temp_path, members=members_tarinfo) # nosec B202:tarfile_unsafe_members
return True
except (
tarfile.TarError,
zipfile.BadZipFile,
zipfile.LargeZipFile,
OSError,
) as error:
except (tarfile.TarError, zipfile.BadZipFile, zipfile.LargeZipFile, OSError, ValueError) as error:
logger.info(error)

return False
Expand Down
20 changes: 15 additions & 5 deletions src/macaron/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def send_get_http(url: str, headers: dict) -> dict:
The response's json data or an empty dict if there is an error.
"""
logger.debug("GET - %s", url)
response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10))
response = requests.get(
url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)
) # nosec B113:request_without_timeout
while response.status_code != 200:
logger.error(
"Receiving error code %s from server. Message: %s.",
Expand All @@ -47,7 +49,9 @@ def send_get_http(url: str, headers: dict) -> dict:
check_rate_limit(response)
else:
return {}
response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10))
response = requests.get(
url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)
) # nosec B113:request_without_timeout

return dict(response.json())

Expand All @@ -70,7 +74,9 @@ def send_get_http_raw(url: str, headers: dict) -> Response | None:
The response object or None if there is an error.
"""
logger.debug("GET - %s", url)
response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10))
response = requests.get(
url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)
) # nosec B113:request_without_timeout
while response.status_code != 200:
logger.error(
"Receiving error code %s from server. Message: %s.",
Expand All @@ -81,7 +87,9 @@ def send_get_http_raw(url: str, headers: dict) -> Response | None:
check_rate_limit(response)
else:
return None
response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10))
response = requests.get(
url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)
) # nosec B113:request_without_timeout

return response

Expand Down Expand Up @@ -155,7 +163,9 @@ def download_github_build_log(url: str, headers: dict) -> str:
The content of the downloaded build log or empty if error.
"""
logger.debug("Downloading content at link %s", url)
response = requests.get(url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10))
response = requests.get(
url=url, headers=headers, timeout=defaults.getint("requests", "timeout", fallback=10)
) # nosec B113:request_without_timeout

return response.content.decode("utf-8")

Expand Down

0 comments on commit 74c9637

Please sign in to comment.