Skip to content

Commit

Permalink
Merge pull request #10 from koxudaxi/improve_argument_list
Browse files Browse the repository at this point in the history
add field to argument
  • Loading branch information
koxudaxi authored Jun 17, 2020
2 parents 628e8c3 + d49a7d4 commit 0c7ef52
Showing 1 changed file with 44 additions and 51 deletions.
95 changes: 44 additions & 51 deletions fastapi_code_generator/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,20 @@ def camelcase(self) -> str:
return stringcase.camelcase(self)


class Argument(BaseModel):
class Argument(CachedPropertyModel):
name: UsefulStr
type_hint: UsefulStr
default: Optional[UsefulStr]
required: bool

def __str__(self) -> str:
return self.argument

# def __str__(self) -> UsefulStr:
# return self.name
@cached_property
def argument(self) -> str:
if not self.default and self.required:
return f'{self.name}: {self.type_hint}'
return f'{self.name}: {self.type_hint} = {self.default}'


class Operation(CachedPropertyModel):
Expand All @@ -93,22 +102,28 @@ def snake_case_path(self) -> str:
)

@cached_property
def request(self) -> Optional[str]:
models: List[str] = []
def request(self) -> Optional[Argument]:
arguments: List[Argument] = []
for requests in self.request_objects:
for content_type, schema in requests.contents.items():
# TODO: support other content-types
if content_type == "application/json":
models.append(schema.ref_object_name)
arguments.append(
# TODO: support multiple body
Argument(
name='body', # type: ignore
type_hint=schema.ref_object_name,
required=requests.required,
)
)
self.imports.append(
Import(
from_=model_path_var.get(), import_=schema.ref_object_name
)
)
if not models:
if not arguments:
return None
if len(models) > 1:
return f'Union[{",".join(models)}]'
return models[0]
return arguments[0]

@cached_property
def request_objects(self) -> List[Request]:
Expand Down Expand Up @@ -171,69 +186,47 @@ def snake_case_arguments(self) -> str:
return self.get_arguments(snake_case=True)

def get_arguments(self, snake_case: bool) -> str:
arguments: List[str] = []

if self.parameters:
for parameter in self.parameters:
arguments.append(self.get_parameter_type(parameter, snake_case))

if self.request:
arguments.append(f"body: {self.request}")

return ", ".join(arguments)
return ", ".join(
argument.argument for argument in self.get_argument_list(snake_case)
)

@cached_property
def argument_list(self) -> List[Argument]:
return self.get_argument_list(False)

def get_argument_list(self, snake_case: bool) -> List[Argument]:
arguments: List[Argument] = []

if self.parameters:
for parameter in self.parameters:
arguments.append(Argument.parse_obj(parameter))
arguments.append(self.get_parameter_type(parameter, snake_case))

if self.request:
arguments.append(Argument(name=UsefulStr('body')))

arguments.append(self.request)
return arguments

def get_parameter_type(
self, parameter: Dict[str, Union[str, Dict[str, str]]], snake_case: bool
) -> str:
) -> Argument:
schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"])
format_ = schema.format or "default"
type_ = json_schema_data_formats[schema.type][format_]
return self.get_data_type_hint(
name=stringcase.snakecase(parameter["name"])
if snake_case
else parameter["name"],
name: str = parameter["name"] # type: ignore

field = DataModelField(
name=stringcase.snakecase(name) if snake_case else name,
data_types=[type_map[type_]],
required=parameter.get("required") == "true"
or parameter.get("in") == "path",
snake_case=snake_case,
default=schema.typed_default,
)

def get_data_type_hint(
self,
name: str,
data_types: List[DataType],
required: bool,
snake_case: bool,
default: Optional[str] = None,
auto_import: bool = True,
) -> str:
field = DataModelField(
name=stringcase.snakecase(name) if snake_case else name,
data_types=data_types,
required=required,
default=default,
self.imports.extend(field.imports)
return Argument(
name=field.name,
type_hint=field.type_hint,
default=field.default,
required=field.required,
)
if auto_import:
self.imports.extend(field.imports)

if not default and field.required:
return f"{field.name}: {field.type_hint}"

return f'{field.name}: {field.type_hint} = {default}'

@cached_property
def response(self) -> str:
Expand Down

0 comments on commit 0c7ef52

Please sign in to comment.