Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments to pydantic code #123

Merged
merged 1 commit into from
Apr 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 72 additions & 20 deletions src/blueapi/utils/type_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def create_model_with_type_validators(
Args:
name: Name of the new model
definitions: Definitions of how to validate which types of field
base (Type[BaseModel]): Base class for the model
base: Base class for the model

Returns:
Type[BaseModel]: A new version of `base` with type validators
Expand All @@ -142,7 +142,6 @@ def create_model_with_type_validators(
base: Optional[Type[BaseModel]] = None,
func: Optional[Callable[..., Any]] = None,
config: Optional[Type[BaseConfig]] = None,
cache: Optional[Dict[Type, Type]] = None,
) -> Type[BaseModel]:
"""
Create a pydantic model with type validators according to
Expand All @@ -164,24 +163,17 @@ def create_model_with_type_validators(
Type[BaseModel]: A new pydantic model
"""

cache = cache or {}
# Fields are determined from various sources, directly passed, a base class
# and/or a function signature.
all_fields = {**(fields or {})}
if base is not None:
all_fields = {**all_fields, **_extract_fields_from_model(base)}
if func is not None:
all_fields = {**all_fields, **_extract_fields_from_function(func)}
for name, field in all_fields.items():
annotation, val = field
if annotation in cache:
all_fields[name] = cache[annotation], val
else:
all_fields[name] = apply_type_validators(annotation, definitions), val
# model_type = find_model_type(annotation)
# if model_type is not None:
# recursed = create_model_with_type_validators(
# annotation.__name__, definitions, base=model_type
# )
# all_fields[name] = recursed, val
all_fields[name] = apply_type_validators(annotation, definitions), val

validators = _type_validators(all_fields, definitions)
return create_model( # type: ignore
name, **all_fields, __base__=base, __validators__=validators, __config__=config
Expand All @@ -191,16 +183,24 @@ def create_model_with_type_validators(
def apply_type_validators(
model_type: Type,
definitions: List[TypeValidatorDefinition],
cache: Optional[Dict[Type, Type]] = None,
) -> Type:
cache = cache or {}
if model_type in cache:
return cache[model_type]
"""
Create a copy of a model (or modellable type) that has the defined
type validators.

Args:
model_type: The model to copy (e.g. a BaseModel or pydantic dataclass)
definitions: Definitions of type validators that the copy should have

Returns:
Type: A new pydantic model
"""

if isclass(model_type) and issubclass(model_type, BaseModel):
if "__root__" in model_type.__fields__:
return apply_type_validators(
model_type.__fields__["__root__"].type_, definitions, cache=cache
model_type.__fields__["__root__"].type_,
definitions,
)
else:
return create_model_with_type_validators(
Expand All @@ -210,14 +210,26 @@ def apply_type_validators(
)
elif isclass(model_type) and hasattr(model_type, "__pydantic_model__"):
model = getattr(model_type, "__pydantic_model__")
return apply_type_validators(model, definitions, cache=cache)
# Recursively apply to inner model
return apply_type_validators(
model,
definitions,
)
else:
# Apply to type parameters, e.g. apply to int in List[int]
params = [
apply_type_validators(param, definitions, cache=cache)
apply_type_validators(
param,
definitions,
)
for param in get_args(model_type)
]

# __origin__ converts Union[int, str] to Union
if params and hasattr(model_type, "__origin__"):
origin = getattr(model_type, "__origin__")
# Certain origins are different to their code notations,
# e.g. list vs List.
origin = _sanitise_origin(origin)
return origin[tuple(params)]
return model_type
Expand Down Expand Up @@ -304,17 +316,57 @@ def _determine_fields_of_type(fields: Fields, field_type: Type) -> Iterable[str]


def is_type_or_container_type(type_to_check: Type, field_type: Type) -> bool:
"""
Is the supplied type either the type to check against or a type that
contains it?
For example, if field_type=int, then this should return True if type_to_check
is int, List[int], Dict[str, int], Set[int] etc.

Args:
type_to_check: The type to check
field_type: The expected type

Returns:
bool: True if the types match
"""
return params_contains(type_to_check, field_type)


def params_contains(type_to_check: Type, field_type: Type) -> bool:
"""
Do the parameters of a type contain the type to check?
For example, do the parameters of List[int] conain int? Yes

Args:
type_to_check: The type to check
field_type: The expected type

Returns:
bool: True if the types match
"""

type_params = get_args(type_to_check)
return type_to_check is field_type or any(
map(lambda v: params_contains(v, field_type), type_params)
)


def apply_to_scalars(func: Callable[[T], U], obj: Any) -> Any:
"""
Apply the supplied function to all scalars within the JSON-serializable
object. In this case, scalars are values of type int, str, float, and bool.
For example, if the function multiplies by 2 and the object is:
{"a": 3, "b": [4, 5]} then the result should be {"a": 6, "b": [8, 10]}

Args:
func: The function to apply.
obj: An object that can be serialized to JSON

Returns:
Any: A new JSON-serializable object with the function
applied to all scalars.
"""

if is_list_type(obj):
return list(map(lambda v: apply_to_scalars(func, v), obj))
elif is_dict_type(obj):
Expand Down