Skip to content

Commit

Permalink
Add pytypes property to parameters to return Python type declaration
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Jan 9, 2023
1 parent 111cfc1 commit 4376014
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 2 deletions.
129 changes: 127 additions & 2 deletions param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import re
import datetime as dt
import collections
import numbers
import typing

from .parameterized import (
Parameterized, Parameter, String, ParameterizedFunction, ParamOverrides,
Expand Down Expand Up @@ -692,7 +694,6 @@ def _force(self,obj,objtype=None):
return gen


import numbers
def _is_number(obj):
if isinstance(obj, numbers.Number): return True
# The extra check is for classes that behave like numbers, such as those
Expand Down Expand Up @@ -752,6 +753,10 @@ def __init__(self, default=b"", regex=None, allow_None=False, **kwargs):
self.allow_None = (default is None or allow_None)
self._validate(default)

@property
def pytype(self):
return typing.Union[bytes, None] if self.allow_None else bytes

def _validate_regex(self, val, regex):
if (val is None and self.allow_None):
return
Expand Down Expand Up @@ -834,6 +839,10 @@ def __init__(self, default=0.0, bounds=None, softbounds=None,
self.step = step
self._validate(default)

@property
def pytype(self):
return typing.Union[numbers.Number, None] if self.allow_None else numbers.Number

def __get__(self, obj, objtype):
"""
Same as the superclass's __get__, but if the value was
Expand Down Expand Up @@ -963,6 +972,10 @@ class Integer(Number):
def __init__(self, default=0, **params):
Number.__init__(self, default=default, **params)

@property
def pytype(self):
return typing.Union[int, None] if self.allow_None else int

def _validate_value(self, val, allow_None):
if callable(val):
return
Expand Down Expand Up @@ -1000,6 +1013,10 @@ def __init__(self, default=False, bounds=(0,1), **params):
self.bounds = bounds
super(Boolean, self).__init__(default=default, **params)

@property
def pytype(self):
return typing.Union[bool, None] if self.allow_None else bool

def _validate_value(self, val, allow_None):
if allow_None:
if not isinstance(val, bool) and val is not None:
Expand Down Expand Up @@ -1034,6 +1051,14 @@ def __init__(self, default=(0,0), length=None, **params):
self.length = length
self._validate(default)

@property
def pytype(self):
if self.length:
pytype = typing.Tuple[(typing.Any,)*self.length]
else:
ptype = typing.Tuple[typing.Any, ...]
return typing.Union[pytype, None] if self.allow_None else pytype

def _validate_value(self, val, allow_None):
if val is None and allow_None:
return
Expand Down Expand Up @@ -1071,6 +1096,14 @@ def deserialize(cls, value):
class NumericTuple(Tuple):
"""A numeric tuple Parameter (e.g. (4.5,7.6,3)) with a fixed tuple length."""

@property
def pytype(self):
if self.length:
pytype = typing.Tuple[(numbers.Number,)*self.length]
else:
ptype = typing.Tuple[numbers.Number, ...]
return typing.Union[pytype, None] if self.allow_None else pytype

def _validate_value(self, val, allow_None):
super(NumericTuple, self)._validate_value(val, allow_None)
if allow_None and val is None:
Expand All @@ -1088,6 +1121,11 @@ class XYCoordinates(NumericTuple):
def __init__(self, default=(0.0, 0.0), **params):
super(XYCoordinates,self).__init__(default=default, length=2, **params)

@property
def pytype(self):
pytype = typing.Tuple[numbers.Number, numbers.Number]
return typing.Union[pytype, None] if self.allow_None else pytype


class Callable(Parameter):
"""
Expand All @@ -1099,6 +1137,11 @@ class Callable(Parameter):
2.4, so instantiate must be False for those values.
"""

@property
def pytype(self):
ctype = typing.Callable[..., typing.Any]
return typing.Union[ctype, None] if self.allow_None else ctype

def _validate_value(self, val, allow_None):
if (allow_None and val is None) or callable(val):
return
Expand Down Expand Up @@ -1197,6 +1240,11 @@ class SelectorBase(Parameter):

__abstract = True

@property
def pytype(self):
literal = typing.Literal[tuple(self.get_range().values())]
return typing.Union[literal, None] if self.allow_None else literal

def get_range(self):
raise NotImplementedError("get_range() must be implemented in subclasses.")

Expand Down Expand Up @@ -1433,6 +1481,17 @@ def __init__(self, default=[], class_=None, item_type=None,
**params)
self._validate(default)

@property
def pytype(self):
if isinstance(self.item_type, tuple):
item_type = typing.Union[self.item_type]
elif self.item_type is not None:
item_type = self.item_type
else:
item_type = typing.Any
list_type = typing.List[item_type]
return typing.Union[list_type, None] if self.allow_None else list_type

def _validate(self, val):
"""
Checks that the value is numeric and that it is within the hard
Expand Down Expand Up @@ -1487,6 +1546,11 @@ class HookList(List):
"""
__slots__ = ['class_', 'bounds']

@property
def pytype(self):
list_type = typing.List[typing.Callable[[], None]]
return typing.Union[list_type, None] if self.allow_None else list_type

def _validate_value(self, val, allow_None):
super(HookList, self)._validate_value(val, allow_None)
if allow_None and val is None:
Expand All @@ -1506,16 +1570,27 @@ class Dict(ClassSelector):
def __init__(self, default=None, **params):
super(Dict, self).__init__(dict, default=default, **params)

@property
def pytype(self):
dict_type = typing.Dict[typing.Hashable, typing.Any]
return typing.Union[dict_type, None] if self.allow_None else dict_type



class Array(ClassSelector):
"""
Parameter whose value is a numpy array.
"""

def __init__(self, default=None, **params):
from numpy import ndarray
from numpy import ndarray
super(Array, self).__init__(ndarray, allow_None=True, default=default, **params)

@property
def pytype(self):
from numpy import ndarray
return ndarray

@classmethod
def serialize(cls, value):
if value is None:
Expand Down Expand Up @@ -1559,6 +1634,11 @@ def __init__(self, default=None, rows=None, columns=None, ordered=None, **params
super(DataFrame,self).__init__(pdDFrame, default=default, **params)
self._validate(self.default)

@property
def pytype(self):
from pandas import DataFrame
return DataFrame

def _length_bounds_check(self, bounds, length, name):
message = '{name} length {length} does not match declared bounds of {bounds}'
if not isinstance(bounds, tuple):
Expand Down Expand Up @@ -1635,6 +1715,11 @@ def __init__(self, default=None, rows=None, allow_None=False, **params):
**params)
self._validate(self.default)

@property
def pytype(self):
from pandas import Series
return Series

def _length_bounds_check(self, bounds, length, name):
message = '{name} length {length} does not match declared bounds of {bounds}'
if not isinstance(bounds, tuple):
Expand Down Expand Up @@ -1778,6 +1863,13 @@ def __init__(self, default=None, search_paths=None, **params):
self.search_paths = search_paths
super(Path,self).__init__(default,**params)

@property
def pytype(self):
path_types =(str, pathlib.Path)
if self.allow_None:
path_types += (None,)
return typing.Union[path_types]

def _resolve(self, path):
return resolve_path(path, path_to_file=None, search_paths=self.search_paths)

Expand Down Expand Up @@ -1904,6 +1996,12 @@ def __init__(self, default=None, objects=None, **kwargs):
super(ListSelector,self).__init__(
objects=objects, default=default, empty_default=True, **kwargs)

@property
def pytype(self):
literal = typing.Literal[tuple(self.get_range().values())]
ltype = typing.List[literal]
return typing.Union[ltype, None] if self.allow_None else ltype

def compute_default(self):
if self.default is None and callable(self.compute_default_fn):
self.default = self.compute_default_fn()
Expand Down Expand Up @@ -1954,6 +2052,14 @@ class Date(Number):
def __init__(self, default=None, **kwargs):
super(Date, self).__init__(default=default, **kwargs)

@property
def pytype(self):
if self.allow_None:
date_types = dt_types + (None,)
else:
date_types = dt_types
return typing.Union[date_types]

def _validate_value(self, val, allow_None):
"""
Checks that the value is numeric and that it is within the hard
Expand Down Expand Up @@ -1998,6 +2104,10 @@ class CalendarDate(Number):
def __init__(self, default=None, **kwargs):
super(CalendarDate, self).__init__(default=default, **kwargs)

@property
def pytype(self):
return typing.Union[dt.datetime, None] if self.allow_None else dt.datetime

def _validate_value(self, val, allow_None):
"""
Checks that the value is numeric and that it is within the hard
Expand Down Expand Up @@ -2076,6 +2186,10 @@ def __init__(self, default=None, allow_named=True, **kwargs):
self.allow_named = allow_named
self._validate(default)

@property
def pytype(self):
return typing.Union[str, None] if self.allow_None else str

def _validate(self, val):
self._validate_value(val, self.allow_None)
self._validate_allow_named(val, self.allow_named)
Expand Down Expand Up @@ -2151,6 +2265,12 @@ class DateRange(Range):
Bounds must be specified as datetime or date types (see param.dt_types).
"""

@property
def pytype(self):
date_type = typing.Union[dt_types]
range_type = typing.Tuple[date_type, date_type]
return typing.Union[range_type, None] if self.allow_None else range_type

def _validate_value(self, val, allow_None):
# Cannot use super()._validate_value as DateRange inherits from
# NumericTuple which check that the tuple values are numbers and
Expand Down Expand Up @@ -2205,6 +2325,7 @@ def deserialize(cls, value):
# As JSON has no tuple representation
return tuple(deserialized)


class CalendarDateRange(Range):
"""
A date range specified as (start_date, end_date).
Expand Down Expand Up @@ -2281,6 +2402,10 @@ def __init__(self,default=False,bounds=(0,1),**params):
# back to False while triggered callbacks are executing
super(Event, self).__init__(default=default,**params)

@property
def pytype(self):
return bool

def _reset_event(self, obj, val):
val = False
if obj is None:
Expand Down
9 changes: 9 additions & 0 deletions param/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import random
import numbers
import operator
import typing

# Allow this file to be used standalone if desired, albeit without JSON serialization
try:
Expand Down Expand Up @@ -1071,6 +1072,10 @@ class hierarchy (see ParameterizedMetaclass).
self.watchers = {}
self.per_instance = per_instance

@property
def pytype(self):
return typing.Any

@classmethod
def serialize(cls, value):
"Given the parameter value, return a Python value suitable for serialization"
Expand Down Expand Up @@ -1331,6 +1336,10 @@ def __init__(self, default="", regex=None, allow_None=False, **kwargs):
self.allow_None = (default is None or allow_None)
self._validate(default)

@property
def pytype(self):
return typing.Union[str, None] if self.allow_None else str

def _validate_regex(self, val, regex):
if (val is None and self.allow_None):
return
Expand Down

0 comments on commit 4376014

Please sign in to comment.