Skip to content

Commit

Permalink
fixed a bug where package modules would be created with an extra dire…
Browse files Browse the repository at this point in the history
…ctory
  • Loading branch information
Masara committed May 3, 2024
1 parent ef2d167 commit 8d7ef3b
Showing 1 changed file with 81 additions and 67 deletions.
148 changes: 81 additions & 67 deletions src/safeds_stubgen/stubs_generator/_generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,21 +201,31 @@ class StubsStringGenerator:
"""

def __init__(self, api: API, convert_identifiers: bool) -> None:
self.module_id: str = ""
self.class_generics: list = []
self.module_imports: set[str] = set()

self.api = api
self.naming_convention = NamingConvention.SAFE_DS if convert_identifiers else NamingConvention.PYTHON
self.classes_outside_package: set[str] = set()
self.reexport_modules: dict[str, list[Class | Function]] = defaultdict(list)

def __call__(self, module: Module) -> tuple[str, str]:
self.module_id = module.id
self.class_generics: list = []
self.module_imports: set[str] = set()

self._current_todo_msgs: set[str] = set()
self.module = module
self.class_generics: list = []
return self._create_module_string()
return self._create_module_string(module)

def create_reexport_module_strings(self, out_path: Path) -> list[tuple[Path, str, str, bool]]:
module_data = []
for module_id in self.reexport_modules:
# Reset the objects that we normally would reset in the __call__
self.module_imports = set()
self.class_generics = []
self.module_id = module_id

module_name = module_id.split("/")[-1]

# Create module header
Expand All @@ -224,9 +234,10 @@ def create_reexport_module_strings(self, out_path: Path) -> list[tuple[Path, str
module_name_info = ""
if package_info != package_info_camel_case:
module_name_info = f'@PythonModule("{package_info}")\n'
module_text = f"{module_name_info}package {package_info_camel_case}\n"
module_header = f"{module_name_info}package {package_info_camel_case}\n"

# Create module text
module_text = ""
elements = self.reexport_modules[module_id]

# We sort for the snapshot tests
Expand All @@ -243,24 +254,29 @@ def create_reexport_module_strings(self, out_path: Path) -> list[tuple[Path, str
# We add the class text after global functions
module_text += class_text

# Create imports - We have to create them last, since we have to check all used types in this module first
module_header += self._create_imports_string()

module_text = module_header + module_text

module_data.append((Path(out_path / module_id), module_name, module_text, True))
return module_data

def _create_module_string(self) -> tuple[str, str]:
def _create_module_string(self, module: Module) -> tuple[str, str]:
module_text = ""

# Create package info
package_info, _ = _get_shortest_public_reexport(
reexport_map=self.api.reexport_map,
name=self.module.name,
name=module.name,
qname="",
is_module=True,
)

in_reexport_module = bool(package_info)

if not package_info:
package_info = ".".join(self.module.id.split("/"))
package_info = ".".join(module.id.split("/"))

package_info_camel_case = _convert_name_to_convention(package_info, self.naming_convention)
module_name_info = ""
Expand All @@ -269,12 +285,12 @@ def _create_module_string(self) -> tuple[str, str]:
module_header = f"{module_name_info}package {package_info_camel_case}\n"

# Create docstring
docstring = self._create_sds_docstring_description(self.module.docstring, "")
docstring = self._create_sds_docstring_description(module.docstring, "")
if docstring:
docstring += "\n"

# Create global functions and properties
for function in self.module.global_functions:
for function in module.global_functions:
if function.is_public:
function_string = self._create_function_string(
function=function,
Expand All @@ -285,14 +301,14 @@ def _create_module_string(self) -> tuple[str, str]:
module_text += f"\n{function_string}\n"

# Create classes, class attr. & class methods
for class_ in self.module.classes:
for class_ in module.classes:
if class_.is_public and not class_.inherits_from_exception:
class_string = self._create_class_string(class_=class_, in_reexport_module=in_reexport_module)
if class_string:
module_text += f"\n{class_string}\n"

# Create enums & enum instances
for enum in self.module.enums:
for enum in module.enums:
module_text += f"\n{self._create_enum_string(enum)}\n"

# Create imports - We have to create them last, since we have to check all used types in this module first
Expand Down Expand Up @@ -326,25 +342,8 @@ def _create_imports_string(self) -> str:
return f"\n{import_string}\n"

def _create_class_string(self, class_: Class, class_indentation: str = "", in_reexport_module: bool = False) -> str:
# Check if this class is beeing reexported from a shorter path. If it is, we create it there, not in this module
if not in_reexport_module:
shortest_reexport_module = self.module
for reexport_module in class_.reexported_by:
if len(reexport_module.id.split("/")) < len(shortest_reexport_module.id.split("/")):
shortest_reexport_module = reexport_module

if shortest_reexport_module != self.module:
# Get alias
alias = None
for qualified_import in shortest_reexport_module.qualified_imports:
if qualified_import.qualified_name.endswith(class_.name):
alias = qualified_import.alias

if alias:
class_.name = alias

self.reexport_modules[shortest_reexport_module.id].append(class_)
return ""
if not in_reexport_module and self._has_node_shorter_reexport(node=class_):
return ""

# Set indentation
inner_indentations = class_indentation + INDENTATION
Expand Down Expand Up @@ -559,25 +558,8 @@ def _create_function_string(
in_reexport_module: bool = False,
) -> str:
"""Create a function string for Safe-DS stubs."""
# Check if this class is beeing reexported from a shorter path. If it is, we create it there, not in this module
if not is_method and not in_reexport_module:
shortest_reexport_module = self.module
for reexport_module in function.reexported_by:
if len(reexport_module.id.split("/")) < len(shortest_reexport_module.id.split("/")):
shortest_reexport_module = reexport_module

if shortest_reexport_module != self.module:
# Get alias
alias = None
for qualified_import in shortest_reexport_module.qualified_imports:
if qualified_import.qualified_name.endswith(function.name):
alias = qualified_import.alias

if alias:
function.name = alias

self.reexport_modules[shortest_reexport_module.id].append(function)
return ""
if not is_method and not in_reexport_module and self._has_node_shorter_reexport(node=function):
return ""

# Check if static or class method
is_static = function.is_static
Expand Down Expand Up @@ -1088,40 +1070,70 @@ def _create_sds_docstring(

# ############################### Utilities ############################### #

def is_qname_in_reexports(self, qname: str) -> bool:
name = qname.split("/")[-1]
def _has_node_shorter_reexport(self, node: Class | Function) -> bool:
# Check if this node is beeing reexported from a shorter path. If it is, we create it there, not in this module
shortest_reexport_module_id = self.module_id
shortest_reexport_module: Module | None = None
for reexport_module in node.reexported_by:
if len(reexport_module.id.split("/")) < len(shortest_reexport_module_id.split("/")):
shortest_reexport_module_id = reexport_module.id
shortest_reexport_module = reexport_module

if shortest_reexport_module_id != self.module_id and shortest_reexport_module is not None:
# Get alias
alias = None
for qualified_import in shortest_reexport_module.qualified_imports:
if qualified_import.qualified_name.endswith(node.name):
alias = qualified_import.alias

if alias:
node.name = alias

self.reexport_modules[shortest_reexport_module_id].append(node)
return True
return False

def _is_path_connected_to_class(self, path: str, class_path: str) -> bool:
if class_path.endswith(path):
return True

name = path.split("/")[-1]
class_name = class_path.split("/")[-1]
for reexport in self.api.reexport_map:
if reexport.endswith(name):
for module in self.api.reexport_map[reexport]:
# Added pragma: no cover since I can't recreate this in the tests
if qname.startswith(module.id) and qname.lstrip(module.id).lstrip("/") == name: # pragma: no cover
# Added "no cover" since I can't recreate this in the tests
if (path.startswith(module.id) and class_path.startswith(module.id) and
path.lstrip(module.id).lstrip("/") == name == class_name): # pragma: no cover
return True

return False

def _add_to_imports(self, qname: str) -> None:
def _add_to_imports(self, import_qname: str) -> None:
"""Check if the qname of a type is defined in the current module, if not, create an import for it.
Paramters
---------
qname
The qualified name of a module/class/etc.
"""
if qname == "": # pragma: no cover
if import_qname == "": # pragma: no cover
raise ValueError("Type has no import source.")

qname_parts = qname.split(".")
if (qname_parts[0] == "builtins" and len(qname_parts) == 2) or qname == "typing.Any":
qname_parts = import_qname.split(".")
if (qname_parts[0] == "builtins" and len(qname_parts) == 2) or import_qname == "typing.Any":
return

module_id = self.module.id.replace("/", ".")
if module_id not in qname:
module_id = self.module_id.replace("/", ".")
if module_id not in import_qname:
# We need the full path for an import from the same package, but we sometimes don't get enough information,
# therefore we have to search for the class and get its id
qname_path = qname.replace(".", "/")
import_qname_path = import_qname.replace(".", "/")
in_package = False
for class_ in self.api.classes:
if class_.endswith(qname_path) or self.is_qname_in_reexports(qname_path):
qname = class_.replace("/", ".")
qname = ""
for class_id in self.api.classes:
if self._is_path_connected_to_class(import_qname_path, class_id):
qname = class_id.replace("/", ".")

name = qname.split(".")[-1]
shortest_qname, _ = _get_shortest_public_reexport(
Expand All @@ -1137,6 +1149,8 @@ def _add_to_imports(self, qname: str) -> None:
in_package = True
break

qname = qname or import_qname

if not in_package:
self.classes_outside_package.add(qname)

Expand Down Expand Up @@ -1174,15 +1188,15 @@ def _create_todo_msg(self, indentations: str) -> str:
return indentations + f"\n{indentations}".join(todo_msgs) + "\n"

def _get_class_in_module(self, class_name: str) -> Class:
if f"{self.module.id}/{class_name}" in self.api.classes:
return self.api.classes[f"{self.module.id}/{class_name}"]
if f"{self.module_id}/{class_name}" in self.api.classes:
return self.api.classes[f"{self.module_id}/{class_name}"]

# If the class is a nested class
for class_ in self.api.classes:
if class_.startswith(self.module.id) and class_.endswith(class_name):
if class_.startswith(self.module_id) and class_.endswith(class_name):
return self.api.classes[class_]

raise LookupError(f"Expected finding class '{class_name}' in module '{self.module.id}'.") # pragma: no cover
raise LookupError(f"Expected finding class '{class_name}' in module '{self.module_id}'.") # pragma: no cover

@staticmethod
def _create_docstring_description_part(description: str, indentations: str) -> str:
Expand Down

0 comments on commit 8d7ef3b

Please sign in to comment.