Skip to content

Commit

Permalink
ENH remove trusted=True from skops.io.load(s) (#422)
Browse files Browse the repository at this point in the history
* SEC remove trusted=True

* FIX make the rest of the fixes and make sure tests pass

* TST add test for trusted=True

* DOC add changelog

* DOC update PR number

* ENH make the error message more helpful

* Address Benjamin's comments

* CLN fix a typecheck to what we really want
  • Loading branch information
adrinjalali authored Jun 13, 2024
1 parent f69b928 commit 25a9d99
Show file tree
Hide file tree
Showing 21 changed files with 157 additions and 156 deletions.
5 changes: 5 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ v0.10
- Removes Pythn 3.8 support and adds Python 3.12 Support :pr:`418` by :user:`Thomas Lazarus <lazarust>`.
- Removes a shortcut to add `sklearn-intelex` as a not dependency.
:pr:`420` by :user:`Thomas Lazarus < lazarust > `.
- ``trusted=True`` is now removed from ``skops.io.load`` and ``skops.io.loads``.
This is to further encourage users to inspect the input data before loading
it. :func:`skops.io.get_untrusted_types` can be used to get the untrusted types
present in the input.
:pr:`422` by `Adrin Jalali`_.

v0.9
----
Expand Down
39 changes: 17 additions & 22 deletions docs/persistence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ The code snippet below illustrates how to use :func:`skops.io.dump` and
from xgboost.sklearn import XGBClassifier
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.datasets import load_iris
from skops.io import dump, load
from skops.io import dump, load, get_untrusted_types
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
Expand All @@ -64,26 +64,24 @@ The code snippet below illustrates how to use :func:`skops.io.dump` and
0.9666666666666667
dump(clf, "my-model.skops")
# ...
loaded = load("my-model.skops", trusted=True)
unknown_types = get_untrusted_types(file="my-model.skops")
print(unknown_types)
['sklearn.metrics._scorer._passthrough_scorer',
'xgboost.core.Booster', 'xgboost.sklearn.XGBClassifier']
loaded = load("my-model.skops", trusted=unknown_types)
print(loaded.score(X_test, y_test))
0.9666666666666667
# in memory
from skops.io import dumps, loads
serialized = dumps(clf)
loaded = loads(serialized, trusted=True)
Note that you should only load files with ``trusted=True`` if you trust the
source. Otherwise you can get a list of untrusted types present in the dump
using :func:`skops.io.get_untrusted_types`:

.. code:: python
loaded = loads(serialized, trusted=unknown_types)
from skops.io import get_untrusted_types
unknown_types = get_untrusted_types(file="my-model.skops")
print(unknown_types)
['sklearn.metrics._scorer._passthrough_scorer',
'xgboost.core.Booster', 'xgboost.sklearn.XGBClassifier']
Note that the ``get_untrusted_types`` function is used to check which types are
not trusted by default. The user can then decide whether to trust them or not.
In previous before version 0.10, users could pass ``trusted=True`` to skip the
audit phase, which is now removed to encourage users to validate the input
before loading.

Note that everything in the above list is safe to load. We already have many
types included as trusted by default, and some of the above values might be
Expand All @@ -92,10 +90,6 @@ added to that list in the future.
Once you check the list and you validate that everything in the list is safe,
you can load the file with ``trusted=unknown_types``:

.. code:: python
loaded = load("my-model.skops", trusted=unknown_types)
At the moment, we support the vast majority of sklearn estimators. This
includes complex use cases such as :class:`sklearn.pipeline.Pipeline`,
:class:`sklearn.model_selection.GridSearchCV`, classes using objects defined in
Expand Down Expand Up @@ -226,10 +220,11 @@ green to cyan. The ``rich`` docs list the `supported standard colors

Note that the visualization feature is intended to help understand the structure
of the object, e.g. what attributes are identified as untrusted. It is not a
replacement for a proper security check. In particular, just because an object's
visualization looks innocent does *not* mean you can just call `sio.load(<file>,
trusted=True)` on this object -- only pass the types you really trust to the
``trusted`` argument.
replacement for a proper security check of the included types in the file. In
particular, just because an object's visualization looks innocent does *not*
mean you can just call `sio.load(<file>,
trusted=get_untrusted_types(file=<file>))` on this object -- only pass the
types you really trust to the ``trusted`` argument.

Supported libraries
-------------------
Expand Down
4 changes: 2 additions & 2 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from hashlib import sha256
from pathlib import Path
from reprlib import Repr
from typing import Any, Iterator, Literal, Sequence, Union
from typing import Any, Iterator, List, Literal, Optional, Sequence, Union

import joblib
from huggingface_hub import ModelCardData
Expand Down Expand Up @@ -488,7 +488,7 @@ def __init__(
model_diagram: bool | Literal["auto"] | str = "auto",
metadata: ModelCardData | None = None,
template: Literal["skops"] | dict[str, str] | None = "skops",
trusted: bool = False,
trusted: Optional[List[str]] = None,
) -> None:
self.model = model
self.metadata = metadata or ModelCardData()
Expand Down
21 changes: 14 additions & 7 deletions skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
TableSection,
_load_model,
)
from skops.io import dump, load
from skops.io import dump, get_untrusted_types, load
from skops.utils.importutils import import_or_raise


Expand All @@ -51,10 +51,14 @@ def save_model_to_file(model_instance, suffix):
def test_load_model(suffix):
model0 = LinearRegression(n_jobs=123)
_, save_file = save_model_to_file(model0, suffix)
loaded_model_str = _load_model(save_file, trusted=True)
if suffix == ".skops":
untrusted_types = get_untrusted_types(file=save_file)
else:
untrusted_types = None
loaded_model_str = _load_model(save_file, trusted=untrusted_types)
save_file_path = Path(save_file)
loaded_model_path = _load_model(save_file_path, trusted=True)
loaded_model_instance = _load_model(model0, trusted=True)
loaded_model_path = _load_model(save_file_path, trusted=untrusted_types)
loaded_model_instance = _load_model(model0, trusted=untrusted_types)

assert loaded_model_str.n_jobs == 123
assert loaded_model_path.n_jobs == 123
Expand Down Expand Up @@ -1383,8 +1387,11 @@ def test_with_metadata(self, card: Card, meth, expected_lines):


class TestCardModelAttributeIsPath:
def path_to_card(self, path):
card = Card(model=path, trusted=True)
def path_to_card(self, path, suffix):
if suffix == ".skops":
card = Card(model=path, trusted=get_untrusted_types(file=path))
else:
card = Card(model=path)
return card

@pytest.mark.parametrize("meth", [repr, str])
Expand All @@ -1397,7 +1404,7 @@ def test_model_card_repr(self, meth, suffix):
model = LinearRegression(fit_intercept=False)
file_handle, file_name = save_model_to_file(model, suffix)
os.close(file_handle)
card_from_path = self.path_to_card(file_name)
card_from_path = self.path_to_card(file_name, suffix=suffix)

result0 = meth(card_from_path)
expected = "Card(\n model=LinearRegression(fit_intercept=False),"
Expand Down
4 changes: 2 additions & 2 deletions skops/cli/_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path

from skops.cli._utils import get_log_level
from skops.io import dump, load
from skops.io import dump, get_untrusted_types, load
from skops.io._protocol import PROTOCOL


Expand Down Expand Up @@ -48,7 +48,7 @@ def _update_file(
" file."
)

input_model = load(input_file, trusted=True)
input_model = load(input_file, trusted=get_untrusted_types(file=input_file))
with zipfile.ZipFile(input_file, "r") as zip_file:
input_file_schema = json.loads(zip_file.read("schema.json"))

Expand Down
4 changes: 2 additions & 2 deletions skops/cli/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from skops.cli import _convert
from skops.io import load
from skops.io import get_untrusted_types, load


class MockUnsafeType:
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_unsafe_case_works_as_expected(
):
caplog.set_level(logging.WARNING)
_convert._convert_file(pkl_path, skops_path)
persisted_obj = load(skops_path, trusted=True)
persisted_obj = load(skops_path, trusted=get_untrusted_types(file=skops_path))

assert isinstance(persisted_obj, MockUnsafeType)

Expand Down
35 changes: 12 additions & 23 deletions skops/io/_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import io
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Type, Union
from typing import Any, Dict, Generator, List, Optional, Sequence, Type, Union

from ._protocol import PROTOCOL
from ._utils import LoadContext, get_module, get_type_paths
Expand All @@ -14,9 +14,7 @@
]


def check_type(
module_name: str, type_name: str, trusted: Literal[True] | Sequence[str]
) -> bool:
def check_type(module_name: str, type_name: str, trusted: Sequence[str]) -> bool:
"""Check if a type is safe to load.
A type is safe to load only if it's present in the trusted list.
Expand All @@ -38,16 +36,13 @@ def check_type(
is_safe : bool
True if the type is safe, False otherwise.
"""
if trusted is True:
return True
return module_name + "." + type_name in trusted


def audit_tree(tree: Node) -> None:
"""Audit a tree of nodes.
A tree is safe if it only contains trusted types. Audit is skipped if
trusted is ``True``.
A tree is safe if it only contains trusted types.
Parameters
----------
Expand All @@ -59,9 +54,6 @@ def audit_tree(tree: Node) -> None:
UntrustedTypesFoundException
If the tree contains an untrusted type.
"""
if tree.trusted is True:
return

unsafe = tree.get_unsafe_set()
if unsafe:
raise UntrustedTypesFoundException(unsafe)
Expand Down Expand Up @@ -142,7 +134,7 @@ def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: bool | Sequence[str] = False,
trusted: Optional[Sequence[str]] = None,
memoize: bool = True,
) -> None:
self.class_name, self.module_name = state["__class__"], state["__module__"]
Expand Down Expand Up @@ -180,22 +172,19 @@ def _construct(self):

@staticmethod
def _get_trusted(
trusted: bool | Sequence[Union[str, Type]], default: Sequence[Union[str, Type]]
) -> Literal[True] | list[str]:
trusted: Optional[Sequence[Union[str, Type]]],
default: Sequence[Union[str, Type]],
) -> list[str]:
"""Return a trusted list, or True.
If ``trusted`` is ``False``, we return the ``default``. If a list of
If ``trusted`` is ``None``, we return the ``default``. If a list of
types are being passed, those types, as well as default trusted types,
are returned.
This is a convenience method called by child classes.
"""
if trusted is True:
# if trusted is True, we trust the node
return True

if trusted is False:
if trusted is None:
# if trusted is False, we only trust the defaults
return get_type_paths(default)

Expand Down Expand Up @@ -289,12 +278,12 @@ def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: bool = False,
trusted: Optional[List[str]] = None,
):
# we pass memoize as False because we don't want to memoize the cached
# node.
super().__init__(state, load_context, trusted, memoize=False)
self.trusted = True
self.trusted = self._get_trusted(trusted, default=[])
# TODO: deal with case that __id__ is unknown or prevent it from
# happening
self.cached = load_context.get_object(state.get("__id__")) # type: ignore
Expand All @@ -313,7 +302,7 @@ def _construct(self):
def get_tree(
state: dict[str, Any],
load_context: LoadContext,
trusted: bool | Sequence[str],
trusted: Optional[Sequence[str]],
) -> Node:
"""Get the tree of nodes.
Expand Down
Loading

0 comments on commit 25a9d99

Please sign in to comment.