From 940c9b9664250660b05d4a57f6bb80011c5e6da0 Mon Sep 17 00:00:00 2001 From: snarkmaster Date: Sun, 2 Sep 2018 18:00:57 -0700 Subject: [PATCH] Add a `customize_class_mro` plugin hook (#4567) The rationale for this MRO hook is documented on https://github.com/python/mypy/issues/4527 This patch completely addresses my need for customizing the MRO of types that use my metaclass, and I believe it is simple & general enough for other plugin authors. --- mypy/plugin.py | 8 ++++++++ mypy/semanal.py | 39 ++++++++++++++++++++++----------------- mypy/semanal_pass3.py | 8 ++++---- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/mypy/plugin.py b/mypy/plugin.py index 9cbe827ab0b2..2586789830ab 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -214,6 +214,10 @@ def get_base_class_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: return None + def get_customize_class_mro_hook(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return None + T = TypeVar('T') @@ -270,6 +274,10 @@ def get_base_class_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: return self._find_hook(lambda plugin: plugin.get_base_class_hook(fullname)) + def get_customize_class_mro_hook(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return self._find_hook(lambda plugin: plugin.get_customize_class_mro_hook(fullname)) + def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]: for plugin in self._plugins: hook = lookup(plugin) diff --git a/mypy/semanal.py b/mypy/semanal.py index 8036dd8f85dc..3120a7b6a595 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1145,7 +1145,28 @@ def analyze_base_classes(self, defn: ClassDef) -> None: return # TODO: Ideally we should move MRO calculation to a later stage, but this is # not easy, see issue #5536. - calculate_class_mro(defn, self.fail_blocker, self.object_type) + self.calculate_class_mro(defn, self.object_type) + + def calculate_class_mro(self, defn: ClassDef, + obj_type: Optional[Callable[[], Instance]] = None) -> None: + """Calculate method resolution order for a class. + + `obj_type` may be omitted in the third pass when all classes are already analyzed. + It exists just to fill in empty base class list during second pass in case of + an import cycle. + """ + try: + calculate_mro(defn.info, obj_type) + except MroError: + self.fail_blocker('Cannot determine consistent method resolution ' + 'order (MRO) for "%s"' % defn.name, defn) + defn.info.mro = [] + # Allow plugins to alter the MRO to handle the fact that `def mro()` + # on metaclasses permits MRO rewriting. + if defn.fullname: + hook = self.plugin.get_customize_class_mro_hook(defn.fullname) + if hook: + hook(ClassDefContext(defn, Expression(), self)) def update_metaclass(self, defn: ClassDef) -> None: """Lookup for special metaclass declarations, and update defn fields accordingly. @@ -3428,22 +3449,6 @@ def refers_to_class_or_function(node: Expression) -> bool: isinstance(node.node, (TypeInfo, FuncDef, OverloadedFuncDef))) -def calculate_class_mro(defn: ClassDef, fail: Callable[[str, Context], None], - obj_type: Optional[Callable[[], Instance]] = None) -> None: - """Calculate method resolution order for a class. - - `obj_type` may be omitted in the third pass when all classes are already analyzed. - It exists just to fill in empty base class list during second pass in case of - an import cycle. - """ - try: - calculate_mro(defn.info, obj_type) - except MroError: - fail("Cannot determine consistent method resolution order " - '(MRO) for "%s"' % defn.name, defn) - defn.info.mro = [] - - def calculate_mro(info: TypeInfo, obj_type: Optional[Callable[[], Instance]] = None) -> None: """Calculate and set mro (method resolution order). diff --git a/mypy/semanal_pass3.py b/mypy/semanal_pass3.py index cb7c02a9c671..90e2149ef905 100644 --- a/mypy/semanal_pass3.py +++ b/mypy/semanal_pass3.py @@ -30,11 +30,11 @@ from mypy.typeanal import TypeAnalyserPass3, collect_any_types from mypy.typevars import has_no_typevars from mypy.semanal_shared import PRIORITY_FORWARD_REF, PRIORITY_TYPEVAR_VALUES +from mypy.semanal import SemanticAnalyzerPass2 from mypy.subtypes import is_subtype from mypy.sametypes import is_same_type from mypy.scope import Scope from mypy.semanal_shared import SemanticAnalyzerCoreInterface -import mypy.semanal class SemanticAnalyzerPass3(TraverserVisitor, SemanticAnalyzerCoreInterface): @@ -45,7 +45,7 @@ class SemanticAnalyzerPass3(TraverserVisitor, SemanticAnalyzerCoreInterface): """ def __init__(self, modules: Dict[str, MypyFile], errors: Errors, - sem: 'mypy.semanal.SemanticAnalyzerPass2') -> None: + sem: SemanticAnalyzerPass2) -> None: self.modules = modules self.errors = errors self.sem = sem @@ -138,7 +138,7 @@ def visit_class_def(self, tdef: ClassDef) -> None: # import loop. (Only do so if we succeeded the first time.) if tdef.info.mro: tdef.info.mro = [] # Force recomputation - mypy.semanal.calculate_class_mro(tdef, self.fail_blocker) + self.sem.calculate_class_mro(tdef) if tdef.analyzed is not None: # Also check synthetic types associated with this ClassDef. # Currently these are TypedDict, and NamedTuple. @@ -230,7 +230,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: self.analyze_info(analyzed.info) if analyzed.info and analyzed.info.mro: analyzed.info.mro = [] # Force recomputation - mypy.semanal.calculate_class_mro(analyzed.info.defn, self.fail_blocker) + self.sem.calculate_class_mro(analyzed.info.defn) if isinstance(analyzed, TypeVarExpr): types = [] if analyzed.upper_bound: