Skip to content

Commit

Permalink
feat: Correct stubs for TypeVars (#67)
Browse files Browse the repository at this point in the history
Closes #63

### Summary of Changes
TypeVars create correct stubs.

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
  • Loading branch information
Masara and megalinter-bot authored Feb 29, 2024
1 parent 216e179 commit df8c5c9
Show file tree
Hide file tree
Showing 16 changed files with 365 additions and 132 deletions.
11 changes: 8 additions & 3 deletions src/safeds_stubgen/api_analyzer/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ResultDocstring,
)

from ._types import AbstractType
from ._types import AbstractType, TypeVarType

API_SCHEMA_VERSION = 1

Expand Down Expand Up @@ -237,6 +237,7 @@ class Function:
is_static: bool
is_class_method: bool
is_property: bool
type_var_types: list[TypeVarType] = field(default_factory=list)
results: list[Result] = field(default_factory=list)
reexported_by: list[Module] = field(default_factory=list)
parameters: list[Parameter] = field(default_factory=list)
Expand Down Expand Up @@ -311,11 +312,15 @@ class ParameterAssignment(PythonEnum):
@dataclass(frozen=True)
class TypeParameter:
name: str
type: AbstractType
type: AbstractType | None
variance: VarianceKind

def to_dict(self) -> dict[str, Any]:
return {"name": self.name, "type": self.type.to_dict(), "variance_type": self.variance.name}
return {
"name": self.name,
"type": self.type.to_dict() if self.type is not None else None,
"variance_type": self.variance.name,
}


class VarianceKind(PythonEnum):
Expand Down
38 changes: 30 additions & 8 deletions src/safeds_stubgen/api_analyzer/_ast_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(self, docstring_parser: AbstractDocstringParser, api: API, aliases:
self.__declaration_stack: list[Module | Class | Function | Enum | list[Attribute | EnumInstance]] = []
self.aliases = aliases
self.mypy_file: mp_nodes.MypyFile | None = None
# We gather type var types used as a parameter type in a function
self.type_var_types: set[sds_types.TypeVarType] = set()

def enter_moduledef(self, node: mp_nodes.MypyFile) -> None:
self.mypy_file = node
Expand Down Expand Up @@ -153,13 +155,17 @@ def enter_classdef(self, node: mp_nodes.ClassDef) -> None:

for generic_type in generic_types:
variance_type = mypy_variance_parser(generic_type.variance)
variance_values: sds_types.AbstractType
variance_values: sds_types.AbstractType | None = None
if variance_type == VarianceKind.INVARIANT:
variance_values = sds_types.UnionType(
[self.mypy_type_to_abstract_type(value) for value in generic_type.values],
)
values = [self.mypy_type_to_abstract_type(value) for value in generic_type.values]
if values:
variance_values = sds_types.UnionType(
[self.mypy_type_to_abstract_type(value) for value in generic_type.values],
)
else:
variance_values = self.mypy_type_to_abstract_type(generic_type.upper_bound)
upper_bound = generic_type.upper_bound
if upper_bound.__str__() != "builtins.object":
variance_values = self.mypy_type_to_abstract_type(upper_bound)

type_parameters.append(
TypeParameter(
Expand Down Expand Up @@ -229,11 +235,19 @@ def enter_funcdef(self, node: mp_nodes.FuncDef) -> None:
# Get docstring
docstring = self.docstring_parser.get_function_documentation(node)

# Function args
# Function args & TypeVar
arguments: list[Parameter] = []
type_var_types: list[sds_types.TypeVarType] = []
# Reset the type_var_types list
self.type_var_types = set()
if getattr(node, "arguments", None) is not None:
arguments = self._parse_parameter_data(node, function_id)

if self.type_var_types:
type_var_types = list(self.type_var_types)
# Sort for the snapshot tests
type_var_types.sort(key=lambda x: x.name)

# Create results
results = self._parse_results(node, function_id)

Expand All @@ -252,6 +266,7 @@ def enter_funcdef(self, node: mp_nodes.FuncDef) -> None:
results=results,
reexported_by=reexported_by,
parameters=arguments,
type_var_types=type_var_types,
)
self.__declaration_stack.append(function)

Expand Down Expand Up @@ -666,7 +681,7 @@ def _parse_parameter_data(self, node: mp_nodes.FuncDef, function_id: str) -> lis
for argument in node.arguments:
mypy_type = argument.variable.type
type_annotation = argument.type_annotation
arg_type = None
arg_type: AbstractType | None = None
default_value = None
default_is_none = False

Expand Down Expand Up @@ -827,7 +842,14 @@ def mypy_type_to_abstract_type(

# Special Cases
elif isinstance(mypy_type, mp_types.TypeVarType):
return sds_types.TypeVarType(mypy_type.name)
upper_bound = mypy_type.upper_bound
type_ = None
if upper_bound.__str__() != "builtins.object":
type_ = self.mypy_type_to_abstract_type(upper_bound)

type_var = sds_types.TypeVarType(name=mypy_type.name, upper_bound=type_)
self.type_var_types.add(type_var)
return type_var
elif isinstance(mypy_type, mp_types.CallableType):
return sds_types.CallableType(
parameter_types=[self.mypy_type_to_abstract_type(arg_type) for arg_type in mypy_type.arg_types],
Expand Down
15 changes: 10 additions & 5 deletions src/safeds_stubgen/api_analyzer/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,16 +391,21 @@ def __hash__(self) -> int:
@dataclass(frozen=True)
class TypeVarType(AbstractType):
name: str
upper_bound: AbstractType | None = None

@classmethod
def from_dict(cls, d: dict[str, str]) -> TypeVarType:
return TypeVarType(d["name"])
def from_dict(cls, d: dict[str, Any]) -> TypeVarType:
return TypeVarType(d["name"], d["upper_bound"])

def to_dict(self) -> dict[str, str]:
return {"kind": self.__class__.__name__, "name": self.name}
def to_dict(self) -> dict[str, Any]:
return {
"kind": self.__class__.__name__,
"name": self.name,
"upper_bound": self.upper_bound.to_dict() if self.upper_bound is not None else None,
}

def __hash__(self) -> int:
return hash(frozenset([self.name]))
return hash(frozenset([self.name, self.upper_bound]))


# ############################## Utilities ############################## #
Expand Down
42 changes: 33 additions & 9 deletions src/safeds_stubgen/stubs_generator/_generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __call__(self, module: Module) -> str:
self.module_imports = set()
self._current_todo_msgs: set[str] = set()
self.module = module
self.class_generics: list = []
return self._create_module_string(module)

def _create_module_string(self, module: Module) -> str:
Expand Down Expand Up @@ -175,7 +176,8 @@ def _create_class_string(self, class_: Class, class_indentation: str = "") -> st
constraints_info = ""
variance_info = ""
if class_.type_parameters:
variances = []
# We collect the class generics for the methods later
self.class_generics = []
out = "out "
for variance in class_.type_parameters:
variance_direction = {
Expand All @@ -189,12 +191,12 @@ def _create_class_string(self, class_: Class, class_indentation: str = "") -> st
variance_name_camel_case = self._replace_if_safeds_keyword(variance_name_camel_case)

variance_item = f"{variance_direction}{variance_name_camel_case}"
if variance_direction == out:
if variance.type is not None:
variance_item = f"{variance_item} sub {self._create_type_string(variance.type.to_dict())}"
variances.append(variance_item)
self.class_generics.append(variance_item)

if variances:
variance_info = f"<{', '.join(variances)}>"
if self.class_generics:
variance_info = f"<{', '.join(self.class_generics)}>"

# Class name - Convert to camelCase and check for keywords
class_name = class_.name
Expand Down Expand Up @@ -265,6 +267,10 @@ def _create_class_attribute_string(self, attributes: list[Attribute], inner_inde
if attribute.type:
attribute_type = attribute.type.to_dict()

# Don't create TypeVar attributes
if attribute_type["kind"] == "TypeVarType":
continue

static_string = "static " if attribute.is_static else ""

# Convert name to camelCase and add PythonName annotation
Expand Down Expand Up @@ -317,6 +323,24 @@ def _create_function_string(self, function: Function, indentations: str = "", is
is_instance_method=not is_static and is_method,
)

# TypeVar
type_var_info = ""
if function.type_var_types:
type_var_names = []
for type_var in function.type_var_types:
type_var_name = self._convert_snake_to_camel_case(type_var.name)
type_var_name = self._replace_if_safeds_keyword(type_var_name)

# We don't have to display generic types in methods if they were already displayed in the class
if not is_method or (is_method and type_var_name not in self.class_generics):
if type_var.upper_bound is not None:
type_var_name += f" sub {self._create_type_string(type_var.upper_bound.to_dict())}"
type_var_names.append(type_var_name)

if type_var_names:
type_var_string = ", ".join(type_var_names)
type_var_info = f"<{type_var_string}>"

# Convert function name to camelCase
name = function.name
camel_case_name = self._convert_snake_to_camel_case(name)
Expand All @@ -334,8 +358,8 @@ def _create_function_string(self, function: Function, indentations: str = "", is
f"{self._create_todo_msg(indentations)}"
f"{indentations}@Pure\n"
f"{function_name_annotation}"
f"{indentations}{static}fun {camel_case_name}({func_params})"
f"{result_string}"
f"{indentations}{static}fun {camel_case_name}{type_var_info}"
f"({func_params}){result_string}"
)

def _create_property_function_string(self, function: Function, indentations: str = "") -> str:
Expand Down Expand Up @@ -621,9 +645,9 @@ def _create_type_string(self, type_data: dict | None) -> str:
else:
types.append(f"{literal_type}")
return f"literal<{', '.join(types)}>"
# Todo See issue #63
elif kind == "TypeVarType":
return ""
name = self._convert_snake_to_camel_case(type_data["name"])
return self._replace_if_safeds_keyword(name)

raise ValueError(f"Unexpected type: {kind}") # pragma: no cover

Expand Down
4 changes: 1 addition & 3 deletions tests/data/various_modules_package/attribute_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Final, Literal, TypeVar
from typing import Optional, Final, Literal
from tests.data.main_package.another_path.another_module import AnotherClass


Expand Down Expand Up @@ -66,7 +66,5 @@ def some_func() -> bool:
attr_type_from_outside_package: AnotherClass
attr_default_value_from_outside_package = AnotherClass

type_var = TypeVar("type_var")

def __init__(self):
self.init_attr: bool = False
6 changes: 1 addition & 5 deletions tests/data/various_modules_package/function_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Literal, Any, TypeVar
from typing import Callable, Optional, Literal, Any
from tests.data.main_package.another_path.another_module import AnotherClass


Expand Down Expand Up @@ -187,10 +187,6 @@ def param_from_outside_the_package(param_type: AnotherClass, param_value=Another
def result_from_outside_the_package() -> AnotherClass: ...


_type_var = TypeVar("_type_var")
def type_var_func(type_var_list: list[_type_var]) -> list[_type_var]: ...


class FunctionModulePropertiesClass:
@property
def property_function(self): ...
Expand Down
30 changes: 30 additions & 0 deletions tests/data/various_modules_package/type_var_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import TypeVar, Generic

T = TypeVar('T')


class TypeVarClass(Generic[T]):
type_var = TypeVar("type_var")

def __init__(self, items: list[T]): ...

def type_var_class_method(self, a: T) -> T: ...


class TypeVarClass2(Generic[T]):
type_var = TypeVar("type_var")

def type_var_class_method2(self, a: T) -> T: ...


_type_var = TypeVar("_type_var")
def type_var_func(a: list[_type_var]) -> list[_type_var]: ...


_type_var1 = TypeVar("_type_var1")
_type_var2 = TypeVar("_type_var2")
def multiple_type_var(a: _type_var1, b: _type_var2) -> list[_type_var1 | _type_var2]: ...


T_in = TypeVar("T_in", bound=int)
def type_var_fun_invariance_with_bound(a: list[T_in]) -> T_in: ...
23 changes: 18 additions & 5 deletions tests/data/various_modules_package/variance_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,27 @@ class A:
...


_T_co = TypeVar("_T_co", covariant=True, bound=str)
_T_con = TypeVar("_T_con", contravariant=True, bound=A)
_T_in = TypeVar("_T_in", int, Literal[1, 2])
_T_in = TypeVar("_T_in")
_T_co = TypeVar("_T_co", covariant=True)
_T_con = TypeVar("_T_con", contravariant=True)


class VarianceClassAll(Generic[_T_co, _T_con, _T_in]):
class VarianceClassOnlyCovarianceNoBound(Generic[_T_co]):
...


class VarianceClassOnlyInvariance(Generic[_T_in]):
class VarianceClassOnlyVarianceNoBound(Generic[_T_in]):
...


class VarianceClassOnlyContravarianceNoBound(Generic[_T_con]):
...


_T_co2 = TypeVar("_T_co2", covariant=True, bound=str)
_T_con2 = TypeVar("_T_con2", contravariant=True, bound=A)
_T_in2 = TypeVar("_T_in2", int, Literal[1, 2])


class VarianceClassAll(Generic[_T_co2, _T_con2, _T_in2]):
...
Loading

0 comments on commit df8c5c9

Please sign in to comment.