Skip to content

Commit

Permalink
refactor: Improve visitor getters
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Nov 13, 2021
1 parent 6298ba3 commit 2ea88c0
Showing 1 changed file with 188 additions and 64 deletions.
252 changes: 188 additions & 64 deletions src/griffe/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,45 @@
from __future__ import annotations

import inspect
from ast import (
AST,
AnnAssign,
Assign,
Attribute,
BinOp,
BitOr,
Call,
Constant,
Dict,
Expr,
FormattedValue,
Index,
JoinedStr,
List,
Name,
PyCF_ONLY_AST,
Str,
Subscript,
Tuple,
keyword,
)
from ast import AST as Node
from ast import And as NodeAnd
from ast import AnnAssign as NodeAnnAssign
from ast import Assign as NodeAssign
from ast import Attribute as NodeAttribute
from ast import BinOp as NodeBinOp
from ast import BitOr as NodeBitOr
from ast import BoolOp as NodeBoolOp
from ast import Call as NodeCall
from ast import Compare as NodeCompare
from ast import Constant as NodeConstant
from ast import Dict as NodeDict
from ast import DictComp as NodeDictComp
from ast import Expr as NodeExpr
from ast import FormattedValue as NodeFormattedValue
from ast import GeneratorExp as NodeGeneratorExp
from ast import IfExp as NodeIfExp
from ast import Index as NodeIndex
from ast import JoinedStr as NodeJoinedStr
from ast import Lambda as NodeLambda
from ast import List as NodeList
from ast import ListComp as NodeListComp
from ast import Mult as NodeMult
from ast import Name as NodeName
from ast import Not as NodeNot
from ast import NotEq as NodeNotEq
from ast import Or as NodeOr
from ast import PyCF_ONLY_AST
from ast import Set as NodeSet
from ast import Slice as NodeSlice
from ast import Starred as NodeStarred
from ast import Str as NodeStr
from ast import Subscript as NodeSubscript
from ast import Tuple as NodeTuple
from ast import UAdd as NodeUAdd
from ast import UnaryOp as NodeUnaryOp
from ast import USub as NodeUSub
from ast import comprehension as NodeComprehension
from ast import keyword as NodeKeyword
from itertools import zip_longest
from pathlib import Path

Expand Down Expand Up @@ -66,26 +83,29 @@ def visit(
# ==========================================================
# docstrings
def _get_docstring(node):
if isinstance(node, Expr):
if isinstance(node, NodeExpr):
doc = node.value
elif node.body and isinstance(node.body[0], Expr):
elif node.body and isinstance(node.body[0], NodeExpr):
doc = node.body[0].value
else:
return None
if isinstance(doc, Constant) and isinstance(doc.value, str):
if isinstance(doc, NodeConstant) and isinstance(doc.value, str):
return Docstring(doc.value, doc.lineno, doc.end_lineno)
if isinstance(doc, Str):
if isinstance(doc, NodeStr):
return Docstring(doc.s, doc.lineno, doc.end_lineno)
return None


# ==========================================================
# base classes
def _get_base_class_name(node):
if isinstance(node, Name):
if isinstance(node, NodeName):
return node.id
if isinstance(node, Attribute):
if isinstance(node, NodeAttribute):
return f"{_get_base_class_name(node.value)}.{node.attr}"
# TODO: resolve subscript
if isinstance(node, NodeSubscript):
return f"{_get_base_class_name(node.value)}[{_get_base_class_name(node.slice)}]"


# ==========================================================
Expand All @@ -103,7 +123,7 @@ def _get_attribute_annotation(node):


def _get_binop_annotation(node):
if isinstance(node.op, BitOr):
if isinstance(node.op, NodeBitOr):
return f"{_get_annotation(node.left)} | {_get_annotation(node.right)}"


Expand All @@ -124,14 +144,14 @@ def _get_list_annotation(node):


_node_annotation_map = {
Name: _get_name_annotation,
Constant: _get_constant_annotation,
Attribute: _get_attribute_annotation,
BinOp: _get_binop_annotation,
Subscript: _get_subscript_annotation,
Index: _get_index_annotation,
Tuple: _get_tuple_annotation,
List: _get_list_annotation,
NodeName: _get_name_annotation,
NodeConstant: _get_constant_annotation,
NodeAttribute: _get_attribute_annotation,
NodeBinOp: _get_binop_annotation,
NodeSubscript: _get_subscript_annotation,
NodeIndex: _get_index_annotation,
NodeTuple: _get_tuple_annotation,
NodeList: _get_list_annotation,
}


Expand All @@ -154,8 +174,31 @@ def _get_attribute_value(node):


def _get_binop_value(node):
if isinstance(node.op, BitOr):
return f"{_get_value(node.left)} | {_get_value(node.right)}"
return f"{_get_value(node.left)} {_get_value(node.op)} {_get_value(node.right)}"


def _get_bitor_value(node):
return "|"


def _get_mult_value(node):
return "*"


def _get_unaryop_value(node):
if isinstance(node.op, NodeUSub):
return f"-{_get_value(node.operand)}"
if isinstance(node.op, NodeUAdd):
return f"+{_get_value(node.operand)}"
if isinstance(node.op, NodeNot):
return f"not {_get_value(node.operand)}"


def _get_slice_value(node):
value = f"{_get_value(node.lower) if node.lower else ''}:{_get_value(node.upper) if node.upper else ''}"
if node.step:
value = f"{value}:{_get_value(node.step)}"
return value


def _get_subscript_value(node):
Expand All @@ -166,6 +209,10 @@ def _get_index_value(node):
return _get_value(node.value)


def _get_lambda_value(node):
return f"lambda {_get_value(node.args)}: {_get_value(node.body)}"


def _get_list_value(node):
return "[" + ", ".join(_get_value(el) for el in node.elts) + "]"

Expand All @@ -183,10 +230,18 @@ def _get_dict_value(node):
return "{" + ", ".join(f"{_get_value(key)}: {_get_value(value)}" for key, value in pairs) + "}"


def _get_set_value(node):
return "{" + ", ".join(_get_value(el) for el in node.elts) + "}"


def _get_ellipsis_value(node):
return "..."


def _get_starred_value(node):
return _get_value(node.value)


def _get_formatted_value(node):
return f"{{{_get_value(node.value)}}}"

Expand All @@ -195,6 +250,59 @@ def _get_joinedstr_value(node):
return "".join(_get_value(value) for value in node.values)


def _get_boolop_value(node):
if isinstance(node.op, NodeOr):
return " or ".join(_get_value(value) for value in node.values)
if isinstance(node.op, NodeAnd):
return " and ".join(_get_value(value) for value in node.values)


def _get_compare_value(node):
left = _get_value(node.left)
ops = [_get_value(op) for op in node.ops]
comparators = [_get_value(comparator) for comparator in node.comparators]
return f"{left} " + " ".join(f"{op} {comp}" for op, comp in zip(ops, comparators))


def _get_noteq_value(node):
return "!="


def _get_generatorexp_value(node):
element = _get_value(node.elt)
generators = [_get_value(gen) for gen in node.generators]
return f"{element} " + " ".join(generators)


def _get_listcomp_value(node):
element = _get_value(node.elt)
generators = [_get_value(gen) for gen in node.generators]
return f"[{element} " + " ".join(generators) + "]"


def _get_dictcomp_value(node):
key = _get_value(node.key)
value = _get_value(node.value)
generators = [_get_value(gen) for gen in node.generators]
return f"{{{key}: {value} " + " ".join(generators) + "}"


def _get_comprehension_value(node):
target = _get_value(node.target)
iterable = _get_value(node.iter)
conditions = [_get_value(condition) for condition in node.ifs]
value = f"for {target} in {iterable}"
if conditions:
value = f"{value} if " + " if ".join(conditions)
if node.is_async:
value = f"async {value}"
return value


def _get_ifexp_value(node):
return f"{_get_value(node.body)} if {_get_value(node.test)} else {_get_value(node.orelse)}"


def _get_call_value(node):
posargs = ", ".join(_get_value(arg) for arg in node.args)
kwargs = ", ".join(_get_value(kwarg) for kwarg in node.keywords)
Expand All @@ -211,19 +319,34 @@ def _get_call_value(node):

_node_value_map = {
type(None): lambda _: repr(None),
Name: _get_name_value,
Constant: _get_constant_value,
Attribute: _get_attribute_value,
BinOp: _get_binop_value,
Subscript: _get_subscript_value,
Index: _get_index_value,
List: _get_list_value,
Tuple: _get_tuple_value,
keyword: _get_keyword_value,
Dict: _get_dict_value,
FormattedValue: _get_formatted_value,
JoinedStr: _get_joinedstr_value,
Call: _get_call_value,
NodeName: _get_name_value,
NodeConstant: _get_constant_value,
NodeAttribute: _get_attribute_value,
NodeBinOp: _get_binop_value,
NodeUnaryOp: _get_unaryop_value,
NodeSubscript: _get_subscript_value,
NodeIndex: _get_index_value,
NodeList: _get_list_value,
NodeTuple: _get_tuple_value,
NodeKeyword: _get_keyword_value,
NodeDict: _get_dict_value,
NodeSet: _get_set_value,
NodeFormattedValue: _get_formatted_value,
NodeJoinedStr: _get_joinedstr_value,
NodeCall: _get_call_value,
NodeSlice: _get_slice_value,
NodeBoolOp: _get_boolop_value,
NodeGeneratorExp: _get_generatorexp_value,
NodeComprehension: _get_comprehension_value,
NodeCompare: _get_compare_value,
NodeNotEq: _get_noteq_value,
NodeBitOr: _get_bitor_value,
NodeMult: _get_mult_value,
NodeListComp: _get_listcomp_value,
NodeLambda: _get_lambda_value,
NodeDictComp: _get_dictcomp_value,
NodeStarred: _get_starred_value,
NodeIfExp: _get_ifexp_value,
}


Expand All @@ -234,26 +357,26 @@ def _get_value(node):
# ==========================================================
# names
def _get_attribute_name(node):
return f"{node.attr}.{_get_names(node.value)}"
return f"{_get_names(node.value)}.{node.attr}"


def _get_name_name(node):
return node.id


def _get_assign_names(node):
return [_get_names(target) for target in node.targets]
return [name for name in [_get_names(target) for target in node.targets] if name]


def _get_annassign_names(node):
return [_get_names(node.target)]
return [name for name in _get_names(node.target) if name]


_node_names_map = {
Assign: _get_assign_names,
AnnAssign: _get_annassign_names,
Name: _get_name_name,
Attribute: _get_attribute_name,
NodeAssign: _get_assign_names,
NodeAnnAssign: _get_annassign_names,
NodeName: _get_name_name,
NodeAttribute: _get_attribute_name,
}


Expand All @@ -270,9 +393,9 @@ def _get_instance_names(node):
def _get_parameter_default(node, filepath):
if node is None:
return None
if isinstance(node, Constant):
if isinstance(node, NodeConstant):
return repr(node.value)
if isinstance(node, Name):
if isinstance(node, NodeName):
return node.id
if node.lineno == node.end_lineno:
return lines_collection[filepath][node.lineno - 1][node.col_offset : node.end_col_offset]
Expand All @@ -296,11 +419,12 @@ def __init__(
self.code: str = code
self.extensions: Extensions = extensions.instantiate(self)
# self.scope = defaultdict(dict)
self.root: Node | None = None
self.parent: Module | None = parent
self.current: Module | Class | Function = None # type: ignore
self.in_decorator: bool = False

def _visit(self, node: AST, parent: AST | None = None) -> None:
def _visit(self, node: Node, parent: Node | None = None) -> None:
node.parent = parent # type: ignore
self._run_specific_or_generic(node)

Expand All @@ -311,14 +435,14 @@ def get_module(self) -> Module:
self.visit(top_node)
return self.current.module # type: ignore # there's always a module after the visit

def visit(self, node: AST, parent: AST | None = None) -> None:
def visit(self, node: Node, parent: Node | None = None) -> None:
for start_visitor in self.extensions.when_visit_starts:
start_visitor.visit(node, parent)
super().visit(node, parent)
for stop_visitor in self.extensions.when_visit_stops:
stop_visitor.visit(node, parent)

def generic_visit(self, node: AST) -> None: # noqa: WPS231
def generic_visit(self, node: Node) -> None: # noqa: WPS231
for start_visitor in self.extensions.when_children_visit_starts:
start_visitor.visit(node)
super().generic_visit(node)
Expand Down

0 comments on commit 2ea88c0

Please sign in to comment.