diff --git a/fastapi_code_generator/__main__.py b/fastapi_code_generator/__main__.py index a551e87..630f7dd 100644 --- a/fastapi_code_generator/__main__.py +++ b/fastapi_code_generator/__main__.py @@ -43,7 +43,8 @@ def dynamic_load_module(module_path: Path) -> Any: @app.command() def main( - input_file: typer.FileText = typer.Option(..., "--input", "-i"), + encoding: str = typer.Option("utf-8", "--encoding", "-e"), + input_file: str = typer.Option(..., "--input", "-i"), output_dir: Path = typer.Option(..., "--output", "-o"), model_file: str = typer.Option(None, "--model-file", "-m"), template_dir: Optional[Path] = typer.Option(None, "--template-dir", "-t"), @@ -57,8 +58,12 @@ def main( ), disable_timestamp: bool = typer.Option(False, "--disable-timestamp"), ) -> None: - input_name: str = input_file.name - input_text: str = input_file.read() + input_name: str = input_file + input_text: str + + with open(input_file, encoding=encoding) as f: + input_text = f.read() + if model_file: model_path = Path(model_file).with_suffix('.py') else: @@ -68,6 +73,7 @@ def main( return generate_code( input_name, input_text, + encoding, output_dir, template_dir, model_path, @@ -80,6 +86,7 @@ def main( return generate_code( input_name, input_text, + encoding, output_dir, template_dir, model_path, @@ -103,6 +110,7 @@ def _get_most_of_reference(data_type: DataType) -> Optional[Reference]: def generate_code( input_name: str, input_text: str, + encoding: str, output_dir: Path, template_dir: Optional[Path], model_path: Optional[Path] = None, @@ -218,7 +226,9 @@ def generate_code( header += f"\n# timestamp: {timestamp}" for path, code in results.items(): - with output_dir.joinpath(path.with_suffix(".py")).open("wt") as file: + with output_dir.joinpath(path.with_suffix(".py")).open( + "wt", encoding=encoding + ) as file: print(header, file=file) print("", file=file) print(code.rstrip(), file=file) diff --git a/tests/test_generate.py b/tests/test_generate.py index b6278e3..32ce824 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -22,6 +22,8 @@ SPECIFIC_TAGS = 'Wild Boars, Fat Cats' +ENCODING = 'utf-8' + @pytest.mark.parametrize( "oas_file", (DATA_DIR / OPEN_API_DEFAULT_TEMPLATE_DIR_NAME).glob("*.yaml") @@ -33,6 +35,7 @@ def test_generate_default_template(oas_file): generate_code( input_name=oas_file.name, input_text=oas_file.read_text(), + encoding=ENCODING, output_dir=output_dir, template_dir=None, ) @@ -54,6 +57,7 @@ def test_generate_custom_security_template(oas_file): generate_code( input_name=oas_file.name, input_text=oas_file.read_text(), + encoding=ENCODING, output_dir=output_dir, template_dir=DATA_DIR / 'custom_template' / 'security', ) @@ -79,6 +83,7 @@ def test_generate_remote_ref(mocker): generate_code( input_name=oas_file.name, input_text=oas_file.read_text(), + encoding=ENCODING, output_dir=output_dir, template_dir=None, ) @@ -105,6 +110,7 @@ def test_disable_timestamp(oas_file): generate_code( input_name=oas_file.name, input_text=oas_file.read_text(), + encoding=ENCODING, output_dir=output_dir, template_dir=None, disable_timestamp=True, @@ -130,6 +136,7 @@ def test_generate_using_routers(oas_file): generate_code( input_name=oas_file.name, input_text=oas_file.read_text(), + encoding=ENCODING, output_dir=output_dir, template_dir=BUILTIN_MODULAR_TEMPLATE_DIR, generate_routers=True, @@ -166,6 +173,7 @@ def test_generate_modify_specific_routers(oas_file): generate_code( input_name=oas_file.name, input_text=oas_file.read_text(), + encoding=ENCODING, output_dir=output_dir, template_dir=BUILTIN_MODULAR_TEMPLATE_DIR, generate_routers=True,