Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Load / Save Plugin #1114

Merged
merged 16 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions bandit/plugins/pytorch_load_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) 2024 Stacklok, Inc.
#
# SPDX-License-Identifier: Apache-2.0
r"""
==========================================
B613: Test for unsafe PyTorch load or save
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
==========================================

This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution, and
improper use of `torch.save` might expose sensitive data or lead to data
corruption. A safe alternative is to use `torch.load` with the `safetensors`
library from hugingface, which provides a safe deserialization mechanism.

:Example:

.. code-block:: none

>> Issue: Use of unsafe PyTorch load or save
Severity: Medium Confidence: High
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
Location: examples/pytorch_load_save.py:8
7 loaded_model.load_state_dict(torch.load('model_weights.pth'))
8 another_model.load_state_dict(torch.load('model_weights.pth',
map_location='cpu'))
9
10 print("Model loaded successfully!")

.. seealso::

- https://cwe.mitre.org/data/definitions/94.html
- https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
- https://github.com/huggingface/safetensors

lukehinds marked this conversation as resolved.
Show resolved Hide resolved
.. versionadded:: 1.7.10

"""
import bandit
from bandit.core import issue
from bandit.core import test_properties as test


@test.checks("Call")
@test.test_id("B613")
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
def pytorch_load_save(context):
"""
This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution,
and improper use of `torch.save` might expose sensitive data or lead
to data corruption.
"""
imported = context.is_module_imported_exact("torch")
qualname = context.call_function_name_qual
if not imported and isinstance(qualname, str):
return

qualname_list = qualname.split(".")
func = qualname_list[-1]
if all(
[
"torch" in qualname_list,
func in ["load", "save"],
not context.check_call_arg_value("map_location", "cpu"),
]
):
return bandit.Issue(
severity=bandit.MEDIUM,
confidence=bandit.HIGH,
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
text="Use of unsafe PyTorch load or save",
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
lineno=context.get_lineno_for_call_arg("load"),
)
5 changes: 5 additions & 0 deletions doc/source/plugins/b704_pytorch_load_save.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-----------------------
B613: pytorch_load_save
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
-----------------------

.. automodule:: bandit.plugins.pytorch_load_save
21 changes: 21 additions & 0 deletions examples/pytorch_load_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import torchvision.models as models

# Example of saving a model
model = models.resnet18(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# Example of loading the model weights in an insecure way
loaded_model = models.resnet18()
loaded_model.load_state_dict(torch.load('model_weights.pth'))

# Save the model
torch.save(loaded_model.state_dict(), 'model_weights.pth')

# Another example using torch.load with more parameters
another_model = models.resnet18()
another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
lukehinds marked this conversation as resolved.
Show resolved Hide resolved

# Save the model
torch.save(another_model.state_dict(), 'model_weights.pth')

3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ bandit.plugins =
#bandit/plugins/tarfile_unsafe_members.py
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members

#bandit/plugins/pytorch_load_save.py
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save

# bandit/plugins/trojansource.py
trojansource = bandit.plugins.trojansource:trojansource

Expand Down
8 changes: 8 additions & 0 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,14 @@ def test_tarfile_unsafe_members(self):
}
self.check_example("tarfile_extractall.py", expect)

def test_pytorch_load_save(self):
"""Test insecure usage of torch.load and torch.save."""
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 4, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 4},
}
self.check_example("pytorch_load_save.py", expect)

def test_trojansource(self):
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 1},
Expand Down
Loading