Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(core, generator): Raise exceptions instead of returning them. #206

Merged
merged 8 commits into from
Sep 5, 2023
30 changes: 24 additions & 6 deletions expediagroup/sdk/core/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,24 @@ def __init__(self, config: ClientConfig, auth_client_cls):
self.request_timeout = config.request_timeout

@staticmethod
def __build_response(response: requests.Response, response_models: list[pydantic.BaseModel]):
def __build_response(
response: requests.Response,
response_models: list[pydantic.BaseModel],
error_responses: dict[int, Any],
):
if response.status_code not in OK_STATUS_CODES_RANGE:
raise service_exception.ExpediaGroupServiceException.of(
error=Error.parse_obj(response.json()),
error_code=HTTPStatus(response.status_code),
)
exception: service_exception.ExpediaGroupApiException

if response.status_code in error_responses.keys():
error_object = pydantic.parse_obj_as(error_responses[response.status_code].model, response.json())
exception = error_responses[response.status_code].exception.of(error=error_object, error_code=HTTPStatus(response.status_code))
else:
exception = service_exception.ExpediaGroupApiException.of(
error=Error.parse_obj(response.json()),
error_code=HTTPStatus(response.status_code),
)

raise exception

response_object = None
for model in response_models:
Expand All @@ -76,6 +88,7 @@ def call(
body: pydantic.BaseModel,
headers: dict = dict(), # noqa
response_models: Optional[list[Any]] = [None], # noqa
error_responses: dict[int, Any] = {}, # noqa
) -> Any:
r"""Sends HTTP request to API.

Expand Down Expand Up @@ -121,7 +134,12 @@ def call(

LOG.info(log_constant.EXPEDIAGROUP_LOG_MESSAGE_TEMPLATE.format(request_log_message))

result = ApiClient.__build_response(response=response, response_models=response_models)
result = ApiClient.__build_response(
response=response,
response_models=response_models,
error_responses=error_responses,
)

return result

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions expediagroup/sdk/core/model/exception/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from expediagroup.sdk.core.model.exception.expediagroup import ExpediaGroupException


class ExpediaGroupServiceException(ExpediaGroupException):
class ExpediaGroupApiException(ExpediaGroupException):
def __init__(self, message: str, cause: Optional[BaseException] = None):
super().__init__(message, cause)

@staticmethod
def of(error: Error, error_code: HTTPStatus):
return ExpediaGroupServiceException(message=f"[{error_code.value}] {error}")
return ExpediaGroupApiException(message=f"[{error_code.value}] {error}")


class ExpediaGroupAuthException(ExpediaGroupServiceException):
class ExpediaGroupAuthException(ExpediaGroupApiException):
def __init__(self, error_code: HTTPStatus, message: str):
super().__init__(message=f"[{error_code.value}] {message}")
1 change: 1 addition & 0 deletions expediagroup/sdk/generator/client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ def argument(self) -> str:
class Operation(fastapi_code_generator_parser.Operation):
arguments_list: list[Argument] = []
snake_case_arguments_list: list[Argument] = []
error_responses: dict[int, typing.Any] = dict()
15 changes: 15 additions & 0 deletions expediagroup/sdk/generator/client/templates/__model__.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
{{ model_imports }}
from typing import Union, Any, Literal
from pydantic import Extra
from pydantic.dataclasses import dataclass
from expediagroup.sdk.core.model.exception.service import ExpediaGroupApiException


{% for model in models %}
Expand Down Expand Up @@ -67,3 +69,16 @@ class {{ model.class_name }}{% if is_aliased[model.class_name] %}Generic{% endif
{% endif %}
{% endfor %}

{% for error_model in error_responses_models %}
class ExpediaGroup{{ error_model }}Exception(ExpediaGroupApiException):
r"""Exception wrapping a {{ error_model }} object."""
pass
{% endfor %}


{% for error_model in error_responses_models %}
@dataclass
class {{ error_model }}DeserializationContract:
exception: type = ExpediaGroup{{ error_model }}Exception
model: type = {{ error_model }}
{% endfor %}
12 changes: 11 additions & 1 deletion expediagroup/sdk/generator/client/templates/client.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ from expediagroup.sdk.core.constant import header
from expediagroup.sdk.core.configuration.client_config import ClientConfig
from furl import furl
from uuid import UUID, uuid4
{% if error_responses_models.__len__() %}
from .model import ({% for error_model in error_responses_models %}{{ error_model }}DeserializationContract,{% endfor %}
)
{% endif %}
{% if api.lower() == "rapid" %}from expediagroup.sdk.core.client.rapid_auth_client import _RapidAuthClient
{% else %}from expediagroup.sdk.core.client.expediagroup_auth_client import _ExpediaGroupAuthClient
{% endif %}
Expand Down Expand Up @@ -56,11 +60,17 @@ Args:
request_url.query.set(query)
request_url.path.normalize()

error_responses = {
{% for response_code in operation.error_responses.keys() %}{{ response_code }}: {{ operation.error_responses[response_code]["model"] }}DeserializationContract,
{% endfor %}
}

return self.__api_client.call(
headers=headers,
method='{{ operation.method }}',
body={% if 'body' in operation.snake_case_arguments %}body{% else %}None{% endif %},
response_models={{ operation.return_type.removeprefix('Union')}},
url=request_url
url=request_url,
error_responses=error_responses,
)
{% endfor %}
23 changes: 22 additions & 1 deletion expediagroup/sdk/generator/client/visitors/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
import collections
import dataclasses
from pathlib import Path
from typing import Any

from datamodel_code_generator.imports import Imports
from datamodel_code_generator.model import DataModel
from datamodel_code_generator.parser.base import sort_data_models
from datamodel_code_generator.types import DataType
from fastapi_code_generator.parser import OpenAPIParser
from fastapi_code_generator.parser import OpenAPIParser, Operation
from fastapi_code_generator.visitor import Visitor
from pydantic import BaseModel, Extra, Field, parse_obj_as

Expand Down Expand Up @@ -227,6 +228,24 @@ def parse_sorted_aliases(models: dict[str, DataModel], discriminators: list[Disc
return sorted(aliases, key=lambda alias_: alias_.order)


def set_other_responses_models(operations: list[Operation]):
ok_status_code_range = [code for code in range(200, 300)]
for index, operation in enumerate(operations):
error_responses: dict[int, Any] = {
int(code): response for code, response in operation.additional_responses.items() if int(code) not in ok_status_code_range
}

operations[index].error_responses = error_responses


def get_error_models(parser: OpenAPIParser) -> list:
error_models: set[str] = set()
for response in list(map(lambda operation: operation.error_responses, parser.operations.values())):
for model in response.values():
error_models.add(model["model"])
return list(error_models)


def get_models(parser: OpenAPIParser, model_path: Path) -> dict[str, object]:
r"""A visitor that exposes models and related data to `jinja2` templates.

Expand All @@ -243,6 +262,7 @@ def get_models(parser: OpenAPIParser, model_path: Path) -> dict[str, object]:
models: dict[str, DataModel] = parse_datamodels(parser)
discriminators: list[Discriminator] = parse_discriminators(parser=parser, models=models)

set_other_responses_models([operation for operation in parser.operations.values()])
anssari1 marked this conversation as resolved.
Show resolved Hide resolved
apply_discriminators_to_models(discriminators=discriminators, models=models)

aliases: list[Alias] = parse_sorted_aliases(models=models, discriminators=discriminators)
Expand All @@ -253,6 +273,7 @@ def get_models(parser: OpenAPIParser, model_path: Path) -> dict[str, object]:
"model_imports": collect_imports(sorted_models, parser),
"aliases": aliases,
"is_aliased": is_aliased,
"error_responses_models": get_error_models(parser),
}


Expand Down
2 changes: 1 addition & 1 deletion test/core/client/test_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_api_client_call_missing_obj(self):
def test_error_response(self):
api_client = ApiClient(Configs.client_config, _ExpediaGroupAuthClient)

with self.assertRaises(service_exception.ExpediaGroupServiceException) as call_error_response:
with self.assertRaises(service_exception.ExpediaGroupApiException) as call_error_response:
api_client.call(
body=api_constant.HELLO_WORLD_OBJECT,
method=api_constant.METHOD,
Expand Down