diff --git a/docs/changelog.md b/docs/changelog.md index 8cda416c..bd2f2b8a 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,9 @@ ## Unreleased +- Add check that suggests parameter and return types for untyped + functions, using the new `suggested_parameter_type` and + `suggested_return_type` codes (#358) - Extract constraints from multi-comparisons (`a < b < c`) (#354) - Support positional-only arguments with the `__` prefix outside of stubs (#353) diff --git a/pyanalyze/__init__.py b/pyanalyze/__init__.py index 931cf8df..32b2dc7b 100644 --- a/pyanalyze/__init__.py +++ b/pyanalyze/__init__.py @@ -26,6 +26,7 @@ from . import safe from . import signature from . import stacked_scopes +from . import suggested_type from . import test_config from . import type_object from . import typeshed @@ -44,3 +45,4 @@ used(value.UNRESOLVED_VALUE) # keeping it around for now just in case used(reexport) used(checker) +used(suggested_type) diff --git a/pyanalyze/checker.py b/pyanalyze/checker.py index dbd7040a..633aa676 100644 --- a/pyanalyze/checker.py +++ b/pyanalyze/checker.py @@ -9,12 +9,14 @@ import sys from typing import Iterable, Iterator, List, Set, Tuple, Union, Dict +from .node_visitor import Failure from .value import TypedValue from .arg_spec import ArgSpecCache from .config import Config from .reexport import ImplicitReexportTracker from .safe import is_instance_of_typing_name, is_typing_name, safe_getattr from .type_object import TypeObject, get_mro +from .suggested_type import CallableTracker @dataclass @@ -22,6 +24,7 @@ class Checker: config: Config arg_spec_cache: ArgSpecCache = field(init=False) reexport_tracker: ImplicitReexportTracker = field(init=False) + callable_tracker: CallableTracker = field(init=False) type_object_cache: Dict[Union[type, super, str], TypeObject] = field( default_factory=dict, init=False, repr=False ) @@ -32,6 +35,10 @@ class Checker: def __post_init__(self) -> None: self.arg_spec_cache = ArgSpecCache(self.config) self.reexport_tracker = ImplicitReexportTracker(self.config) + self.callable_tracker = CallableTracker() + + def perform_final_checks(self) -> List[Failure]: + return self.callable_tracker.check() def get_additional_bases(self, typ: Union[type, super]) -> Set[type]: return self.config.get_additional_bases(typ) diff --git a/pyanalyze/error_code.py b/pyanalyze/error_code.py index 077e9d90..e7344026 100644 --- a/pyanalyze/error_code.py +++ b/pyanalyze/error_code.py @@ -85,12 +85,16 @@ class ErrorCode(enum.Enum): no_return_may_return = 68 implicit_reexport = 69 invalid_context_manager = 70 + suggested_return_type = 71 + suggested_parameter_type = 72 # Allow testing unannotated functions without too much fuss DISABLED_IN_TESTS = { ErrorCode.missing_return_annotation, ErrorCode.missing_parameter_annotation, + ErrorCode.suggested_return_type, + ErrorCode.suggested_parameter_type, } @@ -193,6 +197,8 @@ class ErrorCode(enum.Enum): ErrorCode.no_return_may_return: "Function is annotated as NoReturn but may return", ErrorCode.implicit_reexport: "Use of implicitly re-exported name", ErrorCode.invalid_context_manager: "Use of invalid object in with or async with", + ErrorCode.suggested_return_type: "Suggested return type", + ErrorCode.suggested_parameter_type: "Suggested parameter type", } diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index e7167344..1e6256b3 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -109,6 +109,7 @@ ARGS, KWARGS, ) +from .suggested_type import CallArgs, display_suggested_type from .asynq_checker import AsyncFunctionKind, AsynqChecker, FunctionInfo from .yield_checker import YieldChecker from .type_object import TypeObject, get_mro @@ -1306,6 +1307,17 @@ def visit_FunctionDef( else: potential_function = None + if ( + potential_function is not None + and self.settings + and self.settings[ErrorCode.suggested_parameter_type] + ): + sig = self.signature_from_value(KnownValue(potential_function)) + if isinstance(sig, Signature): + self.checker.callable_tracker.record_callable( + node, potential_function, sig, self + ) + self.yield_checker.reset_yield_checks() # This code handles nested functions @@ -1373,6 +1385,21 @@ def visit_FunctionDef( else: self._show_error_if_checking(node, error_code=ErrorCode.missing_return) + if ( + has_return + and expected_return_value is None + and not info.is_overload + and not any( + decorator == KnownValue(abstractmethod) + for _, decorator in info.decorators + ) + ): + self._show_error_if_checking( + node, + error_code=ErrorCode.suggested_return_type, + detail=display_suggested_type(return_value), + ) + if evaled_function: return evaled_function @@ -1427,6 +1454,10 @@ def visit_FunctionDef( self.log(logging.DEBUG, "No argspec", (potential_function, node)) return KnownValue(potential_function) + def record_call(self, callable: object, arguments: CallArgs) -> None: + if self.settings and self.settings[ErrorCode.suggested_parameter_type]: + self.checker.callable_tracker.record_call(callable, arguments) + def _visit_defaults( self, node: FunctionNode ) -> Tuple[List[Value], List[Optional[Value]]]: @@ -4439,6 +4470,12 @@ def prepare_constructor_kwargs(cls, kwargs: Mapping[str, Any]) -> Mapping[str, A kwargs.setdefault("checker", Checker(cls.config)) return kwargs + @classmethod + def perform_final_checks( + cls, kwargs: Mapping[str, Any] + ) -> List[node_visitor.Failure]: + return kwargs["checker"].perform_final_checks() + @classmethod def _run_on_files( cls, diff --git a/pyanalyze/node_visitor.py b/pyanalyze/node_visitor.py index a944a5f0..a79bf8f7 100644 --- a/pyanalyze/node_visitor.py +++ b/pyanalyze/node_visitor.py @@ -330,6 +330,10 @@ def get_files_to_check(cls, include_tests: bool) -> List[str]: def prepare_constructor_kwargs(cls, kwargs: Mapping[str, Any]) -> Mapping[str, Any]: return kwargs + @classmethod + def perform_final_checks(cls, kwargs: Mapping[str, Any]) -> List[Failure]: + return [] + @classmethod def main(cls) -> int: """Can be used as a main function. Calls the checker on files given on the command line.""" @@ -520,6 +524,7 @@ def show_error( obey_ignore: bool = True, ignore_comment: str = IGNORE_COMMENT, detail: Optional[str] = None, + save: bool = True, ) -> Optional[Failure]: """Shows an error associated with this node. @@ -647,7 +652,8 @@ def show_error( self._changes_for_fixer[self.filename].append(replacement) error["message"] = message - self.all_failures.append(error) + if save: + self.all_failures.append(error) sys.stderr.write(message) sys.stderr.flush() if self.fail_after_first: @@ -710,6 +716,7 @@ def _run_on_files(cls, files: Iterable[str], **kwargs: Any) -> List[Failure]: else: for failures, _ in map(cls._check_file_single_arg, args): all_failures += failures + all_failures += cls.perform_final_checks(kwargs) return all_failures @classmethod diff --git a/pyanalyze/reexport.py b/pyanalyze/reexport.py index a9fbe9a2..acc524b6 100644 --- a/pyanalyze/reexport.py +++ b/pyanalyze/reexport.py @@ -18,7 +18,13 @@ class ErrorContext: all_failures: List[Failure] def show_error( - self, node: AST, message: str, error_code: Enum + self, + node: AST, + message: str, + error_code: Enum, + *, + detail: Optional[str] = None, + save: bool = True, ) -> Optional[Failure]: raise NotImplementedError diff --git a/pyanalyze/signature.py b/pyanalyze/signature.py index c7fc8c77..f5c94bd8 100644 --- a/pyanalyze/signature.py +++ b/pyanalyze/signature.py @@ -727,6 +727,10 @@ def check_call_preprocessed( if bound_args is None: return self.get_default_return() variables = {key: composite.value for key, (_, composite) in bound_args.items()} + + if self.callable is not None: + visitor.record_call(self.callable, variables) + return_value = self.return_value typevar_values: Dict[TypeVar, Value] = {} if self.all_typevars: diff --git a/pyanalyze/suggested_type.py b/pyanalyze/suggested_type.py new file mode 100644 index 00000000..408bde75 --- /dev/null +++ b/pyanalyze/suggested_type.py @@ -0,0 +1,178 @@ +""" + +Suggest types for untyped code. + +""" +import ast +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Iterator, List, Mapping, Sequence, Union + +from pyanalyze.safe import safe_isinstance + +from .error_code import ErrorCode +from .node_visitor import Failure +from .value import ( + AnnotatedValue, + AnySource, + AnyValue, + CallableValue, + CanAssignError, + GenericValue, + KnownValue, + SequenceIncompleteValue, + SubclassValue, + TypedDictValue, + TypedValue, + Value, + MultiValuedValue, + VariableNameValue, + replace_known_sequence_value, + unite_values, +) +from .reexport import ErrorContext +from .signature import Signature + +CallArgs = Mapping[str, Value] +FunctionNode = Union[ast.FunctionDef, ast.AsyncFunctionDef] + + +@dataclass +class CallableData: + node: FunctionNode + ctx: ErrorContext + sig: Signature + calls: List[CallArgs] = field(default_factory=list) + + def check(self) -> Iterator[Failure]: + if not self.calls: + return + for param in _extract_params(self.node): + if param.annotation is not None: + continue + sig_param = self.sig.parameters.get(param.arg) + if sig_param is None or not isinstance(sig_param.annotation, AnyValue): + continue # e.g. inferred type for self + all_values = [call[param.arg] for call in self.calls] + all_values = [prepare_type(v) for v in all_values] + all_values = [v for v in all_values if not isinstance(v, AnyValue)] + if not all_values: + continue + suggested = display_suggested_type(unite_values(*all_values)) + failure = self.ctx.show_error( + param, + f"Suggested type for parameter {param.arg}", + ErrorCode.suggested_parameter_type, + detail=suggested, + # Otherwise we record it twice in tests. We should ultimately + # refactor error tracking to make it less hacky for things that + # show errors outside of files. + save=False, + ) + if failure is not None: + yield failure + + +@dataclass +class CallableTracker: + callable_to_data: Dict[object, CallableData] = field(default_factory=dict) + callable_to_calls: Dict[object, List[CallArgs]] = field( + default_factory=lambda: defaultdict(list) + ) + + def record_callable( + self, node: FunctionNode, callable: object, sig: Signature, ctx: ErrorContext + ) -> None: + """Record when we encounter a callable.""" + self.callable_to_data[callable] = CallableData(node, ctx, sig) + + def record_call(self, callable: object, arguments: Mapping[str, Value]) -> None: + """Record the actual arguments passed in in a call.""" + self.callable_to_calls[callable].append(arguments) + + def check(self) -> List[Failure]: + failures = [] + for callable, calls in self.callable_to_calls.items(): + if callable in self.callable_to_data: + data = self.callable_to_data[callable] + data.calls += calls + failures += data.check() + return failures + + +def display_suggested_type(value: Value) -> str: + value = prepare_type(value) + if isinstance(value, MultiValuedValue) and value.vals: + cae = CanAssignError("Union", [CanAssignError(str(val)) for val in value.vals]) + else: + cae = CanAssignError(str(value)) + return str(cae) + + +def prepare_type(value: Value) -> Value: + """Simplify a type to turn it into a suggestion.""" + if isinstance(value, AnnotatedValue): + return prepare_type(value.value) + elif isinstance(value, SequenceIncompleteValue): + if value.typ is tuple: + return SequenceIncompleteValue( + tuple, [prepare_type(elt) for elt in value.members] + ) + else: + return GenericValue(value.typ, [prepare_type(arg) for arg in value.args]) + elif isinstance(value, (TypedDictValue, CallableValue)): + return value + elif isinstance(value, GenericValue): + # TODO maybe turn DictIncompleteValue into TypedDictValue? + return GenericValue(value.typ, [prepare_type(arg) for arg in value.args]) + elif isinstance(value, VariableNameValue): + return AnyValue(AnySource.unannotated) + elif isinstance(value, KnownValue): + if value.val is None or safe_isinstance(value.val, type): + return value + elif callable(value.val): + return value # TODO get the signature instead and return a CallableValue? + value = replace_known_sequence_value(value) + if isinstance(value, KnownValue): + return TypedValue(type(value.val)) + else: + return prepare_type(value) + elif isinstance(value, MultiValuedValue): + vals = [prepare_type(subval) for subval in value.vals] + type_literals = [ + v + for v in vals + if isinstance(v, KnownValue) and safe_isinstance(v.val, type) + ] + if len(type_literals) > 1: + types = [v.val for v in type_literals if isinstance(v.val, type)] + shared_type = get_shared_type(types) + type_val = SubclassValue(TypedValue(shared_type)) + others = [ + v + for v in vals + if not isinstance(v, KnownValue) or not safe_isinstance(v.val, type) + ] + return unite_values(type_val, *others) + return unite_values(*vals) + else: + return value + + +def get_shared_type(types: Sequence[type]) -> type: + mros = [t.mro() for t in types] + first, *rest = mros + rest_sets = [set(mro) for mro in rest] + for candidate in first: + if all(candidate in mro for mro in rest_sets): + return candidate + assert False, "should at least have found object" + + +def _extract_params(node: FunctionNode) -> Iterator[ast.arg]: + yield from node.args.args + if node.args.vararg is not None: + yield node.args.vararg + yield from node.args.kwonlyargs + if node.args.kwarg is not None: + yield node.args.kwarg diff --git a/pyanalyze/test_name_check_visitor.py b/pyanalyze/test_name_check_visitor.py index 208ac19a..0b694849 100644 --- a/pyanalyze/test_name_check_visitor.py +++ b/pyanalyze/test_name_check_visitor.py @@ -84,10 +84,11 @@ def _run_tree( verbosity = int(os.environ.get("ANS_TEST_SCOPE_VERBOSITY", 0)) mod = _make_module(code_str) kwargs = self.visitor_cls.prepare_constructor_kwargs(kwargs) + new_code = "" with ClassAttributeChecker( self.visitor_cls.config, enabled=check_attributes ) as attribute_checker: - return self.visitor_cls( + visitor = self.visitor_cls( mod.__name__, code_str, tree, @@ -96,7 +97,14 @@ def _run_tree( settings=default_settings, verbosity=verbosity, **kwargs, - ).check_for_test(apply_changes=apply_changes) + ) + result = visitor.check_for_test(apply_changes=apply_changes) + if apply_changes: + result, new_code = result + result += visitor.perform_final_checks(kwargs) + if apply_changes: + return result, new_code + return result class TestAnnotatingNodeVisitor(test_node_visitor.BaseNodeVisitorTester): diff --git a/pyanalyze/test_self.py b/pyanalyze/test_self.py index fbf48a5a..1f725bed 100644 --- a/pyanalyze/test_self.py +++ b/pyanalyze/test_self.py @@ -19,6 +19,8 @@ class PyanalyzeConfig(pyanalyze.config.Config): ErrorCode.missing_parameter_annotation, ErrorCode.unused_variable, ErrorCode.value_always_true, + ErrorCode.suggested_parameter_type, + ErrorCode.suggested_return_type, } diff --git a/pyanalyze/test_suggested_type.py b/pyanalyze/test_suggested_type.py new file mode 100644 index 00000000..1f94d51e --- /dev/null +++ b/pyanalyze/test_suggested_type.py @@ -0,0 +1,60 @@ +# static analysis: ignore +from .suggested_type import prepare_type +from .value import KnownValue, SubclassValue, TypedValue +from .error_code import ErrorCode +from .test_node_visitor import assert_passes +from .test_name_check_visitor import TestNameCheckVisitorBase + + +class TestSuggestedType(TestNameCheckVisitorBase): + @assert_passes(settings={ErrorCode.suggested_return_type: True}) + def test_return(self): + def capybara(): # E: suggested_return_type + return 1 + + def kerodon(cond): # E: suggested_return_type + if cond: + return 1 + else: + return 2 + + @assert_passes(settings={ErrorCode.suggested_parameter_type: True}) + def test_parameter(self): + def capybara(a): # E: suggested_parameter_type + pass + + def annotated(b: int): + pass + + class Mammalia: + # should not suggest a type for this + def method(self): + pass + + def kerodon(unannotated): + capybara(1) + annotated(2) + + m = Mammalia() + m.method() + Mammalia.method(unannotated) + + +class A: + pass + + +class B(A): + + pass + + +class C(A): + pass + + +def test_prepare_type() -> None: + assert prepare_type(KnownValue(int) | KnownValue(str)) == SubclassValue( + TypedValue(object) + ) + assert prepare_type(KnownValue(C) | KnownValue(B)) == SubclassValue(TypedValue(A)) diff --git a/pyanalyze/value.py b/pyanalyze/value.py index 30e4a156..4b1536d7 100644 --- a/pyanalyze/value.py +++ b/pyanalyze/value.py @@ -449,6 +449,9 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "KnownValue": def simplify(self) -> Value: val = replace_known_sequence_value(self) if isinstance(val, KnownValue): + # don't simplify None + if val.val is None: + return self return TypedValue(type(val.val)) return val.simplify()