Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add suggested param/return type #358

Merged
merged 4 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Unreleased

- Add check that suggests parameter and return types for untyped
functions, using the new `suggested_parameter_type` and
`suggested_return_type` codes (#358)
- Extract constraints from multi-comparisons (`a < b < c`) (#354)
- Support positional-only arguments with the `__` prefix
outside of stubs (#353)
Expand Down
2 changes: 2 additions & 0 deletions pyanalyze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import safe
from . import signature
from . import stacked_scopes
from . import suggested_type
from . import test_config
from . import type_object
from . import typeshed
Expand All @@ -44,3 +45,4 @@
used(value.UNRESOLVED_VALUE) # keeping it around for now just in case
used(reexport)
used(checker)
used(suggested_type)
7 changes: 7 additions & 0 deletions pyanalyze/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@
import sys
from typing import Iterable, Iterator, List, Set, Tuple, Union, Dict

from .node_visitor import Failure
from .value import TypedValue
from .arg_spec import ArgSpecCache
from .config import Config
from .reexport import ImplicitReexportTracker
from .safe import is_instance_of_typing_name, is_typing_name, safe_getattr
from .type_object import TypeObject, get_mro
from .suggested_type import CallableTracker


@dataclass
class Checker:
config: Config
arg_spec_cache: ArgSpecCache = field(init=False)
reexport_tracker: ImplicitReexportTracker = field(init=False)
callable_tracker: CallableTracker = field(init=False)
type_object_cache: Dict[Union[type, super, str], TypeObject] = field(
default_factory=dict, init=False, repr=False
)
Expand All @@ -32,6 +35,10 @@ class Checker:
def __post_init__(self) -> None:
self.arg_spec_cache = ArgSpecCache(self.config)
self.reexport_tracker = ImplicitReexportTracker(self.config)
self.callable_tracker = CallableTracker()

def perform_final_checks(self) -> List[Failure]:
return self.callable_tracker.check()

def get_additional_bases(self, typ: Union[type, super]) -> Set[type]:
return self.config.get_additional_bases(typ)
Expand Down
6 changes: 6 additions & 0 deletions pyanalyze/error_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,16 @@ class ErrorCode(enum.Enum):
no_return_may_return = 68
implicit_reexport = 69
invalid_context_manager = 70
suggested_return_type = 71
suggested_parameter_type = 72


# Allow testing unannotated functions without too much fuss
DISABLED_IN_TESTS = {
ErrorCode.missing_return_annotation,
ErrorCode.missing_parameter_annotation,
ErrorCode.suggested_return_type,
ErrorCode.suggested_parameter_type,
}


Expand Down Expand Up @@ -193,6 +197,8 @@ class ErrorCode(enum.Enum):
ErrorCode.no_return_may_return: "Function is annotated as NoReturn but may return",
ErrorCode.implicit_reexport: "Use of implicitly re-exported name",
ErrorCode.invalid_context_manager: "Use of invalid object in with or async with",
ErrorCode.suggested_return_type: "Suggested return type",
ErrorCode.suggested_parameter_type: "Suggested parameter type",
}


Expand Down
37 changes: 37 additions & 0 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
ARGS,
KWARGS,
)
from .suggested_type import CallArgs, display_suggested_type
from .asynq_checker import AsyncFunctionKind, AsynqChecker, FunctionInfo
from .yield_checker import YieldChecker
from .type_object import TypeObject, get_mro
Expand Down Expand Up @@ -1306,6 +1307,17 @@ def visit_FunctionDef(
else:
potential_function = None

if (
potential_function is not None
and self.settings
and self.settings[ErrorCode.suggested_parameter_type]
):
sig = self.signature_from_value(KnownValue(potential_function))
if isinstance(sig, Signature):
self.checker.callable_tracker.record_callable(
node, potential_function, sig, self
)

self.yield_checker.reset_yield_checks()

# This code handles nested functions
Expand Down Expand Up @@ -1373,6 +1385,21 @@ def visit_FunctionDef(
else:
self._show_error_if_checking(node, error_code=ErrorCode.missing_return)

if (
has_return
and expected_return_value is None
and not info.is_overload
and not any(
decorator == KnownValue(abstractmethod)
for _, decorator in info.decorators
)
):
self._show_error_if_checking(
node,
error_code=ErrorCode.suggested_return_type,
detail=display_suggested_type(return_value),
)

if evaled_function:
return evaled_function

Expand Down Expand Up @@ -1427,6 +1454,10 @@ def visit_FunctionDef(
self.log(logging.DEBUG, "No argspec", (potential_function, node))
return KnownValue(potential_function)

def record_call(self, callable: object, arguments: CallArgs) -> None:
if self.settings and self.settings[ErrorCode.suggested_parameter_type]:
self.checker.callable_tracker.record_call(callable, arguments)

def _visit_defaults(
self, node: FunctionNode
) -> Tuple[List[Value], List[Optional[Value]]]:
Expand Down Expand Up @@ -4439,6 +4470,12 @@ def prepare_constructor_kwargs(cls, kwargs: Mapping[str, Any]) -> Mapping[str, A
kwargs.setdefault("checker", Checker(cls.config))
return kwargs

@classmethod
def perform_final_checks(
cls, kwargs: Mapping[str, Any]
) -> List[node_visitor.Failure]:
return kwargs["checker"].perform_final_checks()

@classmethod
def _run_on_files(
cls,
Expand Down
9 changes: 8 additions & 1 deletion pyanalyze/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ def get_files_to_check(cls, include_tests: bool) -> List[str]:
def prepare_constructor_kwargs(cls, kwargs: Mapping[str, Any]) -> Mapping[str, Any]:
return kwargs

@classmethod
def perform_final_checks(cls, kwargs: Mapping[str, Any]) -> List[Failure]:
return []

@classmethod
def main(cls) -> int:
"""Can be used as a main function. Calls the checker on files given on the command line."""
Expand Down Expand Up @@ -520,6 +524,7 @@ def show_error(
obey_ignore: bool = True,
ignore_comment: str = IGNORE_COMMENT,
detail: Optional[str] = None,
save: bool = True,
) -> Optional[Failure]:
"""Shows an error associated with this node.

Expand Down Expand Up @@ -647,7 +652,8 @@ def show_error(
self._changes_for_fixer[self.filename].append(replacement)

error["message"] = message
self.all_failures.append(error)
if save:
self.all_failures.append(error)
sys.stderr.write(message)
sys.stderr.flush()
if self.fail_after_first:
Expand Down Expand Up @@ -710,6 +716,7 @@ def _run_on_files(cls, files: Iterable[str], **kwargs: Any) -> List[Failure]:
else:
for failures, _ in map(cls._check_file_single_arg, args):
all_failures += failures
all_failures += cls.perform_final_checks(kwargs)
return all_failures

@classmethod
Expand Down
8 changes: 7 additions & 1 deletion pyanalyze/reexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ class ErrorContext:
all_failures: List[Failure]

def show_error(
self, node: AST, message: str, error_code: Enum
self,
node: AST,
message: str,
error_code: Enum,
*,
detail: Optional[str] = None,
save: bool = True,
) -> Optional[Failure]:
raise NotImplementedError

Expand Down
4 changes: 4 additions & 0 deletions pyanalyze/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,10 @@ def check_call_preprocessed(
if bound_args is None:
return self.get_default_return()
variables = {key: composite.value for key, (_, composite) in bound_args.items()}

if self.callable is not None:
visitor.record_call(self.callable, variables)

return_value = self.return_value
typevar_values: Dict[TypeVar, Value] = {}
if self.all_typevars:
Expand Down
178 changes: 178 additions & 0 deletions pyanalyze/suggested_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""

Suggest types for untyped code.

"""
import ast
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Iterator, List, Mapping, Sequence, Union

from pyanalyze.safe import safe_isinstance

from .error_code import ErrorCode
from .node_visitor import Failure
from .value import (
AnnotatedValue,
AnySource,
AnyValue,
CallableValue,
CanAssignError,
GenericValue,
KnownValue,
SequenceIncompleteValue,
SubclassValue,
TypedDictValue,
TypedValue,
Value,
MultiValuedValue,
VariableNameValue,
replace_known_sequence_value,
unite_values,
)
from .reexport import ErrorContext
from .signature import Signature

CallArgs = Mapping[str, Value]
FunctionNode = Union[ast.FunctionDef, ast.AsyncFunctionDef]


@dataclass
class CallableData:
node: FunctionNode
ctx: ErrorContext
sig: Signature
calls: List[CallArgs] = field(default_factory=list)

def check(self) -> Iterator[Failure]:
if not self.calls:
return
for param in _extract_params(self.node):
if param.annotation is not None:
continue
sig_param = self.sig.parameters.get(param.arg)
if sig_param is None or not isinstance(sig_param.annotation, AnyValue):
continue # e.g. inferred type for self
all_values = [call[param.arg] for call in self.calls]
all_values = [prepare_type(v) for v in all_values]
all_values = [v for v in all_values if not isinstance(v, AnyValue)]
if not all_values:
continue
suggested = display_suggested_type(unite_values(*all_values))
failure = self.ctx.show_error(
param,
f"Suggested type for parameter {param.arg}",
ErrorCode.suggested_parameter_type,
detail=suggested,
# Otherwise we record it twice in tests. We should ultimately
# refactor error tracking to make it less hacky for things that
# show errors outside of files.
save=False,
)
if failure is not None:
yield failure


@dataclass
class CallableTracker:
callable_to_data: Dict[object, CallableData] = field(default_factory=dict)
callable_to_calls: Dict[object, List[CallArgs]] = field(
default_factory=lambda: defaultdict(list)
)

def record_callable(
self, node: FunctionNode, callable: object, sig: Signature, ctx: ErrorContext
) -> None:
"""Record when we encounter a callable."""
self.callable_to_data[callable] = CallableData(node, ctx, sig)

def record_call(self, callable: object, arguments: Mapping[str, Value]) -> None:
"""Record the actual arguments passed in in a call."""
self.callable_to_calls[callable].append(arguments)

def check(self) -> List[Failure]:
failures = []
for callable, calls in self.callable_to_calls.items():
if callable in self.callable_to_data:
data = self.callable_to_data[callable]
data.calls += calls
failures += data.check()
return failures


def display_suggested_type(value: Value) -> str:
value = prepare_type(value)
if isinstance(value, MultiValuedValue) and value.vals:
cae = CanAssignError("Union", [CanAssignError(str(val)) for val in value.vals])
else:
cae = CanAssignError(str(value))
return str(cae)


def prepare_type(value: Value) -> Value:
"""Simplify a type to turn it into a suggestion."""
if isinstance(value, AnnotatedValue):
return prepare_type(value.value)
elif isinstance(value, SequenceIncompleteValue):
if value.typ is tuple:
return SequenceIncompleteValue(
tuple, [prepare_type(elt) for elt in value.members]
)
else:
return GenericValue(value.typ, [prepare_type(arg) for arg in value.args])
elif isinstance(value, (TypedDictValue, CallableValue)):
return value
elif isinstance(value, GenericValue):
# TODO maybe turn DictIncompleteValue into TypedDictValue?
return GenericValue(value.typ, [prepare_type(arg) for arg in value.args])
elif isinstance(value, VariableNameValue):
return AnyValue(AnySource.unannotated)
elif isinstance(value, KnownValue):
if value.val is None or safe_isinstance(value.val, type):
return value
elif callable(value.val):
return value # TODO get the signature instead and return a CallableValue?
value = replace_known_sequence_value(value)
if isinstance(value, KnownValue):
return TypedValue(type(value.val))
else:
return prepare_type(value)
elif isinstance(value, MultiValuedValue):
vals = [prepare_type(subval) for subval in value.vals]
type_literals = [
v
for v in vals
if isinstance(v, KnownValue) and safe_isinstance(v.val, type)
]
if len(type_literals) > 1:
types = [v.val for v in type_literals if isinstance(v.val, type)]
shared_type = get_shared_type(types)
type_val = SubclassValue(TypedValue(shared_type))
others = [
v
for v in vals
if not isinstance(v, KnownValue) or not safe_isinstance(v.val, type)
]
return unite_values(type_val, *others)
return unite_values(*vals)
else:
return value


def get_shared_type(types: Sequence[type]) -> type:
mros = [t.mro() for t in types]
first, *rest = mros
rest_sets = [set(mro) for mro in rest]
for candidate in first:
if all(candidate in mro for mro in rest_sets):
return candidate
assert False, "should at least have found object"


def _extract_params(node: FunctionNode) -> Iterator[ast.arg]:
yield from node.args.args
if node.args.vararg is not None:
yield node.args.vararg
yield from node.args.kwonlyargs
if node.args.kwarg is not None:
yield node.args.kwarg
Loading