From 8cee81e6a6dc27295b6e7160823f8e78f1ffa2ed Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Sun, 11 Feb 2024 09:14:20 -0600 Subject: [PATCH] Remove Validator base class, fix hidden type error --- firecrown/connector/mapping.py | 2 +- firecrown/descriptors.py | 111 +++++++++++++++++++++------------ tests/test_descriptors.py | 4 +- 3 files changed, 74 insertions(+), 43 deletions(-) diff --git a/firecrown/connector/mapping.py b/firecrown/connector/mapping.py index ff488d4f..b90589d9 100644 --- a/firecrown/connector/mapping.py +++ b/firecrown/connector/mapping.py @@ -175,7 +175,7 @@ def redshift_to_scale_factor_p_k(p_k): p_k_out = np.flipud(p_k) return p_k_out - def asdict(self) -> Dict[str, Union[Optional[float], List[float]]]: + def asdict(self) -> Dict[str, Union[Optional[float], List[float], str]]: """Return a dictionary containing the cosmological constants.""" return { "Omega_c": self.Omega_c, diff --git a/firecrown/descriptors.py b/firecrown/descriptors.py index c1db8673..85759149 100644 --- a/firecrown/descriptors.py +++ b/firecrown/descriptors.py @@ -26,50 +26,26 @@ def __init__(self): # pylint: disable=R0903 import math -from abc import ABC, abstractmethod +from typing import Callable, Optional -class Validator(ABC): - """Abstract base class for all validators. - - The base class implements the methods required for setting values, which should - not be overridden in derived classes. - - Derived classes must implement the abstract method validate, which is called by - the base class when any variable protected by a Validator is set. - """ - - def __set_name__(self, owner, name): - """Create the name of the private instance variable that will hold the value.""" - self.private_name = "_" + name # pylint: disable-msg=W0201 - - def __get__(self, obj, objtype=None): - """Accessor method, which reads controlled value. - - This is invoked whenever the validated variable is read.""" - return getattr(obj, self.private_name) - - def __set__(self, obj, value): - """Setter for the validated variable. - - This function invokes the `validate` method of the derived class.""" - self.validate(value) - setattr(obj, self.private_name, value) - - @abstractmethod - def validate(self, value): - """Abstract method to perform whatever validation is required.""" - - -class TypeFloat(Validator): +class TypeFloat: """Floating point number attribute descriptor.""" - def __init__(self, minvalue=None, maxvalue=None, allow_none=False): + def __init__( + self, + minvalue: Optional[float] = None, + maxvalue: Optional[float] = None, + allow_none: bool = False, + ) -> None: self.minvalue = minvalue self.maxvalue = maxvalue self.allow_none = allow_none - def validate(self, value): + def validate(self, value: Optional[float]) -> None: + """Raise an exception if the provided `value` does not meet all of the + required conditions enforced by this validator. + """ if self.allow_none and value is None: return if not isinstance(value, float): @@ -81,19 +57,55 @@ def validate(self, value): if self._is_constrained() and math.isnan(value): raise ValueError("NaN is disallowed in a constrained float") - def _is_constrained(self): + def _is_constrained(self) -> bool: + """Return true if this validation enforces any constraint, and false + if it does not.""" return not ((self.minvalue is None) and (self.maxvalue is None)) + def __set_name__(self, _, name: str) -> None: + """Create the name of the private instance variable that will hold the value.""" + self.private_name = "_" + name # pylint: disable-msg=W0201 + + def __get__(self, obj, objtype=None) -> float: + """Accessor method, which reads controlled value. + + This is invoked whenever the validated variable is read.""" + return getattr(obj, self.private_name) + + def __set__(self, obj, value: Optional[float]) -> None: + """Setter for the validated variable. -class TypeString(Validator): - """String attribute descriptor.""" + This function invokes the `validate` method of the derived class.""" + self.validate(value) + setattr(obj, self.private_name, value) - def __init__(self, minsize=None, maxsize=None, predicate=None): + +class TypeString: + """String attribute descriptor. + + TypeString provides several different means of validation of the controlled + string attribute, all of which are optional. + `minsize` provides a required minimum length for the string + `maxsize` provides a required maximum length for the string + `predicate` allows specification of a function that must return true + when a string is provided to allow use of that string. + """ + + def __init__( + self, + minsize: Optional[int] = None, + maxsize: Optional[int] = None, + predicate: Optional[Callable[[str], bool]] = None, + ) -> None: + """Initialize the TypeString object'""" self.minsize = minsize self.maxsize = maxsize self.predicate = predicate - def validate(self, value): + def validate(self, value: Optional[str]) -> None: + """Raise an exception if the provided `value` does not meet all of the + required conditions enforced by this validator. + """ if not isinstance(value, str): raise TypeError(f"Expected {value!r} to be an str") if self.minsize is not None and len(value) < self.minsize: @@ -106,3 +118,20 @@ def validate(self, value): ) if self.predicate is not None and not self.predicate(value): raise ValueError(f"Expected {self.predicate} to be true for {value!r}") + + def __set_name__(self, _, name: str) -> None: + """Create the name of the private instance variable that will hold the value.""" + self.private_name = "_" + name # pylint: disable-msg=W0201 + + def __get__(self, obj, objtype=None) -> str: + """Accessor method, which reads controlled value. + + This is invoked whenever the validated variable is read.""" + return getattr(obj, self.private_name) + + def __set__(self, obj, value: Optional[str]) -> None: + """Setter for the validated variable. + + This function invokes the `validate` method of the derived class.""" + self.validate(value) + setattr(obj, self.private_name, value) diff --git a/tests/test_descriptors.py b/tests/test_descriptors.py index d69ab369..7e124fc9 100644 --- a/tests/test_descriptors.py +++ b/tests/test_descriptors.py @@ -219,7 +219,9 @@ def test_upper_bound_float(): def test_string_conversion_failure(): d = HasString() with pytest.raises(TypeError): - d.x = NotStringy() + # We ignore type checking on this line because we are testing the error + # handling for the very type error that mypy would detect. + d.x = NotStringy() # type: ignore def test_string_too_short():