Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(deps): pydantic 1.x.x to 2.x.x migration #292

Merged
merged 9 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/validate-test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
with:
python-version: '3.11'
- name: Install Dependencies
run: pip install -r requirements-dev.txt
run: pip install -r requirements-test.txt && pip install -r requirements-core.txt
- name: Coverage Run
run: coverage run --omit=$(cat test/exclude) --branch -m unittest
- name: Coverage Report
Expand Down
50 changes: 19 additions & 31 deletions expediagroup/sdk/core/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import json
import logging
import uuid
from copy import deepcopy
from http import HTTPStatus
from typing import Any, Optional

import pydantic
import pydantic.schema
import requests
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter

from expediagroup.sdk.core.client.auth_client import AuthClient
from expediagroup.sdk.core.configuration.client_config import ClientConfig
from expediagroup.sdk.core.constant import header as header_constant
from expediagroup.sdk.core.constant import log as log_constant
from expediagroup.sdk.core.constant.constant import OK_STATUS_CODES_RANGE
from expediagroup.sdk.core.model.api import RequestHeaders
from expediagroup.sdk.core.model.error import Error
from expediagroup.sdk.core.model.exception import service as service_exception
from expediagroup.sdk.core.util import log as log_util
Expand All @@ -52,18 +49,18 @@ def __init__(self, config: ClientConfig, auth_client_cls):
@staticmethod
def __build_response(
response: requests.Response,
response_models: list[pydantic.BaseModel],
response_models: list[type],
error_responses: dict[int, Any],
):
if response.status_code not in OK_STATUS_CODES_RANGE:
exception: service_exception.ExpediaGroupApiException

if response.status_code in error_responses.keys():
error_object = pydantic.parse_obj_as(error_responses[response.status_code].model, response.json())
error_object = error_responses[response.status_code].model.model_validate(response.json())
exception = error_responses[response.status_code].exception.of(error=error_object, error_code=HTTPStatus(response.status_code))
else:
exception = service_exception.ExpediaGroupApiException.of(
error=Error.parse_obj(response.json()),
error=Error.model_validate(response.json()),
error_code=HTTPStatus(response.status_code),
)

Expand All @@ -74,7 +71,7 @@ def __build_response(
if not model:
continue
try:
response_object = pydantic.parse_obj_as(model, response.json())
response_object = TypeAdapter(model).validate_python(response.json())
return response_object
except Exception:
continue
Expand All @@ -85,10 +82,10 @@ def call(
self,
method: str,
url: str,
body: pydantic.BaseModel,
headers: dict = dict(), # noqa
response_models: Optional[list[Any]] = [None], # noqa
error_responses: dict[int, Any] = {}, # noqa
body: BaseModel,
headers: RequestHeaders = RequestHeaders(), # noqa
response_models: Optional[list[Any]] = list(), # noqa
error_responses: dict[int, Any] = dict(), # noqa
) -> Any:
r"""Sends HTTP request to API.

Expand All @@ -113,7 +110,7 @@ def call(
timeout=self.request_timeout,
)
else:
request_body = body.json(exclude_none=True)
request_body = body.model_dump_json(exclude_none=True)
response = requests.request(
method=method.upper(),
url=url,
Expand All @@ -123,7 +120,7 @@ def call(
timeout=self.request_timeout,
)

logged_body: dict[str, Any] = dict() if not body else body.dict()
logged_body: dict[str, Any] = dict() if not body else body.model_dump()

request_log_message = log_util.request_log(
headers=request_headers,
Expand All @@ -146,24 +143,15 @@ def call(
@staticmethod
def __fill_request_headers(request_headers: dict):
if not request_headers:
request_headers = dict()
return header_constant.API_REQUEST

request_header_keys = request_headers.keys()
for key, value in header_constant.API_REQUEST.items():
if key in request_header_keys:
continue

request_headers[key] = value
headers: dict = deepcopy(header_constant.API_REQUEST)
headers.update(request_headers)

return request_headers
return headers

@staticmethod
def __prepare_request_headers(headers: dict) -> dict:
request_headers = dict()
for header_key, header_value in headers.items():
if not header_value:
continue
needs_serialization: bool = isinstance(header_value, BaseModel) or isinstance(header_value, enum.Enum) or isinstance(header_value, uuid.UUID)
request_headers[header_key] = json.dumps(header_value, default=pydantic.schema.pydantic_encoder) if needs_serialization else header_value
def __prepare_request_headers(headers: RequestHeaders) -> dict:
request_headers: dict = headers.unwrap()

return ApiClient.__fill_request_headers(request_headers)
46 changes: 46 additions & 0 deletions expediagroup/sdk/core/model/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2022 Expedia, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import Any

from pydantic import BaseModel, Field


class RequestHeaders(BaseModel):
"""
RequestHeaders class represents the headers of an HTTP request.

Attributes:
headers (Any): The HTTP request headers. It can be of any type.
"""

headers: Any = Field(default=None)

def unwrap(self) -> dict[str, Any]:
"""
Unwraps the headers from the model.

Returns:
A dictionary containing the headers.

Example:
>>> headers = RequestHeaders()
>>> headers.unwrap()
{'Content-Type': 'application/json', 'Authorization': 'Bearer token'}
"""
if not self.headers:
return dict()

return json.loads(self.model_dump_json()).get("headers")
6 changes: 3 additions & 3 deletions expediagroup/sdk/core/model/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from dataclasses import dataclass
from multiprocessing import Lock
from typing import Optional
from typing import Optional, Union

import pydantic.schema
import requests
Expand All @@ -40,7 +40,7 @@ class _TokenResponse(pydantic.BaseModel):
"""A model of an API response."""

access_token: str
expires_in: int
expires_in: Union[int, float]
scope: str
token_type: str
id_token: Optional[str] = None
Expand All @@ -53,7 +53,7 @@ def __init__(self, data: dict):

:param data: token data
"""
self.__token: _TokenResponse = _TokenResponse.parse_obj(data)
self.__token: _TokenResponse = _TokenResponse.model_validate(data)
self.lock = Lock()
self.__expiration_time = datetime.datetime.now() + datetime.timedelta(seconds=self.__token.expires_in)
self.__auth_header = HttpBearerAuth(self.__token.access_token)
Expand Down
77 changes: 77 additions & 0 deletions expediagroup/sdk/generator/client/datatype_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2022 Expedia, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Any

from datamodel_code_generator.model import pydantic as datamodel_code_generator_pydantic
from datamodel_code_generator.model.pydantic.imports import IMPORT_CONSTR
from datamodel_code_generator.model.pydantic.types import (
escape_characters,
string_kwargs,
transform_kwargs,
)
from datamodel_code_generator.types import DataType, StrictTypes, Types

PYDANTIC_V2_MIGRATION_CONSTRAINTS_MAPPING: dict[str, str] = {"regex": "pattern"}


class PydanticV2DataTypeManager(datamodel_code_generator_pydantic.DataTypeManager):
r"""
Custom DataTypeManager to map PydanticV1 types to PydanticV2.

Notes:
- This class is a temporary solution until `fastapi-code-generator` bumps up
its `datamodel-code-generator` from `0.16.1` to `>=0.25.1` which includes
PydanticV2 support.
GitHub Issue: https://github.com/koxudaxi/fastapi-code-generator/issues/378
mohnoor94 marked this conversation as resolved.
Show resolved Hide resolved
"""

@staticmethod
def migrate_datatype_constraints(data_type_kwargs: dict[str, Any]) -> dict[str, Any]:
"""
Migrates datatype constraints in the given data_type_kwargs dictionary from pydantic v1 to v2 .

Args:
data_type_kwargs (dict[str, Any]): A dictionary containing datatype constraints.

Returns:
dict[str, Any]: The migrated datatype constraints dictionary.
"""
migrated_data_type_kwargs: dict[str, Any] = deepcopy(data_type_kwargs)

for key, value in PYDANTIC_V2_MIGRATION_CONSTRAINTS_MAPPING.items():
if migrated_data_type_kwargs.get(key):
migrated_data_type_kwargs.update([(value, migrated_data_type_kwargs.get(key))])
migrated_data_type_kwargs.pop(key)

return migrated_data_type_kwargs

def get_data_str_type(self, types: Types, **kwargs: Any) -> DataType:
data_type_kwargs: dict[str, Any] = transform_kwargs(kwargs, string_kwargs)
strict = StrictTypes.str in self.strict_types
if data_type_kwargs:
if strict:
data_type_kwargs["strict"] = True
if "regex" in data_type_kwargs:
escaped_regex = data_type_kwargs["regex"].translate(escape_characters)
data_type_kwargs["regex"] = f"r'{escaped_regex}'"

# Copied code, single line added.
data_type_kwargs = PydanticV2DataTypeManager.migrate_datatype_constraints(data_type_kwargs)

return self.data_type.from_import(IMPORT_CONSTR, kwargs=data_type_kwargs)
if strict:
return self.strict_type_map[StrictTypes.str]
return self.type_map[types]
3 changes: 2 additions & 1 deletion expediagroup/sdk/generator/client/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
RequestBodyObject,
)
from datamodel_code_generator.types import DataType, DataTypeManager, StrictTypes, Types
from datatype_manager import PydanticV2DataTypeManager
from fastapi_code_generator import parser
from model import Argument, Operation, ParamTypes
from stringcase import snakecase
Expand All @@ -73,7 +74,7 @@ def __init__(
*,
data_model_type: type[DataModel] = pydantic_model.BaseModel,
data_model_root_type: type[DataModel] = pydantic_model.CustomRootType,
data_type_manager_type: type[DataTypeManager] = pydantic_model.DataTypeManager,
data_type_manager_type: type[DataTypeManager] = PydanticV2DataTypeManager,
data_model_field_type: type[DataModelFieldBase] = pydantic_model.DataModelField,
base_class: Optional[str] = None,
custom_template_dir: Optional[pathlib.Path] = None,
Expand Down
29 changes: 13 additions & 16 deletions expediagroup/sdk/generator/client/templates/__model__.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,36 @@
{# limitations under the License.#}
{{ model_imports }}

from collections.abc import Callable
from typing import Union, Any, Literal
from pydantic import Extra, validator, SecretStr, SecretBytes
from pydantic import field_validator, SecretStr, SecretBytes, ConfigDict
from pydantic.dataclasses import dataclass
from expediagroup.sdk.core.model.exception.service import ExpediaGroupApiException

SecretStr.__str__ = lambda self: '<-- omitted -->' if self.get_secret_value() else ''

class PydanticModelConfig:
r"""List of configurations for all SDK pydantic models."""

JSON_ENCODERS: dict[type, Callable] = {
SecretStr: lambda v: v.get_secret_value() if v else None,
SecretBytes: lambda v: v.get_secret_value() if v else None,
}

EXTRA: bool = Extra.forbid

SMART_UNION: bool = True
class PydanticModel(BaseModel):
r"""Generic model that is a parent to all pydantic models, holds models configuration."""

model_config: dict[str, Any] = ConfigDict(
extra="forbid",
json_encoders={
SecretStr: lambda v: v.get_secret_value() if v else None,
SecretBytes: lambda v: v.get_secret_value() if v else None,
}
)

{% for model in models %}
{% for decorator in model.decorators -%}
{{ decorator }}
{% endfor -%}

class {{ model.class_name }}{% if is_aliased[model.class_name] %}Generic{% endif %}({{ model.base_class }}{% if is_aliased[model.base_class] %}Generic{% endif %},{% if model.base_class != 'Enum' %} smart_union=PydanticModelConfig.SMART_UNION, extra=PydanticModelConfig.EXTRA, json_encoders=PydanticModelConfig.JSON_ENCODERS{% endif %}): {% if comment is defined %} # {{ model.comment }}{% endif %}
class {{ model.class_name }}{% if is_aliased[model.class_name] %}Generic{% endif %}({% if model.base_class == 'BaseModel' %}PydanticModel{% else %}{{ model.base_class }}{% endif %}{% if is_aliased[model.base_class] %}Generic{% endif %},): {% if comment is defined %} # {{ model.comment }}{% endif %}
r"""pydantic model {{ model.class_name }}{%- if model.description %}: {{ model.description }}{%- endif %}
{# comment for new line #}
"""
{% if model.class_name in omitted_log_fields.keys() %}
{% for field in omitted_log_fields[model.class_name] %}
@validator("{{ field }}")
@field_validator("{{ field }}")
def __{{ field }}_validator(cls, {{ field }}):
return SecretStr(str({{ field }}))
{% endfor %}
Expand Down Expand Up @@ -82,7 +79,7 @@ class {{ model.class_name }}{% if is_aliased[model.class_name] %}Generic{% endif

{% for model in models %}
{% if not is_aliased[model.class_name] and model.base_class != 'Enum' %}
{{ model.class_name }}.update_forward_refs()
{{ model.class_name }}.model_rebuild()
{% endif %}
{% endfor %}

Expand Down
5 changes: 3 additions & 2 deletions expediagroup/sdk/generator/client/templates/client.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import pydantic.schema
from expediagroup.sdk.core.client.api import ApiClient
from expediagroup.sdk.core.constant import header
from expediagroup.sdk.core.configuration.client_config import ClientConfig
from expediagroup.sdk.core.model.api import RequestHeaders
from furl import furl
from uuid import UUID, uuid4
{% if error_responses_models.__len__() %}
Expand Down Expand Up @@ -36,15 +37,15 @@ class {{ classname }}:
Args:
{% for arguemnt in operation.snake_case_arguments_list %} {{ arguemnt.name }}({{ arguemnt.type_hint }}{% if not arguemnt.required %}, optional{% endif %}): {{ arguemnt.description.replace("\n", "") }}
{% endfor %}"""
headers = {
headers = RequestHeaders(headers={
header.TRANSACTION_ID: transaction_id,
header.USER_AGENT: self.__user_agent,
{% for arguemnt in operation.snake_case_arguments_list %}
{% if arguemnt.in_.value == 'header' %}
'{{ arguemnt.alias }}': {{ arguemnt.name.strip() }},
{% endif %}
{% endfor %}
}
})

query = {key: value for key, value in {
{% for arguemnt in operation.snake_case_arguments_list %}
Expand Down
Loading