Skip to content

Commit

Permalink
Improve handling of default values (#336)
Browse files Browse the repository at this point in the history
- Disallow numbers where the string representation is >7 characters long
- Disallow strings or bytes where the string representation is >50
characters long
- Allow a small number of special constants from the `sys` module, such
as `sys.maxsize` and `sys.executable`

I've also changed a bunch of methods so that they're now functions.
These are all pure, stateless functions, so there's no need to have them
inside the body of the `PyiVisitor` class.
  • Loading branch information
AlexWaygood authored Jan 23, 2023
1 parent cbc447e commit 29ba4e3
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 67 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
# Change Log

## Unreleased

* Disallow numeric default values where `len(str(default)) > 7`. If a function
has a default value where the string representation is greater than 7
characters, it is likely to be an implementation detail or a constant that
varies depending on the system you're running on, such as `sys.maxsize`.
* Disallow `str` or `bytes` defaults where the default is >50 characters long,
for similar reasons.
* Allow `ast.Attribute` nodes as default values for a small number of special
cases, such as `sys.maxsize` and `sys.executable`.

## 23.1.0

Bugfixes:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ currently emitted:
| Y008 | Unrecognized platform in a `sys.platform` check. To prevent you from typos, we warn if you use a platform name outside a small set of known platforms (e.g. `"linux"` and `"win32"`).
| Y009 | Empty class or function body should contain `...`, not `pass`. This is just a stylistic choice, but it's the one typeshed made.
| Y010 | Function body must contain only `...`. Stub files should not contain code, so function bodies should be empty.
| Y011 | Only simple default values (`int`, `float`, `complex`, `bytes`, `str`, `bool`, `None` or `...`) are allowed for typed function arguments. Type checkers ignore the default value, so the default value is not useful information for type-checking, but it may be useful information for other users of stubs such as IDEs. If you're writing a stub for a function that has a more complex default value, use `...` instead of trying to reproduce the runtime default exactly in the stub.
| Y011 | Only simple default values (`int`, `float`, `complex`, `bytes`, `str`, `bool`, `None` or `...`) are allowed for typed function arguments. Type checkers ignore the default value, so the default value is not useful information for type-checking, but it may be useful information for other users of stubs such as IDEs. If you're writing a stub for a function that has a more complex default value, use `...` instead of trying to reproduce the runtime default exactly in the stub. Also use `...` for very long numbers, very long strings, very long bytes, or defaults that vary according to the machine Python is being run on.
| Y012 | Class body must not contain `pass`.
| Y013 | Non-empty class body must not contain `...`.
| Y014 | Only simple default values are allowed for any function arguments. A stronger version of Y011 that includes arguments without type annotations.
Expand Down
170 changes: 104 additions & 66 deletions pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,103 @@ def _analyse_union(members: Sequence[ast.expr]) -> UnionAnalysis:
)


_ALLOWED_ATTRIBUTES_IN_DEFAULTS = frozenset(
{
"sys.base_prefix",
"sys.byteorder",
"sys.exec_prefix",
"sys.executable",
"sys.hexversion",
"sys.maxsize",
"sys.platform",
"sys.prefix",
"sys.stdin",
"sys.stdout",
"sys.stderr",
"sys.version",
"sys.version_info",
"sys.winver",
}
)


def _is_valid_stub_default(node: ast.expr) -> bool:
"""Is `node` valid as a default value for a function or method parameter in a stub?"""
# `...`, bools, None
if isinstance(node, (ast.Ellipsis, ast.NameConstant)):
return True

# strings, bytes
if isinstance(node, (ast.Str, ast.Bytes)):
return len(str(node.s)) <= 50

def _is_valid_Num(node: ast.expr) -> TypeGuard[ast.Num]:
return isinstance(node, ast.Num) and len(str(node.n)) <= 7

# Positive ints, positive floats, positive complex numbers with no real part
if _is_valid_Num(node):
return True
# Negative ints, negative floats, negative complex numbers with no real part
if (
isinstance(node, ast.UnaryOp)
and isinstance(node.op, ast.USub)
and _is_valid_Num(node.operand)
):
return True
# Complex numbers with a real part and an imaginary part...
if (
isinstance(node, ast.BinOp)
and isinstance(node.op, (ast.Add, ast.Sub))
and _is_valid_Num(node.right)
and type(node.right.n) is complex
):
left = node.left
# ...Where the real part is positive:
if isinstance(left, ast.Num) and type(left.n) is not complex:
return True
# ...Where the real part is negative:
if (
isinstance(left, ast.UnaryOp)
and isinstance(left.op, ast.USub)
and _is_valid_Num(left.operand)
and type(left.operand.n) is not complex
):
return True
# Special cases
if (
isinstance(node, ast.Attribute)
and isinstance(node.value, ast.Name)
and f"{node.value.id}.{node.attr}" in _ALLOWED_ATTRIBUTES_IN_DEFAULTS
):
return True
return False


def _is_valid_pep_604_union_member(node: ast.expr) -> bool:
return _is_None(node) or isinstance(node, (ast.Name, ast.Attribute, ast.Subscript))


def _is_valid_pep_604_union(node: ast.expr) -> TypeGuard[ast.BinOp]:
return (
isinstance(node, ast.BinOp)
and isinstance(node.op, ast.BitOr)
and (
_is_valid_pep_604_union_member(node.left)
or _is_valid_pep_604_union(node.left)
)
and _is_valid_pep_604_union_member(node.right)
)


def _is_valid_assignment_value(node: ast.expr) -> bool:
"""Is `node` valid as the default value for an assignment in a stub?"""
return (
isinstance(node, (ast.Call, ast.Name, ast.Attribute, ast.Subscript))
or _is_valid_pep_604_union(node)
or _is_valid_stub_default(node)
)


@dataclass
class NestingCounter:
"""Class to help the PyiVisitor keep track of internal state"""
Expand Down Expand Up @@ -847,30 +944,6 @@ def _check_for_typevarlike_assignments(
else:
self.error(node, Y001.format(cls_name))

@staticmethod
def _is_valid_pep_604_union_member(node: ast.expr) -> bool:
return _is_None(node) or isinstance(
node, (ast.Name, ast.Attribute, ast.Subscript)
)

def _is_valid_pep_604_union(self, node: ast.expr) -> TypeGuard[ast.BinOp]:
return (
isinstance(node, ast.BinOp)
and isinstance(node.op, ast.BitOr)
and (
self._is_valid_pep_604_union_member(node.left)
or self._is_valid_pep_604_union(node.left)
)
and self._is_valid_pep_604_union_member(node.right)
)

def _is_valid_assignment_value(self, node: ast.expr) -> bool:
return (
isinstance(node, (ast.Call, ast.Name, ast.Attribute, ast.Subscript))
or self._is_valid_pep_604_union(node)
or self._is_valid_stub_default(node)
)

def visit_Assign(self, node: ast.Assign) -> None:
if self.in_function.active:
# We error for unexpected things within functions separately.
Expand Down Expand Up @@ -911,7 +984,7 @@ def visit_Assign(self, node: ast.Assign) -> None:

if not is_special_assignment:
self._check_for_type_aliases(node, target, assignment)
if not self._is_valid_assignment_value(assignment):
if not _is_valid_assignment_value(assignment):
self.error(node, Y015)

def visit_AugAssign(self, node: ast.AugAssign) -> None:
Expand Down Expand Up @@ -946,7 +1019,7 @@ def _check_for_type_aliases(
"""
if (
isinstance(assignment, ast.Subscript)
or self._is_valid_pep_604_union(assignment)
or _is_valid_pep_604_union(assignment)
or _is_Any(assignment)
or _is_None(assignment)
):
Expand Down Expand Up @@ -1048,7 +1121,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
if _is_TypeAlias(node_annotation) and isinstance(node_target, ast.Name):
self._check_typealias(node=node, alias_name=node_target.id)

if node_value and not self._is_valid_assignment_value(node_value):
if node_value and not _is_valid_assignment_value(node_value):
self.error(node, Y015)

def _check_union_members(self, members: Sequence[ast.expr]) -> None:
Expand Down Expand Up @@ -1407,7 +1480,7 @@ def error_for_bad_annotation(
arg1_annotation = non_kw_only_args[1].annotation
if arg1_annotation is None or _is_object_or_Unused(arg1_annotation):
pass
elif self._is_valid_pep_604_union(arg1_annotation):
elif _is_valid_pep_604_union(arg1_annotation):
is_union_with_None, non_None_part = _analyse_exit_method_arg(
arg1_annotation
)
Expand All @@ -1425,7 +1498,7 @@ def error_for_bad_annotation(
arg2_annotation = non_kw_only_args[2].annotation
if arg2_annotation is None or _is_object_or_Unused(arg2_annotation):
pass
elif self._is_valid_pep_604_union(arg2_annotation):
elif _is_valid_pep_604_union(arg2_annotation):
is_union_with_None, non_None_part = _analyse_exit_method_arg(
arg2_annotation
)
Expand All @@ -1438,7 +1511,7 @@ def error_for_bad_annotation(
arg3_annotation = non_kw_only_args[3].annotation
if arg3_annotation is None or _is_object_or_Unused(arg3_annotation):
pass
elif self._is_valid_pep_604_union(arg3_annotation):
elif _is_valid_pep_604_union(arg3_annotation):
is_union_with_None, non_None_part = _analyse_exit_method_arg(
arg3_annotation
)
Expand Down Expand Up @@ -1684,47 +1757,12 @@ def visit_arguments(self, node: ast.arguments) -> None:
if node.kwarg is not None:
self.visit(node.kwarg)

def _is_valid_stub_default(self, default: ast.expr) -> bool:
# `...`, strings, bytes, bools, None
if isinstance(default, (ast.Ellipsis, ast.Str, ast.Bytes, ast.NameConstant)):
return True
# Positive ints, positive floats, positive complex numbers with no real part
if isinstance(default, ast.Num):
return True
# Negative ints, negative floats, negative complex numbers with no real part
if (
isinstance(default, ast.UnaryOp)
and isinstance(default.op, ast.USub)
and isinstance(default.operand, ast.Num)
):
return True
# Complex numbers with a real part and an imaginary part...
if (
isinstance(default, ast.BinOp)
and isinstance(default.op, (ast.Add, ast.Sub))
and isinstance(default.right, ast.Num)
and type(default.right.n) is complex
):
left = default.left
# ...Where the real part is positive:
if isinstance(left, ast.Num) and type(left.n) is not complex:
return True
# ...Where the real part is negative:
if (
isinstance(left, ast.UnaryOp)
and isinstance(left.op, ast.USub)
and isinstance(left.operand, ast.Num)
and type(left.operand.n) is not complex
):
return True
return False

def check_arg_default(self, arg: ast.arg, default: ast.expr | None) -> None:
self.visit(arg)
if default is not None:
with self.string_literals_allowed.enabled():
self.visit(default)
if default is not None and not self._is_valid_stub_default(default):
if default is not None and not _is_valid_stub_default(default):
self.error(default, (Y014 if arg.annotation is None else Y011))

def error(self, node: ast.AST, message: str) -> None:
Expand Down
22 changes: 22 additions & 0 deletions tests/defaults.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import sys

from _typeshed import SupportsRead, SupportsWrite

def f1(x: int = ...) -> None: ...
def f2(x: int = 3) -> None: ...
Expand Down Expand Up @@ -31,3 +34,22 @@ def f19(x: object = "foo" + 4) -> None: ... # Y011 Only simple default values a
def f20(x: int = 5 + 5) -> None: ... # Y011 Only simple default values allowed for typed arguments
def f21(x: complex = 3j - 3j) -> None: ... # Y011 Only simple default values allowed for typed arguments
def f22(x: complex = -42.5j + 4.3j) -> None: ... # Y011 Only simple default values allowed for typed arguments

# Special-cased attributes
def f23(x: str = sys.base_prefix) -> None: ...
def f24(x: str = sys.byteorder) -> None: ...
def f25(x: str = sys.exec_prefix) -> None: ...
def f26(x: str = sys.executable) -> None: ...
def f27(x: int = sys.hexversion) -> None: ...
def f28(x: int = sys.maxsize) -> None: ...
def f29(x: str = sys.platform) -> None: ...
def f30(x: str = sys.prefix) -> None: ...
def f31(x: SupportsRead[str] = sys.stdin) -> None: ...
def f32(x: SupportsWrite[str] = sys.stdout) -> None: ...
def f33(x: SupportsWrite[str] = sys.stderr) -> None: ...
def f34(x: str = sys.version) -> None: ...
def f35(x: tuple[int, ...] = sys.version_info) -> None: ...
def f36(x: str = sys.winver) -> None: ...

def f37(x: str = "a_very_long_stringgggggggggggggggggggggggggggggggggggggggggggggg") -> None: ... # Y011 Only simple default values allowed for typed arguments
def f38(x: bytes = b"a_very_long_byte_stringggggggggggggggggggggggggggggggggggggg") -> None: ... # Y011 Only simple default values allowed for typed arguments

0 comments on commit 29ba4e3

Please sign in to comment.