diff --git a/scripts/tests/test_validate_docstrings.py b/scripts/tests/test_validate_docstrings.py index bb58449843096..6f78b91653a3f 100644 --- a/scripts/tests/test_validate_docstrings.py +++ b/scripts/tests/test_validate_docstrings.py @@ -229,6 +229,27 @@ def good_imports(self): """ pass + def no_returns(self): + """ + Say hello and have no returns. + """ + pass + + def empty_returns(self): + """ + Say hello and always return None. + + Since this function never returns a value, this + docstring doesn't need a return section. + """ + def say_hello(): + return "Hello World!" + say_hello() + if True: + return + else: + return None + class BadGenericDocStrings(object): """Everything here has a bad docstring @@ -783,7 +804,7 @@ def test_good_class(self, capsys): @pytest.mark.parametrize("func", [ 'plot', 'sample', 'random_letters', 'sample_values', 'head', 'head1', - 'contains', 'mode', 'good_imports']) + 'contains', 'mode', 'good_imports', 'no_returns', 'empty_returns']) def test_good_functions(self, capsys, func): errors = validate_one(self._import_path( klass='GoodDocStrings', func=func))['errors'] diff --git a/scripts/validate_docstrings.py b/scripts/validate_docstrings.py index bce33f7e78daa..446cd60968312 100755 --- a/scripts/validate_docstrings.py +++ b/scripts/validate_docstrings.py @@ -26,6 +26,8 @@ import importlib import doctest import tempfile +import ast +import textwrap import flake8.main.application @@ -490,9 +492,45 @@ def yields(self): @property def method_source(self): try: - return inspect.getsource(self.obj) + source = inspect.getsource(self.obj) except TypeError: return '' + return textwrap.dedent(source) + + @property + def method_returns_something(self): + ''' + Check if the docstrings method can return something. + + Bare returns, returns valued None and returns from nested functions are + disconsidered. + + Returns + ------- + bool + Whether the docstrings method can return something. + ''' + + def get_returns_not_on_nested_functions(node): + returns = [node] if isinstance(node, ast.Return) else [] + for child in ast.iter_child_nodes(node): + # Ignore nested functions and its subtrees. + if not isinstance(child, ast.FunctionDef): + child_returns = get_returns_not_on_nested_functions(child) + returns.extend(child_returns) + return returns + + tree = ast.parse(self.method_source).body + if tree: + returns = get_returns_not_on_nested_functions(tree[0]) + return_values = [r.value for r in returns] + # Replace NameConstant nodes valued None for None. + for i, v in enumerate(return_values): + if isinstance(v, ast.NameConstant) and v.value is None: + return_values[i] = None + return any(return_values) + else: + return False @property def first_line_ends_in_dot(self): @@ -691,7 +729,7 @@ def get_validation_data(doc): if doc.is_function_or_method: if not doc.returns: - if 'return' in doc.method_source: + if doc.method_returns_something: errs.append(error('RT01')) else: if len(doc.returns) == 1 and doc.returns[0][1]: