From a3abd3644c5f2c84d2fa65fb621327a6c6a0e695 Mon Sep 17 00:00:00 2001 From: Hugues Date: Fri, 29 Apr 2022 07:21:22 -0700 Subject: [PATCH] errors: speedup for large error counts and improve error filtering (#12631) * errors: speedup for large error counts We have a legacy codebase with many errors across many files ``` Found 7995 errors in 2218 files (checked 21364 source files) ``` For historical reasons, it hasn't been practical to fix all of these yet, and we've been slowly chipping at them over time. Profiling shows that `is_blockers` is the biggest single hotspot, taking roughly 1min, and `total_errors` account for another 11s. Instead of computing those values on read by iterating over all errors, update auxiliary variables appropriately every time a new error is recorded. * errors: unify and optimize error filtering Instead of maintaining two separate mechanism to filter out errors (boolean flag in MessageBuilder and explicit copy of MessageBuilder/Errors) expand ErrorWatcher to support all relevant usage patterns and update all usages accordingly. This is both cleaner and more robust than the previous approach, and should also offer a slight performance improvement by reducing allocations. --- mypy/checker.py | 51 +++++------ mypy/checkexpr.py | 188 +++++++++++++++++++---------------------- mypy/checkmember.py | 17 ++-- mypy/checkpattern.py | 77 ++++++++--------- mypy/checkstrformat.py | 4 +- mypy/errors.py | 105 ++++++++++++++++++++--- mypy/messages.py | 75 +++++----------- 7 files changed, 277 insertions(+), 240 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 24f101421ff4..002f28d4db6c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2196,7 +2196,7 @@ def is_raising_or_empty(self, s: Statement) -> bool: if isinstance(s.expr, EllipsisExpr): return True elif isinstance(s.expr, CallExpr): - with self.expr_checker.msg.disable_errors(): + with self.expr_checker.msg.filter_errors(): typ = get_proper_type(self.expr_checker.accept( s.expr, allow_none_return=True, always_allow_any=True)) @@ -3377,7 +3377,7 @@ def check_member_assignment(self, instance_type: Type, attribute_type: Type, # For non-overloaded setters, the result should be type-checked like a regular assignment. # Hence, we first only try to infer the type by using the rvalue as type context. type_context = rvalue - with self.msg.disable_errors(): + with self.msg.filter_errors(): _, inferred_dunder_set_type = self.expr_checker.check_call( dunder_set_type, [TempNode(instance_type, context=context), type_context], @@ -4312,9 +4312,6 @@ def _make_fake_typeinfo_and_full_name( ) return info_, full_name_ - old_msg = self.msg - new_msg = old_msg.clean_copy() - self.msg = new_msg base_classes = _get_base_classes(instances) # We use the pretty_names_list for error messages but can't # use it for the real name that goes into the symbol table @@ -4322,24 +4319,22 @@ def _make_fake_typeinfo_and_full_name( pretty_names_list = pretty_seq(format_type_distinctly(*base_classes, bare=True), "and") try: info, full_name = _make_fake_typeinfo_and_full_name(base_classes, curr_module) - self.check_multiple_inheritance(info) - if new_msg.is_errors(): + with self.msg.filter_errors() as local_errors: + self.check_multiple_inheritance(info) + if local_errors.has_new_errors(): # "class A(B, C)" unsafe, now check "class A(C, B)": - new_msg = new_msg.clean_copy() - self.msg = new_msg base_classes = _get_base_classes(instances[::-1]) info, full_name = _make_fake_typeinfo_and_full_name(base_classes, curr_module) - self.check_multiple_inheritance(info) + with self.msg.filter_errors() as local_errors: + self.check_multiple_inheritance(info) info.is_intersection = True except MroError: if self.should_report_unreachable_issues(): - old_msg.impossible_intersection( + self.msg.impossible_intersection( pretty_names_list, "inconsistent method resolution order", ctx) return None - finally: - self.msg = old_msg - if new_msg.is_errors(): + if local_errors.has_new_errors(): if self.should_report_unreachable_issues(): self.msg.impossible_intersection( pretty_names_list, "incompatible method signatures", ctx) @@ -4974,20 +4969,20 @@ def refine_parent_types(self, member_name = expr.name def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: - msg_copy = self.msg.clean_copy() - member_type = analyze_member_access( - name=member_name, - typ=new_parent_type, - context=parent_expr, - is_lvalue=False, - is_super=False, - is_operator=False, - msg=msg_copy, - original_type=new_parent_type, - chk=self, - in_literal_context=False, - ) - if msg_copy.is_errors(): + with self.msg.filter_errors() as w: + member_type = analyze_member_access( + name=member_name, + typ=new_parent_type, + context=parent_expr, + is_lvalue=False, + is_super=False, + is_operator=False, + msg=self.msg, + original_type=new_parent_type, + chk=self, + in_literal_context=False, + ) + if w.has_new_errors(): return None else: return member_type diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index c6f4d24c1815..f6aef2e361cc 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1,6 +1,6 @@ """Expression type checker. This file is conceptually part of TypeChecker.""" -from mypy.backports import OrderedDict, nullcontext +from mypy.backports import OrderedDict from contextlib import contextmanager import itertools from typing import ( @@ -8,7 +8,7 @@ ) from typing_extensions import ClassVar, Final, overload, TypeAlias as _TypeAlias -from mypy.errors import report_internal_error +from mypy.errors import report_internal_error, ErrorWatcher from mypy.typeanal import ( has_any_from_unimported_type, check_for_explicit_any, set_any_tvars, expand_type_alias, make_optional_type, @@ -887,7 +887,7 @@ def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str res: List[Type] = [] for typ in object_type.relevant_items(): # Member access errors are already reported when visiting the member expression. - with self.msg.disable_errors(): + with self.msg.filter_errors(): item = analyze_member_access(member, typ, e, False, False, False, self.msg, original_type=object_type, chk=self.chk, in_literal_context=self.is_literal_context(), @@ -1244,7 +1244,7 @@ def infer_function_type_arguments(self, callee_type: CallableType, # due to partial available context information at this time, but # these errors can be safely ignored as the arguments will be # inferred again later. - with self.msg.disable_errors(): + with self.msg.filter_errors(): arg_types = self.infer_arg_types_in_context( callee_type, args, arg_kinds, formal_to_actual) @@ -1594,13 +1594,13 @@ def check_overload_call(self, unioned_result: Optional[Tuple[Type, Type]] = None union_interrupted = False # did we try all union combinations? if any(self.real_union(arg) for arg in arg_types): - unioned_errors = arg_messages.clean_copy() try: - unioned_return = self.union_overload_result(plausible_targets, args, - arg_types, arg_kinds, arg_names, - callable_name, object_type, - context, - arg_messages=unioned_errors) + with arg_messages.filter_errors(): + unioned_return = self.union_overload_result(plausible_targets, args, + arg_types, arg_kinds, arg_names, + callable_name, object_type, + context, + arg_messages=arg_messages) except TooManyUnions: union_interrupted = True else: @@ -1751,29 +1751,18 @@ def infer_overload_return_type(self, args_contain_any = any(map(has_any_type, arg_types)) for typ in plausible_targets: - overload_messages = self.msg.clean_copy() - prev_messages = self.msg assert self.msg is self.chk.msg - self.msg = overload_messages - self.chk.msg = overload_messages - try: - # Passing `overload_messages` as the `arg_messages` parameter doesn't - # seem to reliably catch all possible errors. - # TODO: Figure out why + with self.msg.filter_errors() as w: ret_type, infer_type = self.check_call( callee=typ, args=args, arg_kinds=arg_kinds, arg_names=arg_names, context=context, - arg_messages=overload_messages, + arg_messages=self.msg, callable_name=callable_name, object_type=object_type) - finally: - self.chk.msg = prev_messages - self.msg = prev_messages - - is_match = not overload_messages.is_errors() + is_match = not w.has_new_errors() if is_match: # Return early if possible; otherwise record info so we can # check for ambiguity due to 'Any' below. @@ -2254,15 +2243,15 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # Keep track of whether we get type check errors (these won't be reported, they # are just to verify whether something is valid typing wise). - local_errors = self.msg.clean_copy() - _, method_type = self.check_method_call_by_name( - method='__contains__', - base_type=right_type, - args=[left], - arg_kinds=[ARG_POS], - context=e, - local_errors=local_errors, - ) + with self.msg.filter_errors(save_filtered_errors=True) as local_errors: + _, method_type = self.check_method_call_by_name( + method='__contains__', + base_type=right_type, + args=[left], + arg_kinds=[ARG_POS], + context=e, + ) + sub_result = self.bool_type() # Container item type for strict type overlap checks. Note: we need to only # check for nominal type, because a usual "Unsupported operands for in" @@ -2272,7 +2261,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: if isinstance(right_type, PartialType): # We don't really know if this is an error or not, so just shut up. pass - elif (local_errors.is_errors() and + elif (local_errors.has_new_errors() and # is_valid_var_arg is True for any Iterable self.is_valid_var_arg(right_type)): _, itertype = self.chk.analyze_iterable_item_type(right) @@ -2285,20 +2274,22 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: if not is_subtype(left_type, itertype): self.msg.unsupported_operand_types('in', left_type, right_type, e) # Only show dangerous overlap if there are no other errors. - elif (not local_errors.is_errors() and cont_type and + elif (not local_errors.has_new_errors() and cont_type and self.dangerous_comparison(left_type, cont_type, original_container=right_type)): self.msg.dangerous_comparison(left_type, cont_type, 'container', e) else: - self.msg.add_errors(local_errors) + self.msg.add_errors(local_errors.filtered_errors()) elif operator in operators.op_methods: method = self.get_operator_method(operator) - err_count = self.msg.errors.total_errors() - sub_result, method_type = self.check_op(method, left_type, right, e, - allow_reverse=True) + + with ErrorWatcher(self.msg.errors) as w: + sub_result, method_type = self.check_op(method, left_type, right, e, + allow_reverse=True) + # Only show dangerous overlap if there are no other errors. See # testCustomEqCheckStrictEquality for an example. - if self.msg.errors.total_errors() == err_count and operator in ('==', '!='): + if not w.has_new_errors() and operator in ('==', '!='): right_type = self.accept(right) # We suppress the error if there is a custom __eq__() method on either # side. User defined (or even standard library) classes can define this @@ -2525,24 +2516,20 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: if not self.has_member(base_type, op_name): return None - local_errors = msg.clean_copy() - - member = analyze_member_access( - name=op_name, - typ=base_type, - is_lvalue=False, - is_super=False, - is_operator=True, - original_type=base_type, - context=context, - msg=local_errors, - chk=self.chk, - in_literal_context=self.is_literal_context() - ) - if local_errors.is_errors(): - return None - else: - return member + with self.msg.filter_errors() as w: + member = analyze_member_access( + name=op_name, + typ=base_type, + is_lvalue=False, + is_super=False, + is_operator=True, + original_type=base_type, + context=context, + msg=self.msg, + chk=self.chk, + in_literal_context=self.is_literal_context() + ) + return None if w.has_new_errors() else member def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: """Returns the name of the class that contains the actual definition of attr_name. @@ -2656,11 +2643,11 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: errors = [] results = [] for method, obj, arg in variants: - local_errors = msg.clean_copy() - result = self.check_method_call( - op_name, obj, method, [arg], [ARG_POS], context, local_errors) - if local_errors.is_errors(): - errors.append(local_errors) + with self.msg.filter_errors(save_filtered_errors=True) as local_errors: + result = self.check_method_call( + op_name, obj, method, [arg], [ARG_POS], context) + if local_errors.has_new_errors(): + errors.append(local_errors.filtered_errors()) results.append(result) else: return result @@ -2678,12 +2665,12 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: # call the __op__ method (even though it's missing). if not variants: - local_errors = msg.clean_copy() - result = self.check_method_call_by_name( - op_name, left_type, [right_expr], [ARG_POS], context, local_errors) + with self.msg.filter_errors(save_filtered_errors=True) as local_errors: + result = self.check_method_call_by_name( + op_name, left_type, [right_expr], [ARG_POS], context) - if local_errors.is_errors(): - errors.append(local_errors) + if local_errors.has_new_errors(): + errors.append(local_errors.filtered_errors()) results.append(result) else: # In theory, we should never enter this case, but it seems @@ -2724,23 +2711,23 @@ def check_op(self, method: str, base_type: Type, # Step 1: We first try leaving the right arguments alone and destructure # just the left ones. (Mypy can sometimes perform some more precise inference # if we leave the right operands a union -- see testOperatorWithEmptyListAndSum.) - msg = self.msg.clean_copy() all_results = [] all_inferred = [] - for left_possible_type in left_variants: - result, inferred = self.check_op_reversible( - op_name=method, - left_type=left_possible_type, - left_expr=TempNode(left_possible_type, context=context), - right_type=right_type, - right_expr=arg, - context=context, - msg=msg) - all_results.append(result) - all_inferred.append(inferred) + with self.msg.filter_errors() as local_errors: + for left_possible_type in left_variants: + result, inferred = self.check_op_reversible( + op_name=method, + left_type=left_possible_type, + left_expr=TempNode(left_possible_type, context=context), + right_type=right_type, + right_expr=arg, + context=context, + msg=self.msg) + all_results.append(result) + all_inferred.append(inferred) - if not msg.is_errors(): + if not local_errors.has_new_errors(): results_final = make_simplified_union(all_results) inferred_final = make_simplified_union(all_inferred) return results_final, inferred_final @@ -2765,27 +2752,30 @@ def check_op(self, method: str, base_type: Type, handle_type_alias_type=True) ] - msg = self.msg.clean_copy() all_results = [] all_inferred = [] - for left_possible_type in left_variants: - for right_possible_type, right_expr in right_variants: - result, inferred = self.check_op_reversible( - op_name=method, - left_type=left_possible_type, - left_expr=TempNode(left_possible_type, context=context), - right_type=right_possible_type, - right_expr=right_expr, - context=context, - msg=msg) - all_results.append(result) - all_inferred.append(inferred) - - if msg.is_errors(): - self.msg.add_errors(msg) + with self.msg.filter_errors(save_filtered_errors=True) as local_errors: + for left_possible_type in left_variants: + for right_possible_type, right_expr in right_variants: + result, inferred = self.check_op_reversible( + op_name=method, + left_type=left_possible_type, + left_expr=TempNode(left_possible_type, context=context), + right_type=right_possible_type, + right_expr=right_expr, + context=context, + msg=self.msg) + all_results.append(result) + all_inferred.append(inferred) + + if local_errors.has_new_errors(): + self.msg.add_errors(local_errors.filtered_errors()) # Point any notes to the same location as an existing message. - recent_context = msg.most_recent_context() + err = local_errors.filtered_errors()[-1] + recent_context = TempNode(NoneType()) + recent_context.line = err.line + recent_context.column = err.column if len(left_variants) >= 2 and len(right_variants) >= 2: self.msg.warn_both_operands_are_from_unions(recent_context) elif len(left_variants) >= 2: @@ -2864,7 +2854,7 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: # If right_map is None then we know mypy considers the right branch # to be unreachable and therefore any errors found in the right branch # should be suppressed. - with (self.msg.disable_errors() if right_map is None else nullcontext()): + with self.msg.filter_errors(filter_errors=right_map is None): right_type = self.analyze_cond_branch(right_map, e.right, expanded_left_type) if left_map is None and right_map is None: diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 16bd0e074c3f..2659ad18ed6e 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -273,13 +273,11 @@ def analyze_type_type_member_access(name: str, # Similar to analyze_type_callable_attribute_access. item = None fallback = mx.named_type('builtins.type') - ignore_messages = mx.msg.copy() - ignore_messages.disable_errors().__enter__() if isinstance(typ.item, Instance): item = typ.item elif isinstance(typ.item, AnyType): - mx = mx.copy_modified(messages=ignore_messages) - return _analyze_member_access(name, fallback, mx, override_info) + with mx.msg.filter_errors(): + return _analyze_member_access(name, fallback, mx, override_info) elif isinstance(typ.item, TypeVarType): upper_bound = get_proper_type(typ.item.upper_bound) if isinstance(upper_bound, Instance): @@ -287,8 +285,8 @@ def analyze_type_type_member_access(name: str, elif isinstance(upper_bound, TupleType): item = tuple_fallback(upper_bound) elif isinstance(upper_bound, AnyType): - mx = mx.copy_modified(messages=ignore_messages) - return _analyze_member_access(name, fallback, mx, override_info) + with mx.msg.filter_errors(): + return _analyze_member_access(name, fallback, mx, override_info) elif isinstance(typ.item, TupleType): item = tuple_fallback(typ.item) elif isinstance(typ.item, FunctionLike) and typ.item.is_type_obj(): @@ -297,6 +295,7 @@ def analyze_type_type_member_access(name: str, # Access member on metaclass object via Type[Type[C]] if isinstance(typ.item.item, Instance): item = typ.item.item.type.metaclass_type + ignore_messages = False if item and not mx.is_operator: # See comment above for why operators are skipped result = analyze_class_attribute_access(item, name, mx, override_info) @@ -305,10 +304,12 @@ def analyze_type_type_member_access(name: str, return result else: # We don't want errors on metaclass lookup for classes with Any fallback - mx = mx.copy_modified(messages=ignore_messages) + ignore_messages = True if item is not None: fallback = item.type.metaclass_type or fallback - return _analyze_member_access(name, fallback, mx, override_info) + + with mx.msg.filter_errors(filter_errors=ignore_messages): + return _analyze_member_access(name, fallback, mx, override_info) def analyze_union_member_access(name: str, typ: UnionType, mx: MemberContext) -> Type: diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index e1d4f9fe285e..84b9acae1aa2 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -416,29 +416,29 @@ def get_mapping_item_type(self, mapping_type: Type, key: Expression ) -> Optional[Type]: - local_errors = self.msg.clean_copy() - local_errors.disable_count = 0 mapping_type = get_proper_type(mapping_type) if isinstance(mapping_type, TypedDictType): - result: Optional[Type] = self.chk.expr_checker.visit_typeddict_index_expr( - mapping_type, key, local_errors=local_errors) + with self.msg.filter_errors() as local_errors: + result: Optional[Type] = self.chk.expr_checker.visit_typeddict_index_expr( + mapping_type, key) + has_local_errors = local_errors.has_new_errors() # If we can't determine the type statically fall back to treating it as a normal # mapping - if local_errors.is_errors(): - local_errors = self.msg.clean_copy() - local_errors.disable_count = 0 + if has_local_errors: + with self.msg.filter_errors() as local_errors: + result = self.get_simple_mapping_item_type(pattern, + mapping_type, + key, + self.msg) + + if local_errors.has_new_errors(): + result = None + else: + with self.msg.filter_errors(): result = self.get_simple_mapping_item_type(pattern, mapping_type, key, - local_errors) - - if local_errors.is_errors(): - result = None - else: - result = self.get_simple_mapping_item_type(pattern, - mapping_type, - key, - local_errors) + self.msg) return result def get_simple_mapping_item_type(self, @@ -507,14 +507,14 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: pattern_type.captures) captures = pattern_type.captures else: - local_errors = self.msg.clean_copy() - match_args_type = analyze_member_access("__match_args__", typ, o, - False, False, False, - local_errors, - original_type=typ, - chk=self.chk) - - if local_errors.is_errors(): + with self.msg.filter_errors() as local_errors: + match_args_type = analyze_member_access("__match_args__", typ, o, + False, False, False, + self.msg, + original_type=typ, + chk=self.chk) + has_local_errors = local_errors.has_new_errors() + if has_local_errors: self.msg.fail(message_registry.MISSING_MATCH_ARGS.format(typ), o) return self.early_non_match() @@ -561,20 +561,21 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: can_match = True for keyword, pattern in keyword_pairs: key_type: Optional[Type] = None - local_errors = self.msg.clean_copy() - if keyword is not None: - key_type = analyze_member_access(keyword, - narrowed_type, - pattern, - False, - False, - False, - local_errors, - original_type=new_type, - chk=self.chk) - else: - key_type = AnyType(TypeOfAny.from_error) - if local_errors.is_errors() or key_type is None: + with self.msg.filter_errors() as local_errors: + if keyword is not None: + key_type = analyze_member_access(keyword, + narrowed_type, + pattern, + False, + False, + False, + self.msg, + original_type=new_type, + chk=self.chk) + else: + key_type = AnyType(TypeOfAny.from_error) + has_local_errors = local_errors.has_new_errors() + if has_local_errors or key_type is None: key_type = AnyType(TypeOfAny.from_error) self.msg.fail(message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(typ, keyword), pattern) diff --git a/mypy/checkstrformat.py b/mypy/checkstrformat.py index dcb711150870..995f3073ba79 100644 --- a/mypy/checkstrformat.py +++ b/mypy/checkstrformat.py @@ -17,6 +17,7 @@ ) from typing_extensions import Final, TYPE_CHECKING, TypeAlias as _TypeAlias +from mypy.errors import Errors from mypy.types import ( Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type, TypeVarType, LiteralType, get_proper_types @@ -512,8 +513,7 @@ def apply_field_accessors(self, spec: ConversionSpecifier, repl: Expression, return repl assert spec.field - # This is a bit of a dirty trick, but it looks like this is the simplest way. - temp_errors = self.msg.clean_copy().errors + temp_errors = Errors() dummy = DUMMY_FIELD_NAME + spec.field[len(spec.key):] temp_ast: Node = parse( dummy, fnam="", module=None, options=self.chk.options, errors=temp_errors diff --git a/mypy/errors.py b/mypy/errors.py index 20e43fa810f9..6dc0a21e7904 100644 --- a/mypy/errors.py +++ b/mypy/errors.py @@ -4,8 +4,8 @@ from mypy.backports import OrderedDict from collections import defaultdict -from typing import Tuple, List, TypeVar, Set, Dict, Optional, TextIO, Callable -from typing_extensions import Final +from typing import Tuple, List, TypeVar, Set, Dict, Optional, TextIO, Callable, Union +from typing_extensions import Final, Literal from mypy.scope import Scope from mypy.options import Options @@ -122,6 +122,58 @@ def __init__(self, Optional[ErrorCode]] +class ErrorWatcher: + """Context manager that can be used to keep track of new errors recorded + around a given operation. + + Errors maintain a stack of such watchers. The handler is called starting + at the top of the stack, and is propagated down the stack unless filtered + out by one of the ErrorWatcher instances. + """ + def __init__(self, errors: 'Errors', *, + filter_errors: Union[bool, Callable[[str, ErrorInfo], bool]] = False, + save_filtered_errors: bool = False): + self.errors = errors + self._has_new_errors = False + self._filter = filter_errors + self._filtered: Optional[List[ErrorInfo]] = [] if save_filtered_errors else None + + def __enter__(self) -> 'ErrorWatcher': + self.errors._watchers.append(self) + return self + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> Literal[False]: + assert self == self.errors._watchers.pop() + return False + + def on_error(self, file: str, info: ErrorInfo) -> bool: + """Handler called when a new error is recorded. + + The default implementation just sets the has_new_errors flag + + Return True to filter out the error, preventing it from being seen by other + ErrorWatcher further down the stack and from being recorded by Errors + """ + self._has_new_errors = True + if isinstance(self._filter, bool): + should_filter = self._filter + elif callable(self._filter): + should_filter = self._filter(file, info) + else: + raise AssertionError("invalid error filter: {}".format(type(self._filter))) + if should_filter and self._filtered is not None: + self._filtered.append(info) + + return should_filter + + def has_new_errors(self) -> bool: + return self._has_new_errors + + def filtered_errors(self) -> List[ErrorInfo]: + assert self._filtered is not None + return self._filtered + + class Errors: """Container for compile errors. @@ -134,6 +186,9 @@ class Errors: # files were processed. error_info_map: Dict[str, List[ErrorInfo]] + # optimization for legacy codebases with many files with errors + has_blockers: Set[str] + # Files that we have reported the errors for flushed_files: Set[str] @@ -178,6 +233,8 @@ class Errors: # in some cases to avoid reporting huge numbers of errors. seen_import_error = False + _watchers: List[ErrorWatcher] = [] + def __init__(self, show_error_context: bool = False, show_column_numbers: bool = False, @@ -209,6 +266,7 @@ def initialize(self) -> None: self.used_ignored_lines = defaultdict(lambda: defaultdict(list)) self.ignored_files = set() self.only_once_messages = set() + self.has_blockers = set() self.scope = None self.target_module = None self.seen_import_error = False @@ -234,9 +292,6 @@ def copy(self) -> 'Errors': new.seen_import_error = self.seen_import_error return new - def total_errors(self) -> int: - return sum(len(errs) for errs in self.error_info_map.values()) - def set_ignore_prefix(self, prefix: str) -> None: """Set path prefix that will be removed from all paths.""" prefix = os.path.normpath(prefix) @@ -354,14 +409,39 @@ def report(self, def _add_error_info(self, file: str, info: ErrorInfo) -> None: assert file not in self.flushed_files + # process the stack of ErrorWatchers before modifying any internal state + # in case we need to filter out the error entirely + if self._filter_error(file, info): + return if file not in self.error_info_map: self.error_info_map[file] = [] self.error_info_map[file].append(info) + if info.blocker: + self.has_blockers.add(file) if info.code is IMPORT: self.seen_import_error = True + def _filter_error(self, file: str, info: ErrorInfo) -> bool: + """ + process ErrorWatcher stack from top to bottom, + stopping early if error needs to be filtered out + """ + i = len(self._watchers) + while i > 0: + i -= 1 + w = self._watchers[i] + if w.on_error(file, info): + return True + return False + def add_error_info(self, info: ErrorInfo) -> None: file, line, end_line = info.origin + # process the stack of ErrorWatchers before modifying any internal state + # in case we need to filter out the error entirely + # NB: we need to do this both here and in _add_error_info, otherwise we + # might incorrectly update the sets of ignored or only_once messages + if self._filter_error(file, info): + return if not info.blocker: # Blockers cannot be ignored if file in self.ignored_lines: # It's okay if end_line is *before* line. @@ -476,12 +556,16 @@ def clear_errors_in_targets(self, path: str, targets: Set[str]) -> None: """Remove errors in specific fine-grained targets within a file.""" if path in self.error_info_map: new_errors = [] + has_blocker = False for info in self.error_info_map[path]: if info.target not in targets: new_errors.append(info) + has_blocker |= info.blocker elif info.only_once: self.only_once_messages.remove(info.message) self.error_info_map[path] = new_errors + if not has_blocker and path in self.has_blockers: + self.has_blockers.remove(path) def generate_unused_ignore_errors(self, file: str) -> None: ignored_lines = self.ignored_lines[file] @@ -551,19 +635,14 @@ def is_errors(self) -> bool: """Are there any generated messages?""" return bool(self.error_info_map) - def is_real_errors(self) -> bool: - """Are there any generated errors (not just notes, for example)?""" - return any(info.severity == 'error' - for infos in self.error_info_map.values() for info in infos) - def is_blockers(self) -> bool: """Are the any errors that are blockers?""" - return any(err for errs in self.error_info_map.values() for err in errs if err.blocker) + return bool(self.has_blockers) def blocker_module(self) -> Optional[str]: """Return the module with a blocking error, or None if not possible.""" - for errs in self.error_info_map.values(): - for err in errs: + for path in self.has_blockers: + for err in self.error_info_map[path]: if err.blocker: return err.module return None diff --git a/mypy/messages.py b/mypy/messages.py index 23ab172f5499..d499a574d527 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -21,7 +21,7 @@ from typing_extensions import Final from mypy.erasetype import erase_type -from mypy.errors import Errors +from mypy.errors import Errors, ErrorWatcher, ErrorInfo from mypy.types import ( Type, CallableType, Instance, TypeVarType, TupleType, TypedDictType, LiteralType, UnionType, NoneType, AnyType, Overloaded, FunctionLike, DeletedType, TypeType, @@ -33,7 +33,7 @@ TypeInfo, Context, MypyFile, FuncDef, reverse_builtin_aliases, ArgKind, ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode, - CallExpr, IndexExpr, StrExpr, SymbolTable, TempNode, SYMBOL_FUNCBASE_TYPES + CallExpr, IndexExpr, StrExpr, SymbolTable, SYMBOL_FUNCBASE_TYPES ) from mypy.operators import op_methods, op_methods_to_symbols from mypy.subtypes import ( @@ -106,66 +106,38 @@ class MessageBuilder: modules: Dict[str, MypyFile] - # Number of times errors have been disabled. - disable_count = 0 - # Hack to deduplicate error messages from union types - disable_type_names_count = 0 + _disable_type_names: List[bool] def __init__(self, errors: Errors, modules: Dict[str, MypyFile]) -> None: self.errors = errors self.modules = modules - self.disable_count = 0 - self.disable_type_names_count = 0 + self._disable_type_names = [] # # Helpers # - def copy(self) -> 'MessageBuilder': - new = MessageBuilder(self.errors.copy(), self.modules) - new.disable_count = self.disable_count - new.disable_type_names_count = self.disable_type_names_count - return new - - def clean_copy(self) -> 'MessageBuilder': - errors = self.errors.copy() - errors.error_info_map = OrderedDict() - return MessageBuilder(errors, self.modules) + def filter_errors(self, *, filter_errors: bool = True, + save_filtered_errors: bool = False) -> ErrorWatcher: + return ErrorWatcher(self.errors, filter_errors=filter_errors, + save_filtered_errors=save_filtered_errors) - def add_errors(self, messages: 'MessageBuilder') -> None: + def add_errors(self, errors: List[ErrorInfo]) -> None: """Add errors in messages to this builder.""" - if self.disable_count <= 0: - for errs in messages.errors.error_info_map.values(): - for info in errs: - self.errors.add_error_info(info) - - @contextmanager - def disable_errors(self) -> Iterator[None]: - self.disable_count += 1 - try: - yield - finally: - self.disable_count -= 1 + for info in errors: + self.errors.add_error_info(info) @contextmanager def disable_type_names(self) -> Iterator[None]: - self.disable_type_names_count += 1 + self._disable_type_names.append(True) try: yield finally: - self.disable_type_names_count -= 1 - - def is_errors(self) -> bool: - return self.errors.is_errors() + self._disable_type_names.pop() - def most_recent_context(self) -> Context: - """Return a dummy context matching the most recent generated error in current file.""" - line, column = self.errors.most_recent_error_location() - node = TempNode(NoneType()) - node.line = line - node.column = column - return node + def are_type_names_disabled(self) -> bool: + return len(self._disable_type_names) > 0 and self._disable_type_names[-1] def report(self, msg: str, @@ -184,12 +156,11 @@ def report(self, end_line = context.end_line else: end_line = None - if self.disable_count <= 0: - self.errors.report(context.get_line() if context else -1, - context.get_column() if context else -1, - msg, severity=severity, file=file, offset=offset, - origin_line=origin.get_line() if origin else None, - end_line=end_line, code=code, allow_dups=allow_dups) + self.errors.report(context.get_line() if context else -1, + context.get_column() if context else -1, + msg, severity=severity, file=file, offset=offset, + origin_line=origin.get_line() if origin else None, + end_line=end_line, code=code, allow_dups=allow_dups) def fail(self, msg: str, @@ -309,7 +280,7 @@ def has_no_attr(self, extra = ' (not iterable)' elif member == '__aiter__': extra = ' (not async iterable)' - if not self.disable_type_names_count: + if not self.are_type_names_disabled(): failed = False if isinstance(original_type, Instance) and original_type.type.names: alternatives = set(original_type.type.names.keys()) @@ -391,7 +362,7 @@ def unsupported_operand_types(self, else: right_str = format_type(right_type) - if self.disable_type_names_count: + if self.are_type_names_disabled(): msg = 'Unsupported operand types for {} (likely involving Union)'.format(op) else: msg = 'Unsupported operand types for {} ({} and {})'.format( @@ -400,7 +371,7 @@ def unsupported_operand_types(self, def unsupported_left_operand(self, op: str, typ: Type, context: Context) -> None: - if self.disable_type_names_count: + if self.are_type_names_disabled(): msg = 'Unsupported left operand type for {} (some union)'.format(op) else: msg = 'Unsupported left operand type for {} ({})'.format(