Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Teach stubgen to work with complex and unary expressions #15661

Merged
merged 3 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
CallExpr,
ClassDef,
ComparisonExpr,
ComplexExpr,
Decorator,
DictExpr,
EllipsisExpr,
Expand Down Expand Up @@ -1396,6 +1397,8 @@ def is_private_member(self, fullname: str) -> bool:
def get_str_type_of_node(
self, rvalue: Expression, can_infer_optional: bool = False, can_be_any: bool = True
) -> str:
rvalue = self.maybe_unwrap_unary_expr(rvalue)

if isinstance(rvalue, IntExpr):
return "int"
if isinstance(rvalue, StrExpr):
Expand All @@ -1404,8 +1407,13 @@ def get_str_type_of_node(
return "bytes"
if isinstance(rvalue, FloatExpr):
return "float"
if isinstance(rvalue, UnaryExpr) and isinstance(rvalue.expr, IntExpr):
return "int"
if isinstance(rvalue, ComplexExpr): # 1j
return "complex"
if isinstance(rvalue, OpExpr) and rvalue.op in ("-", "+"): # -1j + 1
if isinstance(self.maybe_unwrap_unary_expr(rvalue.left), ComplexExpr) or isinstance(
self.maybe_unwrap_unary_expr(rvalue.right), ComplexExpr
):
return "complex"
if isinstance(rvalue, NameExpr) and rvalue.name in ("True", "False"):
return "bool"
if can_infer_optional and isinstance(rvalue, NameExpr) and rvalue.name == "None":
Expand All @@ -1417,6 +1425,40 @@ def get_str_type_of_node(
else:
return ""

def maybe_unwrap_unary_expr(self, expr: Expression) -> Expression:
"""Unwrap (possibly nested) unary expressions.

But, some unary expressions can change the type of expression.
While we want to preserve it. For example, `~True` is `int`.
So, we only allow a subset of unary expressions to be unwrapped.
"""
if not isinstance(expr, UnaryExpr):
return expr

# First, try to unwrap `[+-]+ (int|float|complex)` expr:
math_ops = ("+", "-")
if expr.op in math_ops:
while isinstance(expr, UnaryExpr):
if expr.op not in math_ops or not isinstance(
expr.expr, (IntExpr, FloatExpr, ComplexExpr, UnaryExpr)
):
break
expr = expr.expr
sobolevn marked this conversation as resolved.
Show resolved Hide resolved
return expr

# Next, try `not bool` expr:
if expr.op == "not":
while isinstance(expr, UnaryExpr):
if expr.op != "not" or not isinstance(expr.expr, (NameExpr, UnaryExpr)):
break
if isinstance(expr.expr, NameExpr) and expr.expr.name not in ("True", "False"):
break
expr = expr.expr
return expr

# This is some other unary expr, we cannot do anything with it (yet?).
return expr

def print_annotation(self, t: Type) -> str:
printer = AnnotationPrinter(self)
return t.accept(printer)
Expand Down
48 changes: 45 additions & 3 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,52 @@ class A:

def g() -> None: ...

[case testVariable]
x = 1
[case testVariables]
i = 1
s = 'a'
f = 1.5
c1 = 1j
c2 = 0j + 1
bl1 = True
bl2 = False
bts = b''
[out]
i: int
s: str
f: float
c1: complex
c2: complex
bl1: bool
bl2: bool
bts: bytes

[case testVariablesWithUnary]
i = +-1
f = -1.5
c1 = -1j
c2 = -1j + 1
bl1 = not True
bl2 = not not False
[out]
i: int
f: float
c1: complex
c2: complex
bl1: bool
bl2: bool

[case testVariablesWithUnaryWrong]
i = not +1
bl1 = -True
bl2 = not -False
bl3 = -(not False)
[out]
x: int
from _typeshed import Incomplete

i: Incomplete
bl1: Incomplete
bl2: Incomplete
bl3: Incomplete

[case testAnnotatedVariable]
x: int = 1
Expand Down