diff --git a/.github/workflows/validate-test-coverage.yaml b/.github/workflows/validate-test-coverage.yaml index 5e6ab606..e1e194cc 100644 --- a/.github/workflows/validate-test-coverage.yaml +++ b/.github/workflows/validate-test-coverage.yaml @@ -23,7 +23,7 @@ jobs: with: python-version: '3.11' - name: Install Dependencies - run: pip install -r requirements-dev.txt + run: pip install -r requirements-test.txt && pip install -r requirements-core.txt - name: Coverage Run run: coverage run --omit=$(cat test/exclude) --branch -m unittest - name: Coverage Report diff --git a/expediagroup/sdk/core/client/api.py b/expediagroup/sdk/core/client/api.py index df006f46..62429bc6 100644 --- a/expediagroup/sdk/core/client/api.py +++ b/expediagroup/sdk/core/client/api.py @@ -11,23 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum -import json import logging -import uuid +from copy import deepcopy from http import HTTPStatus from typing import Any, Optional -import pydantic -import pydantic.schema import requests -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from expediagroup.sdk.core.client.auth_client import AuthClient from expediagroup.sdk.core.configuration.client_config import ClientConfig from expediagroup.sdk.core.constant import header as header_constant from expediagroup.sdk.core.constant import log as log_constant from expediagroup.sdk.core.constant.constant import OK_STATUS_CODES_RANGE +from expediagroup.sdk.core.model.api import RequestHeaders from expediagroup.sdk.core.model.error import Error from expediagroup.sdk.core.model.exception import service as service_exception from expediagroup.sdk.core.util import log as log_util @@ -52,18 +49,18 @@ def __init__(self, config: ClientConfig, auth_client_cls): @staticmethod def __build_response( response: requests.Response, - response_models: list[pydantic.BaseModel], + response_models: list[type], error_responses: dict[int, Any], ): if response.status_code not in OK_STATUS_CODES_RANGE: exception: service_exception.ExpediaGroupApiException if response.status_code in error_responses.keys(): - error_object = pydantic.parse_obj_as(error_responses[response.status_code].model, response.json()) + error_object = error_responses[response.status_code].model.model_validate(response.json()) exception = error_responses[response.status_code].exception.of(error=error_object, error_code=HTTPStatus(response.status_code)) else: exception = service_exception.ExpediaGroupApiException.of( - error=Error.parse_obj(response.json()), + error=Error.model_validate(response.json()), error_code=HTTPStatus(response.status_code), ) @@ -74,7 +71,7 @@ def __build_response( if not model: continue try: - response_object = pydantic.parse_obj_as(model, response.json()) + response_object = TypeAdapter(model).validate_python(response.json()) return response_object except Exception: continue @@ -85,10 +82,10 @@ def call( self, method: str, url: str, - body: pydantic.BaseModel, - headers: dict = dict(), # noqa - response_models: Optional[list[Any]] = [None], # noqa - error_responses: dict[int, Any] = {}, # noqa + body: BaseModel, + headers: RequestHeaders = RequestHeaders(), # noqa + response_models: Optional[list[Any]] = list(), # noqa + error_responses: dict[int, Any] = dict(), # noqa ) -> Any: r"""Sends HTTP request to API. @@ -113,7 +110,7 @@ def call( timeout=self.request_timeout, ) else: - request_body = body.json(exclude_none=True) + request_body = body.model_dump_json(exclude_none=True) response = requests.request( method=method.upper(), url=url, @@ -123,7 +120,7 @@ def call( timeout=self.request_timeout, ) - logged_body: dict[str, Any] = dict() if not body else body.dict() + logged_body: dict[str, Any] = dict() if not body else body.model_dump() request_log_message = log_util.request_log( headers=request_headers, @@ -146,24 +143,15 @@ def call( @staticmethod def __fill_request_headers(request_headers: dict): if not request_headers: - request_headers = dict() + return header_constant.API_REQUEST - request_header_keys = request_headers.keys() - for key, value in header_constant.API_REQUEST.items(): - if key in request_header_keys: - continue - - request_headers[key] = value + headers: dict = deepcopy(header_constant.API_REQUEST) + headers.update(request_headers) - return request_headers + return headers @staticmethod - def __prepare_request_headers(headers: dict) -> dict: - request_headers = dict() - for header_key, header_value in headers.items(): - if not header_value: - continue - needs_serialization: bool = isinstance(header_value, BaseModel) or isinstance(header_value, enum.Enum) or isinstance(header_value, uuid.UUID) - request_headers[header_key] = json.dumps(header_value, default=pydantic.schema.pydantic_encoder) if needs_serialization else header_value + def __prepare_request_headers(headers: RequestHeaders) -> dict: + request_headers: dict = headers.unwrap() return ApiClient.__fill_request_headers(request_headers) diff --git a/expediagroup/sdk/core/model/api.py b/expediagroup/sdk/core/model/api.py new file mode 100644 index 00000000..05fa12b4 --- /dev/null +++ b/expediagroup/sdk/core/model/api.py @@ -0,0 +1,46 @@ +# Copyright 2022 Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any + +from pydantic import BaseModel, Field + + +class RequestHeaders(BaseModel): + """ + RequestHeaders class represents the headers of an HTTP request. + + Attributes: + headers (Any): The HTTP request headers. It can be of any type. + """ + + headers: Any = Field(default=None) + + def unwrap(self) -> dict[str, Any]: + """ + Unwraps the headers from the model. + + Returns: + A dictionary containing the headers. + + Example: + >>> headers = RequestHeaders() + >>> headers.unwrap() + {'Content-Type': 'application/json', 'Authorization': 'Bearer token'} + """ + if not self.headers: + return dict() + + return json.loads(self.model_dump_json()).get("headers") diff --git a/expediagroup/sdk/core/model/authentication.py b/expediagroup/sdk/core/model/authentication.py index a865b5e7..8f99ee05 100644 --- a/expediagroup/sdk/core/model/authentication.py +++ b/expediagroup/sdk/core/model/authentication.py @@ -16,7 +16,7 @@ import logging from dataclasses import dataclass from multiprocessing import Lock -from typing import Optional +from typing import Optional, Union import pydantic.schema import requests @@ -40,7 +40,7 @@ class _TokenResponse(pydantic.BaseModel): """A model of an API response.""" access_token: str - expires_in: int + expires_in: Union[int, float] scope: str token_type: str id_token: Optional[str] = None @@ -53,7 +53,7 @@ def __init__(self, data: dict): :param data: token data """ - self.__token: _TokenResponse = _TokenResponse.parse_obj(data) + self.__token: _TokenResponse = _TokenResponse.model_validate(data) self.lock = Lock() self.__expiration_time = datetime.datetime.now() + datetime.timedelta(seconds=self.__token.expires_in) self.__auth_header = HttpBearerAuth(self.__token.access_token) diff --git a/expediagroup/sdk/generator/client/datatype_manager.py b/expediagroup/sdk/generator/client/datatype_manager.py new file mode 100644 index 00000000..bf8ede86 --- /dev/null +++ b/expediagroup/sdk/generator/client/datatype_manager.py @@ -0,0 +1,77 @@ +# Copyright 2022 Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Any + +from datamodel_code_generator.model import pydantic as datamodel_code_generator_pydantic +from datamodel_code_generator.model.pydantic.imports import IMPORT_CONSTR +from datamodel_code_generator.model.pydantic.types import ( + escape_characters, + string_kwargs, + transform_kwargs, +) +from datamodel_code_generator.types import DataType, StrictTypes, Types + +PYDANTIC_V2_MIGRATION_CONSTRAINTS_MAPPING: dict[str, str] = {"regex": "pattern"} + + +class PydanticV2DataTypeManager(datamodel_code_generator_pydantic.DataTypeManager): + r""" + Custom DataTypeManager to map PydanticV1 types to PydanticV2. + + Notes: + - This class is a temporary solution until `fastapi-code-generator` bumps up + its `datamodel-code-generator` from `0.16.1` to `>=0.25.1` which includes + PydanticV2 support. + GitHub Issue: https://github.com/koxudaxi/fastapi-code-generator/issues/378 + """ + + @staticmethod + def migrate_datatype_constraints(data_type_kwargs: dict[str, Any]) -> dict[str, Any]: + """ + Migrates datatype constraints in the given data_type_kwargs dictionary from pydantic v1 to v2 . + + Args: + data_type_kwargs (dict[str, Any]): A dictionary containing datatype constraints. + + Returns: + dict[str, Any]: The migrated datatype constraints dictionary. + """ + migrated_data_type_kwargs: dict[str, Any] = deepcopy(data_type_kwargs) + + for key, value in PYDANTIC_V2_MIGRATION_CONSTRAINTS_MAPPING.items(): + if migrated_data_type_kwargs.get(key): + migrated_data_type_kwargs.update([(value, migrated_data_type_kwargs.get(key))]) + migrated_data_type_kwargs.pop(key) + + return migrated_data_type_kwargs + + def get_data_str_type(self, types: Types, **kwargs: Any) -> DataType: + data_type_kwargs: dict[str, Any] = transform_kwargs(kwargs, string_kwargs) + strict = StrictTypes.str in self.strict_types + if data_type_kwargs: + if strict: + data_type_kwargs["strict"] = True + if "regex" in data_type_kwargs: + escaped_regex = data_type_kwargs["regex"].translate(escape_characters) + data_type_kwargs["regex"] = f"r'{escaped_regex}'" + + # Copied code, single line added. + data_type_kwargs = PydanticV2DataTypeManager.migrate_datatype_constraints(data_type_kwargs) + + return self.data_type.from_import(IMPORT_CONSTR, kwargs=data_type_kwargs) + if strict: + return self.strict_type_map[StrictTypes.str] + return self.type_map[types] diff --git a/expediagroup/sdk/generator/client/parser.py b/expediagroup/sdk/generator/client/parser.py index ad2c16bd..acc5fb1a 100644 --- a/expediagroup/sdk/generator/client/parser.py +++ b/expediagroup/sdk/generator/client/parser.py @@ -61,6 +61,7 @@ RequestBodyObject, ) from datamodel_code_generator.types import DataType, DataTypeManager, StrictTypes, Types +from datatype_manager import PydanticV2DataTypeManager from fastapi_code_generator import parser from model import Argument, Operation, ParamTypes from stringcase import snakecase @@ -73,7 +74,7 @@ def __init__( *, data_model_type: type[DataModel] = pydantic_model.BaseModel, data_model_root_type: type[DataModel] = pydantic_model.CustomRootType, - data_type_manager_type: type[DataTypeManager] = pydantic_model.DataTypeManager, + data_type_manager_type: type[DataTypeManager] = PydanticV2DataTypeManager, data_model_field_type: type[DataModelFieldBase] = pydantic_model.DataModelField, base_class: Optional[str] = None, custom_template_dir: Optional[pathlib.Path] = None, diff --git a/expediagroup/sdk/generator/client/templates/__model__.jinja2 b/expediagroup/sdk/generator/client/templates/__model__.jinja2 index 7f1b5168..7e6d3831 100644 --- a/expediagroup/sdk/generator/client/templates/__model__.jinja2 +++ b/expediagroup/sdk/generator/client/templates/__model__.jinja2 @@ -13,39 +13,36 @@ {# limitations under the License.#} {{ model_imports }} -from collections.abc import Callable from typing import Union, Any, Literal -from pydantic import Extra, validator, SecretStr, SecretBytes +from pydantic import field_validator, SecretStr, SecretBytes, ConfigDict from pydantic.dataclasses import dataclass from expediagroup.sdk.core.model.exception.service import ExpediaGroupApiException SecretStr.__str__ = lambda self: '<-- omitted -->' if self.get_secret_value() else '' -class PydanticModelConfig: - r"""List of configurations for all SDK pydantic models.""" - JSON_ENCODERS: dict[type, Callable] = { - SecretStr: lambda v: v.get_secret_value() if v else None, - SecretBytes: lambda v: v.get_secret_value() if v else None, - } - - EXTRA: bool = Extra.forbid - - SMART_UNION: bool = True +class PydanticModel(BaseModel): + r"""Generic model that is a parent to all pydantic models, holds models configuration.""" + model_config: dict[str, Any] = ConfigDict( + extra="forbid", + json_encoders={ + SecretStr: lambda v: v.get_secret_value() if v else None, + SecretBytes: lambda v: v.get_secret_value() if v else None, + } + ) {% for model in models %} {% for decorator in model.decorators -%} {{ decorator }} {% endfor -%} - -class {{ model.class_name }}{% if is_aliased[model.class_name] %}Generic{% endif %}({{ model.base_class }}{% if is_aliased[model.base_class] %}Generic{% endif %},{% if model.base_class != 'Enum' %} smart_union=PydanticModelConfig.SMART_UNION, extra=PydanticModelConfig.EXTRA, json_encoders=PydanticModelConfig.JSON_ENCODERS{% endif %}): {% if comment is defined %} # {{ model.comment }}{% endif %} +class {{ model.class_name }}{% if is_aliased[model.class_name] %}Generic{% endif %}({% if model.base_class == 'BaseModel' %}PydanticModel{% else %}{{ model.base_class }}{% endif %}{% if is_aliased[model.base_class] %}Generic{% endif %},): {% if comment is defined %} # {{ model.comment }}{% endif %} r"""pydantic model {{ model.class_name }}{%- if model.description %}: {{ model.description }}{%- endif %} {# comment for new line #} """ {% if model.class_name in omitted_log_fields.keys() %} {% for field in omitted_log_fields[model.class_name] %} - @validator("{{ field }}") + @field_validator("{{ field }}") def __{{ field }}_validator(cls, {{ field }}): return SecretStr(str({{ field }})) {% endfor %} @@ -82,7 +79,7 @@ class {{ model.class_name }}{% if is_aliased[model.class_name] %}Generic{% endif {% for model in models %} {% if not is_aliased[model.class_name] and model.base_class != 'Enum' %} -{{ model.class_name }}.update_forward_refs() +{{ model.class_name }}.model_rebuild() {% endif %} {% endfor %} diff --git a/expediagroup/sdk/generator/client/templates/client.jinja2 b/expediagroup/sdk/generator/client/templates/client.jinja2 index 2434b180..7e9df176 100644 --- a/expediagroup/sdk/generator/client/templates/client.jinja2 +++ b/expediagroup/sdk/generator/client/templates/client.jinja2 @@ -5,6 +5,7 @@ import pydantic.schema from expediagroup.sdk.core.client.api import ApiClient from expediagroup.sdk.core.constant import header from expediagroup.sdk.core.configuration.client_config import ClientConfig +from expediagroup.sdk.core.model.api import RequestHeaders from furl import furl from uuid import UUID, uuid4 {% if error_responses_models.__len__() %} @@ -36,7 +37,7 @@ class {{ classname }}: Args: {% for arguemnt in operation.snake_case_arguments_list %} {{ arguemnt.name }}({{ arguemnt.type_hint }}{% if not arguemnt.required %}, optional{% endif %}): {{ arguemnt.description.replace("\n", "") }} {% endfor %}""" - headers = { + headers = RequestHeaders(headers={ header.TRANSACTION_ID: transaction_id, header.USER_AGENT: self.__user_agent, {% for arguemnt in operation.snake_case_arguments_list %} @@ -44,7 +45,7 @@ Args: '{{ arguemnt.alias }}': {{ arguemnt.name.strip() }}, {% endif %} {% endfor %} - } + }) query = {key: value for key, value in { {% for arguemnt in operation.snake_case_arguments_list %} diff --git a/expediagroup/sdk/generator/client/visitors/models.py b/expediagroup/sdk/generator/client/visitors/models.py index 66906a32..419aead3 100644 --- a/expediagroup/sdk/generator/client/visitors/models.py +++ b/expediagroup/sdk/generator/client/visitors/models.py @@ -13,6 +13,7 @@ # limitations under the License. import collections import dataclasses +from collections.abc import Callable from pathlib import Path from typing import Any @@ -246,6 +247,21 @@ def get_error_models(parser: OpenAPIParser) -> list: return list(error_models) +def delete_root_models(models: dict[str, DataModel]): + """ + Deletes root models from a dictionary of DataModel objects. + + Args: + models (dict[str, DataModel]): A dictionary containing DataModel objects, where the keys are the class names. + + """ + is_root_model: Callable = lambda model: len(model.fields) == 1 and not model.fields[0].name + root_models_classnames: list[str] = list(map(lambda model: model.class_name, filter(is_root_model, models.values()))) + + for root_model_classname in root_models_classnames: + models.pop(root_model_classname) + + def get_models(parser: OpenAPIParser, model_path: Path) -> dict[str, object]: r"""A visitor that exposes models and related data to `jinja2` templates. @@ -256,10 +272,11 @@ def get_models(parser: OpenAPIParser, model_path: Path) -> dict[str, object]: Returns: dict[str, object]: Data to be exposed to `jinja2` templates. """ + models: dict[str, DataModel] = parse_datamodels(parser) + delete_root_models(models) - _, sorted_models, __ = sort_data_models(unsorted_data_models=[result for result in parser.results if isinstance(result, DataModel)]) + _, sorted_models, __ = sort_data_models(unsorted_data_models=list(models.values())) - models: dict[str, DataModel] = parse_datamodels(parser) discriminators: list[Discriminator] = parse_discriminators(parser=parser, models=models) set_other_responses_models([operation for operation in parser.operations.values()]) diff --git a/requirements-core.txt b/requirements-core.txt index c687b22f..ce8d1435 100644 --- a/requirements-core.txt +++ b/requirements-core.txt @@ -1,5 +1,5 @@ uri~=2.0.1 requests~=2.31.0 -pydantic==2.5.2 +pydantic~=2.5.2 urllib3==2.1.0 email-validator~=2.1.0.post1 diff --git a/requirements-dev.txt b/requirements-dev.txt index 9bc7b659..22626f0f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,14 +1,10 @@ autoflake~=2.2.1 build~=1.0.3 setuptools~=69.0.2 -coverage~=7.3.2 fastapi-code-generator==0.4.4 flake8~=6.0.0 flake8-black~=0.3.6 flake8-bugbear~=23.12.2 flake8-isort~=6.1.1 flake8-pep585~=0.1.7 -furl~=2.1.3 -prettytable~=3.9.0 -virtualenv~=20.25.0 docformatter~=1.7.5 diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 00000000..6fa5580c --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,4 @@ +furl~=2.1.3 +prettytable~=3.9.0 +virtualenv~=20.24.7 +coverage~=7.3.2 diff --git a/test/core/client/test_api_client.py b/test/core/client/test_api_client.py index 5a50da6e..af89ad48 100644 --- a/test/core/client/test_api_client.py +++ b/test/core/client/test_api_client.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import datetime -import time import unittest from test.core.constant import api as api_constant from test.core.constant import authentication as auth_constant @@ -25,6 +23,7 @@ ) from expediagroup.sdk.core.configuration.client_config import ClientConfig from expediagroup.sdk.core.constant import header as header_constant +from expediagroup.sdk.core.model.api import RequestHeaders from expediagroup.sdk.core.model.exception import service as service_exception @@ -78,7 +77,7 @@ def test_api_client_call(self): body=api_constant.HELLO_WORLD_OBJECT, response_models=[api_constant.HelloWorld], url=api_constant.ENDPOINT, - headers=dict(), + headers=RequestHeaders(), ) self.assertEqual(response_obj.message, api_constant.HELLO_WORLD_MESSAGE) @@ -103,7 +102,9 @@ def test_api_client_call_missing_url(self): api_client = ApiClient(Configs.client_config, _ExpediaGroupAuthClient) with self.assertRaises(Exception) as call_missing_url_test: - api_client.call(body=api_constant.HELLO_WORLD_OBJECT, method=api_constant.METHOD, response_models=[api_constant.HelloWorld], headers=dict()) + api_client.call( + body=api_constant.HELLO_WORLD_OBJECT, method=api_constant.METHOD, response_models=[api_constant.HelloWorld], headers=RequestHeaders() + ) @mock.patch.object(_ExpediaGroupAuthClient, "_ExpediaGroupAuthClient__retrieve_token", Mocks.authorized_retrieve_token_mock) @mock.patch("expediagroup.sdk.core.client.api.requests.request", Mocks.hello_world_request_response_mock) @@ -111,7 +112,7 @@ def test_api_client_call_default_response_model(self): api_client = ApiClient(Configs.client_config, _ExpediaGroupAuthClient) response_obj: api_constant.HelloWorld = api_client.call( - method=api_constant.METHOD, body=api_constant.HELLO_WORLD_OBJECT, url=api_constant.ENDPOINT, headers=dict() + method=api_constant.METHOD, body=api_constant.HELLO_WORLD_OBJECT, url=api_constant.ENDPOINT, headers=RequestHeaders() ) self.assertIsNone(response_obj) @@ -121,7 +122,7 @@ def test_api_client_call_missing_obj(self): api_client = ApiClient(Configs.client_config, _ExpediaGroupAuthClient) with self.assertRaises(Exception) as call_missing_obj_test: - api_client.call(method=api_constant.METHOD, url=api_constant.ENDPOINT, response_models=[api_constant.HelloWorld], headers=dict()) + api_client.call(method=api_constant.METHOD, url=api_constant.ENDPOINT, response_models=[api_constant.HelloWorld], headers=RequestHeaders()) @mock.patch.object(_ExpediaGroupAuthClient, "_ExpediaGroupAuthClient__retrieve_token", Mocks.authorized_retrieve_token_mock) @mock.patch("expediagroup.sdk.core.client.api.requests.request", Mocks.invalid_request_response_mock) @@ -134,7 +135,7 @@ def test_error_response(self): method=api_constant.METHOD, url=api_constant.ENDPOINT, response_models=[api_constant.HelloWorld], - headers=dict(), + headers=RequestHeaders(), ) @mock.patch.object(_ExpediaGroupAuthClient, "_ExpediaGroupAuthClient__retrieve_token", Mocks.authorized_retrieve_token_mock) @@ -146,7 +147,7 @@ def test_api_client_call_missing_method(self): body=api_constant.HELLO_WORLD_OBJECT, url=api_constant.ENDPOINT, response_models=[api_constant.HelloWorld], - headers=dict(), + headers=RequestHeaders(), ) @mock.patch.object(_ExpediaGroupAuthClient, "_ExpediaGroupAuthClient__retrieve_token", Mocks.authorized_retrieve_token_mock) @@ -155,7 +156,7 @@ def test_api_client_call_none_body(self): api_client = ApiClient(Configs.client_config, _ExpediaGroupAuthClient) response_obj: api_constant.HelloWorld = api_client.call( - method=api_constant.METHOD, response_models=[api_constant.HelloWorld], url=api_constant.ENDPOINT, headers=dict(), body=None + method=api_constant.METHOD, response_models=[api_constant.HelloWorld], url=api_constant.ENDPOINT, headers=RequestHeaders(), body=None ) self.assertEqual(response_obj.message, api_constant.HELLO_WORLD_MESSAGE) diff --git a/test/core/constant/api.py b/test/core/constant/api.py index 1df338d6..26313c03 100644 --- a/test/core/constant/api.py +++ b/test/core/constant/api.py @@ -56,7 +56,7 @@ def hello_world_response(): response.url = auth_constant.AUTH_ENDPOINT response.code = "ok" response.headers = dict() - response._content = json.dumps(HELLO_WORLD_OBJECT, default=pydantic.schema.pydantic_encoder).encode() + response._content = HELLO_WORLD_OBJECT.model_dump_json().encode() return response @staticmethod @@ -65,5 +65,5 @@ def invalid_response(): response.status_code = HTTPStatus.BAD_REQUEST response.url = auth_constant.AUTH_ENDPOINT response.code = "Bad Request" - response._content = json.dumps(ERROR_OBJECT, default=pydantic.schema.pydantic_encoder).encode() + response._content = ERROR_OBJECT.model_dump_json().encode() return response diff --git a/test/core/constant/authentication.py b/test/core/constant/authentication.py index 7c9ecb58..92184628 100644 --- a/test/core/constant/authentication.py +++ b/test/core/constant/authentication.py @@ -80,7 +80,7 @@ def default_token_response(): response.status_code = HTTPStatus.OK response.url = AUTH_ENDPOINT response.code = "ok" - response._content = json.dumps(TOKEN_RESPONSE_DATA.copy(), default=pydantic.schema.pydantic_encoder).encode() + response._content = json.dumps(TOKEN_RESPONSE_DATA.copy()).encode() return response @@ -93,7 +93,7 @@ def eleven_seconds_expiration_token_response(): content = TOKEN_RESPONSE_DATA.copy() content[EXPIRES_IN] = 11 - response._content = json.dumps(content, default=pydantic.schema.pydantic_encoder).encode() + response._content = json.dumps(content).encode() return response @staticmethod diff --git a/test/core/constant/pydantic_model.py b/test/core/constant/pydantic_model.py index 90ffebac..4e41618f 100644 --- a/test/core/constant/pydantic_model.py +++ b/test/core/constant/pydantic_model.py @@ -1,18 +1,30 @@ from typing import Literal, Union -from pydantic import BaseModel, Extra +from pydantic import BaseModel, ConfigDict, Extra class PolygonPydanticModels: - class Polygon(BaseModel, smart_union=True, extra=Extra.forbid): + class Polygon(BaseModel): + model_config = ConfigDict( + extra=Extra.forbid, + ) + type: Literal["Polygon"] coordinates: list[list[int]] - class MultiPolygon(BaseModel, smart_union=True, extra=Extra.forbid): + class MultiPolygon(BaseModel): + model_config = ConfigDict( + extra=Extra.forbid, + ) + type: Literal["MultiPolygon"] coordinates: list[list[list[int]]] - class FloatCoordinatesPolygon(BaseModel, smart_union=True, extra=Extra.forbid): + class FloatCoordinatesPolygon(BaseModel): + model_config = ConfigDict( + extra=Extra.forbid, + ) + type: Literal["FloatCoordinatesPolygon"] coordinates: list[list[float]] @@ -21,12 +33,16 @@ class TypeAliases: BoundingPolygon = Union[PolygonPydanticModels.Polygon, PolygonPydanticModels.MultiPolygon] -PolygonPydanticModels.Polygon.update_forward_refs() -PolygonPydanticModels.MultiPolygon.update_forward_refs() +PolygonPydanticModels.Polygon.model_rebuild() +PolygonPydanticModels.MultiPolygon.model_rebuild() class PolymorphicPydanticModels: - class PolygonWrapper(BaseModel, smart_union=True): + class PolygonWrapper(BaseModel): + model_config = ConfigDict( + extra=Extra.forbid, + ) + polygon: TypeAliases.BoundingPolygon diff --git a/test/core/model/test_authentication_model.py b/test/core/model/test_authentication_model.py index 1823622c..09102c66 100644 --- a/test/core/model/test_authentication_model.py +++ b/test/core/model/test_authentication_model.py @@ -28,7 +28,7 @@ class TokenTest(unittest.TestCase): def test_token_response_model(self): - token_response: _TokenResponse = _TokenResponse.parse_obj(auth_constant.TOKEN_RESPONSE_DATA) + token_response: _TokenResponse = _TokenResponse.model_validate(auth_constant.TOKEN_RESPONSE_DATA) self.assertIsNotNone(token_response) self.assertIsNotNone(token_response.expires_in) diff --git a/test/core/model/test_pydantic_model.py b/test/core/model/test_pydantic_model.py index 829e6ca3..7d82a16d 100644 --- a/test/core/model/test_pydantic_model.py +++ b/test/core/model/test_pydantic_model.py @@ -1,8 +1,7 @@ import unittest from test.core.constant.pydantic_model import * -from pydantic import parse_obj_as -from pydantic.error_wrappers import ValidationError +from pydantic import ValidationError class PydanticModelsTest(unittest.TestCase): @@ -60,7 +59,7 @@ def test_create_invalid_polymorphic_object(self): wrapped_polygon: PolymorphicPydanticModels.PolygonWrapper = PolymorphicPydanticModels.PolygonWrapper(polygon=float_coordinates_polygon) def test_serialize_objects(self): - serialized_polygon = PolygonObjects.POLYGON.dict() + serialized_polygon = PolygonObjects.POLYGON.model_dump() self.assertIsNotNone(serialized_polygon) self.assertIsNotNone(serialized_polygon["type"]) @@ -69,7 +68,7 @@ def test_serialize_objects(self): self.assertEqual(serialized_polygon, PolygonDictData.POLYGON) def test_deserialize_object(self): - deserialized_polygon = parse_obj_as(PolygonPydanticModels.Polygon, PolygonDictData.POLYGON) + deserialized_polygon = PolygonPydanticModels.Polygon.model_validate(PolygonDictData.POLYGON) self.assertIsNotNone(deserialized_polygon) self.assertIsNotNone(deserialized_polygon.type) @@ -81,11 +80,11 @@ def test_deserialize_object(self): def test_test_deserialize_object_invalid_data(self): with self.assertRaises(ValidationError): - deserialized_polygon = parse_obj_as(PolygonPydanticModels.MultiPolygon, PolygonDictData.POLYGON) + deserialized_polygon = PolygonPydanticModels.MultiPolygon.model_validate(PolygonDictData.POLYGON) def test_deserialize_polymorphic_object(self): # Case 1 - wrapped_polygon: PolymorphicPydanticModels.PolygonWrapper = parse_obj_as(PolymorphicPydanticModels.PolygonWrapper, PolygonDictData.WRAPPED_POLYGON) + wrapped_polygon: PolymorphicPydanticModels.PolygonWrapper = PolymorphicPydanticModels.PolygonWrapper.model_validate(PolygonDictData.WRAPPED_POLYGON) self.assertIsNotNone(wrapped_polygon) self.assertIsNotNone(wrapped_polygon.polygon) @@ -96,8 +95,8 @@ def test_deserialize_polymorphic_object(self): self.assertTrue(isinstance(wrapped_polygon.polygon, PolygonPydanticModels.Polygon)) # Case 2 - wrapped_multi_polygon: PolymorphicPydanticModels.PolygonWrapper = parse_obj_as( - PolymorphicPydanticModels.PolygonWrapper, PolygonDictData.WRAPPED_MULTI_POLYGON + wrapped_multi_polygon: PolymorphicPydanticModels.PolygonWrapper = PolymorphicPydanticModels.PolygonWrapper.model_validate( + PolygonDictData.WRAPPED_MULTI_POLYGON ) self.assertIsNotNone(wrapped_multi_polygon)