Skip to content

Commit

Permalink
fixed a bug where package modules would be created with an extra dire…
Browse files Browse the repository at this point in the history
…ctory
  • Loading branch information
Masara committed May 3, 2024
1 parent e4fdb39 commit ef2d167
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions src/safeds_stubgen/stubs_generator/_generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class NamingConvention(IntEnum):
def generate_stub_data(
stubs_generator: StubsStringGenerator,
out_path: Path,
) -> list[tuple[Path, str, str]]:
) -> list[tuple[Path, str, str, bool]]:
"""Generate Safe-DS stubs.
Generates stub data from an API object.
Expand All @@ -54,10 +54,11 @@ def generate_stub_data(
Returns
-------
A list of tuples, which are 1. the path of the stub file, 2. the name of the stub file and 3. its content.
A list of tuples, which are 1. the path of the stub file, 2. the name of the stub file, 3. its content and 4. if
it's a package file (created through init reexports).
"""
api = stubs_generator.api
stubs_data: list[tuple[Path, str, str]] = []
stubs_data: list[tuple[Path, str, str, bool]] = []
for module in api.modules.values():
if module.name == "__init__":
continue
Expand Down Expand Up @@ -92,7 +93,7 @@ def generate_stub_data(
module_name = alias if alias else module.name

module_dir = Path(out_path / module_id)
stubs_data.append((module_dir, module_name, module_text))
stubs_data.append((module_dir, module_name, module_text, False))

reexport_module_data = stubs_generator.create_reexport_module_strings(out_path=out_path)

Expand All @@ -101,20 +102,28 @@ def generate_stub_data(

def create_stub_files(
stubs_generator: StubsStringGenerator,
stubs_data: list[tuple[Path, str, str]],
stubs_data: list[tuple[Path, str, str, bool]],
out_path: Path,
) -> None:
naming_convention = stubs_generator.naming_convention
for module_dir, module_name, module_text in stubs_data:
log_msg = f"Creating stub file for {module_dir}"
# A "package module" is a module which is created though the reexported classes and functions in the __init__.py
for module_dir, module_name, module_text, is_package_module in stubs_data:
if is_package_module:
# Cut out the last part of the path, since we don't want "path/to/package/package.sdsstubs" but
# "path/to/package.sdsstubs" so that "package" is not doubled
corrected_module_dir = Path("/".join(module_dir.parts[:-1]))
else:
corrected_module_dir = module_dir

log_msg = f"Creating stub file for {corrected_module_dir}"
logging.info(log_msg)

# Create module dir
module_dir.mkdir(parents=True, exist_ok=True)
corrected_module_dir.mkdir(parents=True, exist_ok=True)

# Create and open module file
public_module_name = module_name.lstrip("_")
file_path = Path(module_dir / f"{public_module_name}.sdsstub")
file_path = Path(corrected_module_dir / f"{public_module_name}.sdsstub")
Path(file_path).touch()

with file_path.open("w") as f:
Expand Down Expand Up @@ -204,7 +213,7 @@ def __call__(self, module: Module) -> tuple[str, str]:
self.class_generics: list = []
return self._create_module_string()

def create_reexport_module_strings(self, out_path: Path) -> list[tuple[Path, str, str]]:
def create_reexport_module_strings(self, out_path: Path) -> list[tuple[Path, str, str, bool]]:
module_data = []
for module_id in self.reexport_modules:
module_name = module_id.split("/")[-1]
Expand Down Expand Up @@ -234,7 +243,7 @@ def create_reexport_module_strings(self, out_path: Path) -> list[tuple[Path, str
# We add the class text after global functions
module_text += class_text

module_data.append((Path(out_path / module_id), module_name, module_text))
module_data.append((Path(out_path / module_id), module_name, module_text, True))
return module_data

def _create_module_string(self) -> tuple[str, str]:
Expand Down Expand Up @@ -1079,6 +1088,16 @@ def _create_sds_docstring(

# ############################### Utilities ############################### #

def is_qname_in_reexports(self, qname: str) -> bool:
name = qname.split("/")[-1]
for reexport in self.api.reexport_map:
if reexport.endswith(name):
for module in self.api.reexport_map[reexport]:
# Added pragma: no cover since I can't recreate this in the tests
if qname.startswith(module.id) and qname.lstrip(module.id).lstrip("/") == name: # pragma: no cover
return True
return False

def _add_to_imports(self, qname: str) -> None:
"""Check if the qname of a type is defined in the current module, if not, create an import for it.
Expand All @@ -1101,7 +1120,7 @@ def _add_to_imports(self, qname: str) -> None:
qname_path = qname.replace(".", "/")
in_package = False
for class_ in self.api.classes:
if class_.endswith(qname_path):
if class_.endswith(qname_path) or self.is_qname_in_reexports(qname_path):
qname = class_.replace("/", ".")

name = qname.split(".")[-1]
Expand Down

0 comments on commit ef2d167

Please sign in to comment.