Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Load / Save Plugin #1114

Merged
merged 16 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions bandit/blacklists/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,19 @@
| | | - os.tmpnam | |
+------+---------------------+------------------------------------+-----------+

B704: pytorch_load_save

Use of unsafe PyTorch load. `torch.load` can lead to arbitrary code execution,
and improper use of `torch.save` might expose sensitive data or lead to data
corruption.

+------+---------------------+--------------------------------------+-----------+
| ID | Name | Calls | Severity |
+======+=====================+======================================+===========+
| B704 | pytorch_load_save| | - torch.load | Medium |
| B704 | pytorch_load_save| | - torch.save | Medium |
+------+---------------------+--------------------------------------+-----------+
lukehinds marked this conversation as resolved.
Show resolved Hide resolved

"""
import sys

Expand Down Expand Up @@ -685,6 +698,17 @@ def gen_blacklist():
)
)

sets.append(
utils.build_conf_dict(
"pytorch_load_save",
"B704",
issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
["torch.load", "torch.save"],
"Use of unsafe PyTorch load or save",
"MEDIUM",
)
)

# skipped B324 (used in bandit/plugins/hashlib_new_insecure_functions.py)

# skipped B325 as the check for a call to os.tempnam and os.tmpnam have
Expand Down
68 changes: 68 additions & 0 deletions bandit/plugins/pytorch_load_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2024 Stacklok, Inc.
#
# SPDX-License-Identifier: Apache-2.0
r"""
=========================================
B704: Test for unsafe PyTorch load or save
=========================================
lukehinds marked this conversation as resolved.
Show resolved Hide resolved

This plugin checks for the use of `torch.load` and `torch.save`. Using `torch.load`
with untrusted data can lead to arbitrary code execution, and improper use of
`torch.save` might expose sensitive data or lead to data corruption.

:Example:

.. code-block:: none

>> Issue: Use of unsafe PyTorch load or save
Severity: Medium Confidence: High
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
Location: examples/pytorch_load_save.py:8
7 loaded_model.load_state_dict(torch.load('model_weights.pth'))
8 another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
9
10 print("Model loaded successfully!")

.. seealso::

- https://cwe.mitre.org/data/definitions/94.html

lukehinds marked this conversation as resolved.
Show resolved Hide resolved
.. versionadded:: 1.7.8
lukehinds marked this conversation as resolved.
Show resolved Hide resolved

"""
import bandit
from bandit.core import issue
from bandit.core import test_properties as test


@test.checks("Call")
@test.test_id(
"B704"
) # Ensure the test ID is unique and does not conflict with existing Bandit tests
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
def pytorch_load_save(context):
"""
This plugin checks for the use of `torch.load` and `torch.save`. Using `torch.load`
with untrusted data can lead to arbitrary code execution, and improper use of
`torch.save` might expose sensitive data or lead to data corruption.
"""
imported = context.is_module_imported_exact("torch")
qualname = context.call_function_name_qual
if not imported and isinstance(qualname, str):
return

qualname_list = qualname.split(".")
func = qualname_list[-1]
if all(
[
"torch" in qualname_list,
func in ["load"],
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
not context.check_call_arg_value("map_location", "cpu"),
]
):
return bandit.Issue(
severity=bandit.MEDIUM,
confidence=bandit.HIGH,
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
text="Use of unsafe PyTorch load or save",
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
cwe=issue.Cwe.UNTRUSTED_INPUT,
lineno=context.get_lineno_for_call_arg("load"),
)
5 changes: 5 additions & 0 deletions doc/source/plugins/b704_pytorch_load_save.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-----------------------
B704: pytorch_load_save
-----------------------

.. automodule:: bandit.plugins.pytorch_load_save
16 changes: 16 additions & 0 deletions examples/pytorch_load_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
import torchvision.models as models

# Example of saving a model
model = models.resnet18(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# Example of loading the model weights in an insecure way
loaded_model = models.resnet18()
loaded_model.load_state_dict(torch.load('model_weights.pth'))

# Another example using torch.load with more parameters
another_model = models.resnet18()
another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
lukehinds marked this conversation as resolved.
Show resolved Hide resolved

print("Model loaded successfully!")
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ bandit.plugins =
#bandit/plugins/tarfile_unsafe_members.py
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members

#bandit/plugins/pytorch_load_save.py
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save

[build_sphinx]
all_files = 1
build-dir = doc/build
Expand Down
8 changes: 8 additions & 0 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,3 +930,11 @@ def test_tarfile_unsafe_members(self):
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 2},
}
self.check_example("tarfile_extractall.py", expect)

def test_pytorch_load_save(self):
"""Test insecure usage of torch.load and torch.save."""
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 1, "MEDIUM": 3, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 1, "HIGH": 3},
}
self.check_example("pytorch_load_save.py", expect)
Loading