Skip to content

Commit

Permalink
Add comments to pydantic code (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester authored Apr 18, 2023
1 parent 1ca16e9 commit 798dcd0
Showing 1 changed file with 72 additions and 20 deletions.
92 changes: 72 additions & 20 deletions src/blueapi/utils/type_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def create_model_with_type_validators(
Args:
name: Name of the new model
definitions: Definitions of how to validate which types of field
base (Type[BaseModel]): Base class for the model
base: Base class for the model
Returns:
Type[BaseModel]: A new version of `base` with type validators
Expand All @@ -142,7 +142,6 @@ def create_model_with_type_validators(
base: Optional[Type[BaseModel]] = None,
func: Optional[Callable[..., Any]] = None,
config: Optional[Type[BaseConfig]] = None,
cache: Optional[Dict[Type, Type]] = None,
) -> Type[BaseModel]:
"""
Create a pydantic model with type validators according to
Expand All @@ -164,24 +163,17 @@ def create_model_with_type_validators(
Type[BaseModel]: A new pydantic model
"""

cache = cache or {}
# Fields are determined from various sources, directly passed, a base class
# and/or a function signature.
all_fields = {**(fields or {})}
if base is not None:
all_fields = {**all_fields, **_extract_fields_from_model(base)}
if func is not None:
all_fields = {**all_fields, **_extract_fields_from_function(func)}
for name, field in all_fields.items():
annotation, val = field
if annotation in cache:
all_fields[name] = cache[annotation], val
else:
all_fields[name] = apply_type_validators(annotation, definitions), val
# model_type = find_model_type(annotation)
# if model_type is not None:
# recursed = create_model_with_type_validators(
# annotation.__name__, definitions, base=model_type
# )
# all_fields[name] = recursed, val
all_fields[name] = apply_type_validators(annotation, definitions), val

validators = _type_validators(all_fields, definitions)
return create_model( # type: ignore
name, **all_fields, __base__=base, __validators__=validators, __config__=config
Expand All @@ -191,16 +183,24 @@ def create_model_with_type_validators(
def apply_type_validators(
model_type: Type,
definitions: List[TypeValidatorDefinition],
cache: Optional[Dict[Type, Type]] = None,
) -> Type:
cache = cache or {}
if model_type in cache:
return cache[model_type]
"""
Create a copy of a model (or modellable type) that has the defined
type validators.
Args:
model_type: The model to copy (e.g. a BaseModel or pydantic dataclass)
definitions: Definitions of type validators that the copy should have
Returns:
Type: A new pydantic model
"""

if isclass(model_type) and issubclass(model_type, BaseModel):
if "__root__" in model_type.__fields__:
return apply_type_validators(
model_type.__fields__["__root__"].type_, definitions, cache=cache
model_type.__fields__["__root__"].type_,
definitions,
)
else:
return create_model_with_type_validators(
Expand All @@ -210,14 +210,26 @@ def apply_type_validators(
)
elif isclass(model_type) and hasattr(model_type, "__pydantic_model__"):
model = getattr(model_type, "__pydantic_model__")
return apply_type_validators(model, definitions, cache=cache)
# Recursively apply to inner model
return apply_type_validators(
model,
definitions,
)
else:
# Apply to type parameters, e.g. apply to int in List[int]
params = [
apply_type_validators(param, definitions, cache=cache)
apply_type_validators(
param,
definitions,
)
for param in get_args(model_type)
]

# __origin__ converts Union[int, str] to Union
if params and hasattr(model_type, "__origin__"):
origin = getattr(model_type, "__origin__")
# Certain origins are different to their code notations,
# e.g. list vs List.
origin = _sanitise_origin(origin)
return origin[tuple(params)]
return model_type
Expand Down Expand Up @@ -304,17 +316,57 @@ def _determine_fields_of_type(fields: Fields, field_type: Type) -> Iterable[str]


def is_type_or_container_type(type_to_check: Type, field_type: Type) -> bool:
"""
Is the supplied type either the type to check against or a type that
contains it?
For example, if field_type=int, then this should return True if type_to_check
is int, List[int], Dict[str, int], Set[int] etc.
Args:
type_to_check: The type to check
field_type: The expected type
Returns:
bool: True if the types match
"""
return params_contains(type_to_check, field_type)


def params_contains(type_to_check: Type, field_type: Type) -> bool:
"""
Do the parameters of a type contain the type to check?
For example, do the parameters of List[int] conain int? Yes
Args:
type_to_check: The type to check
field_type: The expected type
Returns:
bool: True if the types match
"""

type_params = get_args(type_to_check)
return type_to_check is field_type or any(
map(lambda v: params_contains(v, field_type), type_params)
)


def apply_to_scalars(func: Callable[[T], U], obj: Any) -> Any:
"""
Apply the supplied function to all scalars within the JSON-serializable
object. In this case, scalars are values of type int, str, float, and bool.
For example, if the function multiplies by 2 and the object is:
{"a": 3, "b": [4, 5]} then the result should be {"a": 6, "b": [8, 10]}
Args:
func: The function to apply.
obj: An object that can be serialized to JSON
Returns:
Any: A new JSON-serializable object with the function
applied to all scalars.
"""

if is_list_type(obj):
return list(map(lambda v: apply_to_scalars(func, v), obj))
elif is_dict_type(obj):
Expand Down

0 comments on commit 798dcd0

Please sign in to comment.