diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py index 461324c43..fee74c2d4 100644 --- a/src/blueapi/utils/type_validator.py +++ b/src/blueapi/utils/type_validator.py @@ -125,7 +125,7 @@ def create_model_with_type_validators( Args: name: Name of the new model definitions: Definitions of how to validate which types of field - base (Type[BaseModel]): Base class for the model + base: Base class for the model Returns: Type[BaseModel]: A new version of `base` with type validators @@ -142,7 +142,6 @@ def create_model_with_type_validators( base: Optional[Type[BaseModel]] = None, func: Optional[Callable[..., Any]] = None, config: Optional[Type[BaseConfig]] = None, - cache: Optional[Dict[Type, Type]] = None, ) -> Type[BaseModel]: """ Create a pydantic model with type validators according to @@ -164,7 +163,8 @@ def create_model_with_type_validators( Type[BaseModel]: A new pydantic model """ - cache = cache or {} + # Fields are determined from various sources, directly passed, a base class + # and/or a function signature. all_fields = {**(fields or {})} if base is not None: all_fields = {**all_fields, **_extract_fields_from_model(base)} @@ -172,16 +172,8 @@ def create_model_with_type_validators( all_fields = {**all_fields, **_extract_fields_from_function(func)} for name, field in all_fields.items(): annotation, val = field - if annotation in cache: - all_fields[name] = cache[annotation], val - else: - all_fields[name] = apply_type_validators(annotation, definitions), val - # model_type = find_model_type(annotation) - # if model_type is not None: - # recursed = create_model_with_type_validators( - # annotation.__name__, definitions, base=model_type - # ) - # all_fields[name] = recursed, val + all_fields[name] = apply_type_validators(annotation, definitions), val + validators = _type_validators(all_fields, definitions) return create_model( # type: ignore name, **all_fields, __base__=base, __validators__=validators, __config__=config @@ -191,16 +183,24 @@ def create_model_with_type_validators( def apply_type_validators( model_type: Type, definitions: List[TypeValidatorDefinition], - cache: Optional[Dict[Type, Type]] = None, ) -> Type: - cache = cache or {} - if model_type in cache: - return cache[model_type] + """ + Create a copy of a model (or modellable type) that has the defined + type validators. + + Args: + model_type: The model to copy (e.g. a BaseModel or pydantic dataclass) + definitions: Definitions of type validators that the copy should have + + Returns: + Type: A new pydantic model + """ if isclass(model_type) and issubclass(model_type, BaseModel): if "__root__" in model_type.__fields__: return apply_type_validators( - model_type.__fields__["__root__"].type_, definitions, cache=cache + model_type.__fields__["__root__"].type_, + definitions, ) else: return create_model_with_type_validators( @@ -210,14 +210,26 @@ def apply_type_validators( ) elif isclass(model_type) and hasattr(model_type, "__pydantic_model__"): model = getattr(model_type, "__pydantic_model__") - return apply_type_validators(model, definitions, cache=cache) + # Recursively apply to inner model + return apply_type_validators( + model, + definitions, + ) else: + # Apply to type parameters, e.g. apply to int in List[int] params = [ - apply_type_validators(param, definitions, cache=cache) + apply_type_validators( + param, + definitions, + ) for param in get_args(model_type) ] + + # __origin__ converts Union[int, str] to Union if params and hasattr(model_type, "__origin__"): origin = getattr(model_type, "__origin__") + # Certain origins are different to their code notations, + # e.g. list vs List. origin = _sanitise_origin(origin) return origin[tuple(params)] return model_type @@ -304,10 +316,35 @@ def _determine_fields_of_type(fields: Fields, field_type: Type) -> Iterable[str] def is_type_or_container_type(type_to_check: Type, field_type: Type) -> bool: + """ + Is the supplied type either the type to check against or a type that + contains it? + For example, if field_type=int, then this should return True if type_to_check + is int, List[int], Dict[str, int], Set[int] etc. + + Args: + type_to_check: The type to check + field_type: The expected type + + Returns: + bool: True if the types match + """ return params_contains(type_to_check, field_type) def params_contains(type_to_check: Type, field_type: Type) -> bool: + """ + Do the parameters of a type contain the type to check? + For example, do the parameters of List[int] conain int? Yes + + Args: + type_to_check: The type to check + field_type: The expected type + + Returns: + bool: True if the types match + """ + type_params = get_args(type_to_check) return type_to_check is field_type or any( map(lambda v: params_contains(v, field_type), type_params) @@ -315,6 +352,21 @@ def params_contains(type_to_check: Type, field_type: Type) -> bool: def apply_to_scalars(func: Callable[[T], U], obj: Any) -> Any: + """ + Apply the supplied function to all scalars within the JSON-serializable + object. In this case, scalars are values of type int, str, float, and bool. + For example, if the function multiplies by 2 and the object is: + {"a": 3, "b": [4, 5]} then the result should be {"a": 6, "b": [8, 10]} + + Args: + func: The function to apply. + obj: An object that can be serialized to JSON + + Returns: + Any: A new JSON-serializable object with the function + applied to all scalars. + """ + if is_list_type(obj): return list(map(lambda v: apply_to_scalars(func, v), obj)) elif is_dict_type(obj):