Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Validator base class, fix hidden type error #378

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion firecrown/connector/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def redshift_to_scale_factor_p_k(p_k):
p_k_out = np.flipud(p_k)
return p_k_out

def asdict(self) -> Dict[str, Union[Optional[float], List[float]]]:
def asdict(self) -> Dict[str, Union[Optional[float], List[float], str]]:
"""Return a dictionary containing the cosmological constants."""
return {
"Omega_c": self.Omega_c,
Expand Down
111 changes: 70 additions & 41 deletions firecrown/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,50 +26,26 @@ def __init__(self):
# pylint: disable=R0903

import math
from abc import ABC, abstractmethod
from typing import Callable, Optional


class Validator(ABC):
"""Abstract base class for all validators.

The base class implements the methods required for setting values, which should
not be overridden in derived classes.

Derived classes must implement the abstract method validate, which is called by
the base class when any variable protected by a Validator is set.
"""

def __set_name__(self, owner, name):
"""Create the name of the private instance variable that will hold the value."""
self.private_name = "_" + name # pylint: disable-msg=W0201

def __get__(self, obj, objtype=None):
"""Accessor method, which reads controlled value.

This is invoked whenever the validated variable is read."""
return getattr(obj, self.private_name)

def __set__(self, obj, value):
"""Setter for the validated variable.

This function invokes the `validate` method of the derived class."""
self.validate(value)
setattr(obj, self.private_name, value)

@abstractmethod
def validate(self, value):
"""Abstract method to perform whatever validation is required."""


class TypeFloat(Validator):
class TypeFloat:
"""Floating point number attribute descriptor."""

def __init__(self, minvalue=None, maxvalue=None, allow_none=False):
def __init__(
self,
minvalue: Optional[float] = None,
maxvalue: Optional[float] = None,
allow_none: bool = False,
) -> None:
self.minvalue = minvalue
self.maxvalue = maxvalue
self.allow_none = allow_none

def validate(self, value):
def validate(self, value: Optional[float]) -> None:
"""Raise an exception if the provided `value` does not meet all of the
required conditions enforced by this validator.
"""
if self.allow_none and value is None:
return
if not isinstance(value, float):
Expand All @@ -81,19 +57,55 @@ def validate(self, value):
if self._is_constrained() and math.isnan(value):
raise ValueError("NaN is disallowed in a constrained float")

def _is_constrained(self):
def _is_constrained(self) -> bool:
"""Return true if this validation enforces any constraint, and false
if it does not."""
return not ((self.minvalue is None) and (self.maxvalue is None))

def __set_name__(self, _, name: str) -> None:
"""Create the name of the private instance variable that will hold the value."""
self.private_name = "_" + name # pylint: disable-msg=W0201

def __get__(self, obj, objtype=None) -> float:
"""Accessor method, which reads controlled value.

This is invoked whenever the validated variable is read."""
return getattr(obj, self.private_name)

def __set__(self, obj, value: Optional[float]) -> None:
"""Setter for the validated variable.

class TypeString(Validator):
"""String attribute descriptor."""
This function invokes the `validate` method of the derived class."""
self.validate(value)
setattr(obj, self.private_name, value)

def __init__(self, minsize=None, maxsize=None, predicate=None):

class TypeString:
"""String attribute descriptor.

TypeString provides several different means of validation of the controlled
string attribute, all of which are optional.
`minsize` provides a required minimum length for the string
`maxsize` provides a required maximum length for the string
`predicate` allows specification of a function that must return true
when a string is provided to allow use of that string.
"""

def __init__(
self,
minsize: Optional[int] = None,
maxsize: Optional[int] = None,
predicate: Optional[Callable[[str], bool]] = None,
) -> None:
"""Initialize the TypeString object'"""
self.minsize = minsize
self.maxsize = maxsize
self.predicate = predicate

def validate(self, value):
def validate(self, value: Optional[str]) -> None:
"""Raise an exception if the provided `value` does not meet all of the
required conditions enforced by this validator.
"""
if not isinstance(value, str):
raise TypeError(f"Expected {value!r} to be an str")
if self.minsize is not None and len(value) < self.minsize:
Expand All @@ -106,3 +118,20 @@ def validate(self, value):
)
if self.predicate is not None and not self.predicate(value):
raise ValueError(f"Expected {self.predicate} to be true for {value!r}")

def __set_name__(self, _, name: str) -> None:
"""Create the name of the private instance variable that will hold the value."""
self.private_name = "_" + name # pylint: disable-msg=W0201

def __get__(self, obj, objtype=None) -> str:
"""Accessor method, which reads controlled value.

This is invoked whenever the validated variable is read."""
return getattr(obj, self.private_name)

def __set__(self, obj, value: Optional[str]) -> None:
"""Setter for the validated variable.

This function invokes the `validate` method of the derived class."""
self.validate(value)
setattr(obj, self.private_name, value)
4 changes: 3 additions & 1 deletion tests/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def test_upper_bound_float():
def test_string_conversion_failure():
d = HasString()
with pytest.raises(TypeError):
d.x = NotStringy()
# We ignore type checking on this line because we are testing the error
# handling for the very type error that mypy would detect.
d.x = NotStringy() # type: ignore


def test_string_too_short():
Expand Down
Loading