Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type: Infer parameters types for __init__(default...), __get__ and __set__ #985

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions param/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from itertools import chain
from operator import itemgetter, attrgetter
from types import FunctionType, MethodType
from typing import TypeVar, Generic

from contextlib import contextmanager
from logging import DEBUG, INFO, WARNING, ERROR, CRITICAL
Expand All @@ -52,6 +53,8 @@
gen_types,
)

T = TypeVar("T")

# Ideally setting param_pager would be in __init__.py but param_pager is
# needed on import to create the Parameterized class, so it'd need to precede
# importing parameterized.py in __init__.py which would be a little weird.
Expand Down Expand Up @@ -1005,7 +1008,9 @@ def _sorter(p):
cls.__signature__ = new_sig


class Parameter(_ParameterBase):


class Parameter(Generic[T], _ParameterBase):
"""
An attribute descriptor for declaring parameters.

Expand Down Expand Up @@ -1431,7 +1436,7 @@ def _update_state(self):
values, after the slot values have been set in the inheritance procedure.
"""

def __get__(self, obj, objtype): # pylint: disable-msg=W0613
def __get__(self, obj, objtype) -> T:
"""
Return the value for this Parameter.

Expand All @@ -1455,7 +1460,7 @@ def __get__(self, obj, objtype): # pylint: disable-msg=W0613
return result

@instance_descriptor
def __set__(self, obj, val):
def __set__(self, obj, val: T) -> None:
"""
Set the value for this Parameter.

Expand Down Expand Up @@ -1603,7 +1608,7 @@ def __setstate__(self,state):


# Define one particular type of Parameter that is used in this file
class String(Parameter):
class String(Parameter[T]):
r"""
A String Parameter, with a default value and optional regular expression (regex) matching.

Expand All @@ -1624,15 +1629,15 @@ def __init__(self, default="0.0.0.0", allow_None=False, **kwargs):
@typing.overload
def __init__(
self,
default="", *, regex=None,
default: T = "", *, regex=None,
doc=None, label=None, precedence=None, instantiate=False, constant=False,
readonly=False, pickle_default_value=True, allow_None=False, per_instance=True,
allow_refs=False, nested_refs=False
):
...

@_deprecate_positional_args
def __init__(self, default=Undefined, *, regex=Undefined, **kwargs):
def __init__(self, default: T = Undefined, *, regex=Undefined, **kwargs):
super().__init__(default=default, **kwargs)
self.regex = regex
self._validate(self.default)
Expand Down
Loading
Loading