Skip to content

Commit

Permalink
fix: Some packages couldn't be analyzed (#51)
Browse files Browse the repository at this point in the history
Closes #48

### Summary of Changes
* Fixed the bug where some packages couldn't be analyzed because the
Mypy build result could not be parsed from our side.
* Added TypeVarType as a _types.py class for type variables.
* Some bug fixes and refactoring

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
  • Loading branch information
Masara and megalinter-bot authored Feb 21, 2024
1 parent a84ec64 commit fa3d020
Show file tree
Hide file tree
Showing 22 changed files with 350 additions and 104 deletions.
20 changes: 10 additions & 10 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"private": true,
"prettier": "@lars-reimann/prettier-config",
"devDependencies": {
"@lars-reimann/prettier-config": "^5.0.0",
"@lars-reimann/prettier-config": "^5.2.1",
"@semantic-release/changelog": "^6.0.3",
"@semantic-release/exec": "^6.0.3",
"@semantic-release/git": "^10.0.1",
Expand Down
2 changes: 2 additions & 0 deletions src/safeds_stubgen/api_analyzer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
NamedType,
SetType,
TupleType,
TypeVarType,
UnionType,
)

Expand Down Expand Up @@ -63,6 +64,7 @@
"Result",
"SetType",
"TupleType",
"TypeVarType",
"UnionType",
"VarianceKind",
"WildcardImport",
Expand Down
81 changes: 51 additions & 30 deletions src/safeds_stubgen/api_analyzer/_ast_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,12 @@ def enter_moduledef(self, node: mp_nodes.MypyFile) -> None:
docstring = definition.expr.value

# Create module id to get the full path
id_ = self._create_module_id(node.fullname)
id_ = self._create_module_id(node.path)

# If we are checking a package node.name will be the package name, but since we get import information from
# the __init__.py file we set the name to __init__
if is_package:
name = "__init__"
id_ += f"/{name}"
else:
name = node.name

Expand Down Expand Up @@ -151,9 +150,9 @@ def enter_classdef(self, node: mp_nodes.ClassDef) -> None:
variance_type = mypy_variance_parser(generic_type.variance)
variance_values: sds_types.AbstractType
if variance_type == VarianceKind.INVARIANT:
variance_values = sds_types.UnionType([
mypy_type_to_abstract_type(value) for value in generic_type.values
])
variance_values = sds_types.UnionType(
[mypy_type_to_abstract_type(value) for value in generic_type.values],
)
else:
variance_values = mypy_type_to_abstract_type(generic_type.upper_bound)

Expand Down Expand Up @@ -419,24 +418,38 @@ def _infer_type_from_return_stmts(func_node: mp_nodes.FuncDef) -> sds_types.Name
func_defn = get_funcdef_definitions(func_node)
return_stmts = find_return_stmts_recursive(func_defn)
if return_stmts:
# In this case the items of the types set can only be of the class "NamedType" or "TupleType" but we have to
# make a typecheck anyway for the mypy linter.
types = set()
for return_stmt in return_stmts:
if return_stmt.expr is not None:
type_ = mypy_expression_to_sds_type(return_stmt.expr)
if isinstance(type_, sds_types.NamedType | sds_types.TupleType):
types.add(type_)
if return_stmt.expr is None: # pragma: no cover
continue

if not isinstance(return_stmt.expr, mp_nodes.CallExpr | mp_nodes.MemberExpr):
# Todo Frage: Parse conditional branches recursively?
# If the return statement is a conditional expression we parse the "if" and "else" branches
if isinstance(return_stmt.expr, mp_nodes.ConditionalExpr):
for conditional_branch in [return_stmt.expr.if_expr, return_stmt.expr.else_expr]:
if conditional_branch is None: # pragma: no cover
continue

if not isinstance(conditional_branch, mp_nodes.CallExpr | mp_nodes.MemberExpr):
type_ = mypy_expression_to_sds_type(conditional_branch)
if isinstance(type_, sds_types.NamedType | sds_types.TupleType):
types.add(type_)
else:
type_ = mypy_expression_to_sds_type(return_stmt.expr)
if isinstance(type_, sds_types.NamedType | sds_types.TupleType):
types.add(type_)

# We have to sort the list for the snapshot tests
return_stmt_types = list(types)
return_stmt_types.sort(
key=lambda x: (x.name if isinstance(x, sds_types.NamedType) else str(len(x.types))),
)

if len(return_stmt_types) >= 2:
if len(return_stmt_types) == 1:
return return_stmt_types[0]
elif len(return_stmt_types) >= 2:
return sds_types.TupleType(types=return_stmt_types)
return return_stmt_types[0]
return None

@staticmethod
Expand Down Expand Up @@ -561,11 +574,19 @@ def _create_attribute(
qname = getattr(attribute, "fullname", "")

# Get node information
type_: sds_types.AbstractType | None = None
node = None
if hasattr(attribute, "node"):
if not isinstance(attribute.node, mp_nodes.Var): # pragma: no cover
raise TypeError("node has wrong type")
if not isinstance(attribute.node, mp_nodes.Var):
# In this case we have a TypeVar attribute
attr_name = getattr(attribute, "name", "")

if not attr_name: # pragma: no cover
raise AttributeError("Expected TypeVar to have attribute 'name'.")

node: mp_nodes.Var = attribute.node
type_ = sds_types.TypeVarType(attr_name)
else:
node = attribute.node
else: # pragma: no cover
raise AttributeError("Expected attribute to have attribute 'node'.")

Expand All @@ -576,13 +597,13 @@ def _create_attribute(
attribute_type = None

# MemberExpr are constructor (__init__) attributes
if isinstance(attribute, mp_nodes.MemberExpr):
if node is not None and isinstance(attribute, mp_nodes.MemberExpr):
attribute_type = node.type
if isinstance(attribute_type, mp_types.AnyType) and not has_correct_type_of_any(attribute_type.type_of_any):
attribute_type = None

# NameExpr are class attributes
elif isinstance(attribute, mp_nodes.NameExpr):
elif node is not None and isinstance(attribute, mp_nodes.NameExpr):
if not node.explicit_self_type:
attribute_type = node.type

Expand All @@ -600,10 +621,6 @@ def _create_attribute(
else: # pragma: no cover
raise AttributeError("Could not get argument information for attribute.")

else: # pragma: no cover
raise TypeError("Attribute has an unexpected type.")

type_ = None
# Ignore types that are special mypy any types
if attribute_type is not None and not (
isinstance(attribute_type, mp_types.AnyType) and not has_correct_type_of_any(attribute_type.type_of_any)
Expand Down Expand Up @@ -770,7 +787,7 @@ def _add_reexports(self, module: Module) -> None:
def _check_if_qname_in_package(self, qname: str) -> bool:
return self.api.package in qname

def _create_module_id(self, qname: str) -> str:
def _create_module_id(self, module_path: str) -> str:
"""Create an ID for the module object.
Creates the module ID while discarding possible unnecessary information from the module qname.
Expand All @@ -787,19 +804,23 @@ def _create_module_id(self, qname: str) -> str:
"""
package_name = self.api.package

if package_name not in qname:
raise ValueError("Package name could not be found in the qualified name of the module.")
if package_name not in module_path:
raise ValueError("Package name could not be found in the module path.")

# We have to split the qname of the module at the first occurence of the package name and reconnect it while
# discarding everything in front of it. This is necessary since the qname could contain unwanted information.
module_id = qname.split(f"{package_name}", 1)[-1]
module_id = module_path.split(package_name, 1)[-1]
module_id = module_id.replace("\\", "/")

if module_id.startswith("."):
if module_id.startswith("/"):
module_id = module_id[1:]

# Replaces dots with slashes and add the package name at the start of the id, since we removed it
module_id = f"/{module_id.replace('.', '/')}" if module_id else ""
return f"{package_name}{module_id}"
if module_id.endswith(".py"):
module_id = module_id[:-3]

if module_id:
return f"{package_name}/{module_id}"
return package_name

def _is_public(self, name: str, qualified_name: str) -> bool:
if name.startswith("_") and not name.endswith("__"):
Expand Down
67 changes: 25 additions & 42 deletions src/safeds_stubgen/api_analyzer/_get_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING

import mypy.build as mypy_build
Expand All @@ -15,6 +14,8 @@
from ._package_metadata import distribution, distribution_version, package_root

if TYPE_CHECKING:
from pathlib import Path

from mypy.nodes import MypyFile


Expand All @@ -39,7 +40,6 @@ def get_api(
walker = ASTWalker(callable_visitor)

walkable_files = []
package_paths = []
for file_path in root.glob(pattern="./**/*.py"):
logging.info(
"Working on file {posix_path}",
Expand All @@ -51,24 +51,16 @@ def get_api(
logging.info("Skipping test file")
continue

# Check if the current file is an init file
if file_path.parts[-1] == "__init__.py":
# if a directory contains an __init__.py file it's a package
package_paths.append(
file_path.parent,
)
continue

walkable_files.append(str(file_path))

mypy_trees = _get_mypy_ast(walkable_files, package_paths, root)
mypy_trees = _get_mypy_ast(walkable_files, root)
for tree in mypy_trees:
walker.walk(tree)

return callable_visitor.api


def _get_mypy_ast(files: list[str], package_paths: list[Path], root: Path) -> list[MypyFile]:
def _get_mypy_ast(files: list[str], root: Path) -> list[MypyFile]:
if not files:
raise ValueError("No files found to analyse.")

Expand All @@ -79,37 +71,28 @@ def _get_mypy_ast(files: list[str], package_paths: list[Path], root: Path) -> li
result = mypy_build.build(mypyfiles, options=opt)

# Check mypy data key root start
parts = root.parts
graph_keys = list(result.graph.keys())
root_start_after = -1
for i in range(len(parts)):
if ".".join(parts[i:]) in graph_keys:
root_start_after = i
break

# Create the keys for getting the corresponding data
packages = [
".".join(
package_path.parts[root_start_after:],
).replace(".py", "")
for package_path in package_paths
]

modules = [
".".join(
Path(file).parts[root_start_after:],
).replace(".py", "")
for file in files
]

# Get the needed data from mypy. The packages need to be checked first, since we have
# to get the reexported data first
all_paths = packages + modules
graphs = result.graph
graph_keys = list(graphs.keys())
root_path = str(root)

# Get the needed data from mypy. The __init__ files need to be checked first, since we have to get the
# reexported data for the packages first
results = []
for path_key in all_paths:
tree = result.graph[path_key].tree
if tree is not None:
init_results = []
for graph_key in graph_keys:
graph = graphs[graph_key]
graph_path = graph.abspath

if graph_path is None: # pragma: no cover
raise ValueError("Could not parse path of a module.")

tree = graph.tree
if tree is None or root_path not in graph_path or not graph_path.endswith(".py"):
continue

if graph_path.endswith("__init__.py"):
init_results.append(tree)
else:
results.append(tree)

return results
return init_results + results
2 changes: 2 additions & 0 deletions src/safeds_stubgen/api_analyzer/_mypy_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def mypy_type_to_abstract_type(
return sds_types.UnionType(types=[mypy_type_to_abstract_type(item) for item in mypy_type.items])

# Special Cases
elif isinstance(mypy_type, mp_types.TypeVarType):
return sds_types.TypeVarType(mypy_type.name)
elif isinstance(mypy_type, mp_types.CallableType):
return sds_types.CallableType(
parameter_types=[mypy_type_to_abstract_type(arg_type) for arg_type in mypy_type.arg_types],
Expand Down
20 changes: 18 additions & 2 deletions src/safeds_stubgen/api_analyzer/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ def from_dict(cls, d: dict[str, Any]) -> AbstractType:
return UnionType.from_dict(d)
case CallableType.__name__:
return CallableType.from_dict(d)
case TypeVarType.__name__:
return TypeVarType.from_dict(d)
case _:
raise ValueError(f"Cannot parse {d['kind']} value.")

@abstractmethod
def to_dict(self) -> dict[str, Any]:
pass
def to_dict(self) -> dict[str, Any]: ...


@dataclass(frozen=True)
Expand Down Expand Up @@ -387,6 +388,21 @@ def __hash__(self) -> int:
return hash(frozenset(self.types))


@dataclass(frozen=True)
class TypeVarType(AbstractType):
name: str

@classmethod
def from_dict(cls, d: dict[str, str]) -> TypeVarType:
return TypeVarType(d["name"])

def to_dict(self) -> dict[str, str]:
return {"kind": self.__class__.__name__, "name": self.name}

def __hash__(self) -> int:
return hash(frozenset([self.name]))


# ############################## Utilities ############################## #
# def _dismantel_type_string_structure(type_structure: str) -> list:
# current_type = ""
Expand Down
Loading

0 comments on commit fa3d020

Please sign in to comment.