Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT Add light type annotations to skops.io #219

Merged
merged 8 commits into from
Dec 1, 2022
304 changes: 299 additions & 5 deletions skops/io/_audit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
from skops.io.exceptions import UntrustedTypesFoundException
from __future__ import annotations

import io
from contextlib import contextmanager
from typing import Any, Generator, Sequence

def check_type(module_name, type_name, trusted):
from ..utils.fixes import Literal
from ._trusted_types import PRIMITIVE_TYPE_NAMES
from ._utils import LoadContext, get_module
from .exceptions import UntrustedTypesFoundException

NODE_TYPE_MAPPING = {} # type: ignore


def check_type(
module_name: str, type_name: str, trusted: Literal[True] | Sequence[str]
) -> bool:
"""Check if a type is safe to load.

A type is safe to load only if it's present in the trusted list.
Expand All @@ -14,7 +27,7 @@ def check_type(module_name, type_name, trusted):
type_name : str
The class name of the type.

trusted : bool, or list of str
trusted : True, or list of str
If ``True``, the tree is considered safe. Otherwise trusted has to be
a list of trusted types.

Expand All @@ -28,7 +41,7 @@ def check_type(module_name, type_name, trusted):
return module_name + "." + type_name in trusted


def audit_tree(tree, trusted):
def audit_tree(tree: Node, trusted: bool | Sequence[str]) -> None:
"""Audit a tree of nodes.

A tree is safe if it only contains trusted types. Audit is skipped if
Expand All @@ -39,7 +52,7 @@ def audit_tree(tree, trusted):
tree : skops.io._dispatch.Node
The tree to audit.

trusted : bool, or list of str
trusted : True, or list of str
If ``True``, the tree is considered safe. Otherwise trusted has to be
a list of trusted types names.

Expand All @@ -59,3 +72,284 @@ def audit_tree(tree, trusted):
unsafe -= set(trusted)
if unsafe:
raise UntrustedTypesFoundException(unsafe)


class UNINITIALIZED:
"""Sentinel value to indicate that a value has not been initialized yet."""


@contextmanager
def temp_setattr(obj: Any, **kwargs: Any) -> Generator[None, None, None]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does Generator[None, None, None] mean? typehints should helps us understand the code better, not to make them more confusing :D

Also, if everything is Any here, then why do we have typehints?

This is an example where I wouldn't add any typehints here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The meaning for Generator is Generator[YieldType, SendType, ReturnType], which happens to be None, None, None here. I agree that types don't add much here, though **kwargs: Any is a shorthand for dict[str, Any]. The difference for me between a function annotated with Any vs no annotation is that with the latter, I don't know if it takes any or if it was just not annotated yet, so I guess not completely useless.

I would thus keep the annotations here, but if you insist, I can remove them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think typehints should be useful, if they're not, they take extra space and make things less readable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Adrin

imo, type hints should really be used like docstrings, if having them just takes up space and doesn't make things clearer or help an IDE, they probably don't need to be there.

"""Context manager to temporarily set attributes on an object."""
existing_attrs = {k for k in kwargs.keys() if hasattr(obj, k)}
previous_values = {k: getattr(obj, k, None) for k in kwargs}
for k, v in kwargs.items():
setattr(obj, k, v)
try:
yield
finally:
for k, v in previous_values.items():
if k in existing_attrs:
setattr(obj, k, v)
else:
delattr(obj, k)


class Node:
"""A node in the tree of objects.

This class is a parent class for all nodes in the tree of objects. Each
type of object (e.g. dict, list, etc.) has its own subclass of Node.

Each child class has to implement two methods: ``__init__`` and
``_construct``.

``__init__`` takes care of traversing the state tree and to create the
corresponding ``Node`` objects. It has access to the ``load_context`` which
in turn has access to the source zip file. The child class's ``__init__``
must load attributes into the ``children`` attribute, which is a
dictionary of ``{child_name: unloaded_value/Node/list/etc}``. The
``get_unsafe_set`` should be able to parse and validate the values set
under the ``children`` attribute. Note that primitives are persisted as a
``JsonNode``.

``_construct`` takes care of constructing the object. It is only called
once and the result is cached in ``construct`` which is implemented in this
class. All required data to construct an instance should be loaded during
``__init__``.

The separation of ``__init__`` and ``_construct`` is necessary because
audit methods are called after ``__init__`` and before ``construct``.
Therefore ``__init__`` should avoid creating any instances or importing
any modules, to avoid running potentially untrusted code.

Parameters
----------
state : dict
A dict representing the state of the dumped object.

load_context : LoadContext
The context of the loading process.

trusted : bool or list of str, default=False
If ``True``, the object will be loaded without any security checks. If
``False``, the object will be loaded only if there are only trusted
objects in the dumped file. If a list of strings, the object will be
loaded only if all of its required types are listed in ``trusted``
or are trusted by default.

memoize : bool, default=True
If ``True``, the object will be memoized in the load context, if it has
the ``__id__`` set. This is used to avoid loading the same object
multiple times.
"""

def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: bool | Sequence[str] = False,
memoize: bool = True,
) -> None:
self.class_name, self.module_name = state["__class__"], state["__module__"]
self._is_safe = None
self._constructed = UNINITIALIZED
saved_id = state.get("__id__")
if saved_id and memoize:
# hold reference to obj in case same instance encountered again in
# save state
load_context.memoize(self, saved_id)

# subclasses should always:
# 1. call super().__init__()
# 2. set self.trusted = self._get_trusted(trusted, ...) where ... is a
# list of appropriate trusted types
# 3. set self.children, where children are states of child nodes; do not
# construct the children objects yet
self.trusted = self._get_trusted(trusted, [])
self.children: dict[str, Any] = {}

def construct(self) -> Any:
"""Construct the object.

We only construct the object once, and then cache the result.
"""
if self._constructed is not UNINITIALIZED:
return self._constructed
self._constructed = self._construct()
return self._constructed

def _construct(self) -> Any:
raise NotImplementedError(
f"{self.__class__.__name__} should implement a 'construct' method"
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
)

@staticmethod
def _get_trusted(
trusted: bool | Sequence[str], default: list[str]
) -> Literal[True] | list[str]:
"""Return a trusted list, or True.

If ``trusted`` is ``False``, we return the ``default``, otherwise the
``trusted`` value is used.

This is a convenience method called by child classes.
"""
if trusted is True:
# if trusted is True, we trust the node
return True

if trusted is False:
# if trusted is False, we only trust the defaults
return default

# otherwise we trust the given list, call list in case it's a tuple
return list(trusted)

def is_self_safe(self) -> bool:
"""True only if the node's type is considered safe.

This property only checks the type of the node, not its children.
"""
return check_type(self.module_name, self.class_name, self.trusted)

def is_safe(self) -> bool:
"""True only if the node and all its children are safe."""
# if trusted is set to True, we don't do any safety checks.
if self.trusted is True:
return True
E-Aho marked this conversation as resolved.
Show resolved Hide resolved

return len(self.get_unsafe_set()) == 0

def get_unsafe_set(self) -> set[str]:
"""Get the set of unsafe types.

This method returns all types which are not trusted, including this
node and all its children.

Returns
-------
unsafe_set : set
A set of unsafe types.
"""
if hasattr(self, "_computing_unsafe_set"):
# this means we're already computing this node's unsafe set, so we
# return an empty set and let the computation of the parent node
# continue. This is to avoid infinite recursion.
return set()

with temp_setattr(self, _computing_unsafe_set=True):
res = set()
if not self.is_self_safe():
res.add(self.module_name + "." + self.class_name)

for child in self.children.values():
if child is None:
continue

# Get the safety set based on the type of the child. In most cases
# other than ListNode and DictNode, children are all of type Node.
if isinstance(child, list):
# iterate through the list
for value in child:
res.update(value.get_unsafe_set())
elif isinstance(child, dict):
# iterate through the values of the dict only
# TODO: should we check the types of the keys?
for value in child.values():
res.update(value.get_unsafe_set())
elif isinstance(child, Node):
# delegate to the child Node
res.update(child.get_unsafe_set())
elif type(child) is type:
# the if condition bellow is not merged with the previous
# one because if the above condition is True, the following
# conditions about BytesIO, etc should be ignored.
if not check_type(get_module(child), child.__name__, self.trusted):
# if the child is a type, we check its safety
res.add(get_module(child) + "." + child.__name__)
elif isinstance(child, io.BytesIO):
# We trust BytesIO objects, which are read by other
# libraries such as numpy, scipy.
continue
elif check_type(
get_module(child), child.__class__.__name__, PRIMITIVE_TYPE_NAMES
):
# if the child is a primitive type, we don't need to check its
# safety.
continue
else:
raise ValueError(
f"Cannot determine the safety of type {type(child)}. Please"
" open an issue at https://github.com/skops-dev/skops/issues"
" for us to fix the issue."
)

return res


class CachedNode(Node):
def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: bool = False,
):
# we pass memoize as False because we don't want to memoize the cached
# node.
super().__init__(state, load_context, trusted, memoize=False)
self.trusted = True
# TODO: deal with case that __id__ is unknown or prevent it from
# happening
self.cached = load_context.get_object(state.get("__id__")) # type: ignore
self.children = {} # type: ignore

def _construct(self):
# TODO: FIXME This causes a recursion error when loading a cached
# object if we call the cached object's `construct``. Some refactoring
# is needed to fix this.
return self.cached.construct()


NODE_TYPE_MAPPING["CachedNode"] = CachedNode


def get_tree(state: dict[str, Any], load_context: LoadContext) -> Node:
"""Get the tree of nodes.

This function returns the root node of the tree of nodes. The tree is
constructed recursively by traversing the state tree. No instances are
created during this process. One would need to call ``construct`` on the
root node to create the instances.

This function also handles memoization of the nodes. If a node has already
been created, it is returned instead of creating a new one.

Parameters
----------
state : dict
The state of the dumped object.

load_context : LoadContext
The context of the loading process.
"""
saved_id = state.get("__id__")
if saved_id in load_context.memo:
# This means the node is already loaded, so we return it. Note that the
# node is not constructed at this point. It will be constructed when
# the parent node's ``construct`` method is called, and for this node
# it'll be called more than once. But that's not an issue since the
# node's ``construct`` method caches the instance.
return load_context.get_object(saved_id)

try:
node_cls = NODE_TYPE_MAPPING[state["__loader__"]]
except KeyError:
type_name = f"{state['__module__']}.{state['__class__']}"
raise TypeError(
f" Can't find loader {state['__loader__']} for type {type_name}."
)

loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore

return loaded_tree
Loading