From 3b22c25109580b3a1cf4a18ead5cb7fb4ef3f234 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 16 Oct 2022 20:30:50 -0700 Subject: [PATCH] add docstring linter and docstrings --- langchain/__init__.py | 1 + langchain/formatting.py | 3 +++ langchain/prompt.py | 3 +++ requirements.txt | 1 + setup.py | 1 + tests/unit_tests/test_formatting.py | 4 ++++ tests/unit_tests/{test_schema.py => test_prompt.py} | 5 +++++ 7 files changed, 18 insertions(+) rename tests/unit_tests/{test_schema.py => test_prompt.py} (76%) diff --git a/langchain/__init__.py b/langchain/__init__.py index e69de29bb2d1d..7f43cb32ba075 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -0,0 +1 @@ +"""Main entrypoint into package.""" diff --git a/langchain/formatting.py b/langchain/formatting.py index 96b0bffa7c855..88f83d7893cc2 100644 --- a/langchain/formatting.py +++ b/langchain/formatting.py @@ -1,3 +1,4 @@ +"""Utilities for formatting strings.""" from string import Formatter @@ -5,11 +6,13 @@ class StrictFormatter(Formatter): """A subclass of formatter that checks for extra keys.""" def check_unused_args(self, used_args, args, kwargs): + """Check to see if extra parameters are passed.""" extra = set(kwargs).difference(used_args) if extra: raise KeyError(extra) def vformat(self, format_string, args, kwargs): + """Check that no arguments are provided.""" if len(args) > 0: raise ValueError( "No arguments should be provided, " diff --git a/langchain/prompt.py b/langchain/prompt.py index f4065bb1bcfee..15e96f41f0dd1 100644 --- a/langchain/prompt.py +++ b/langchain/prompt.py @@ -13,10 +13,13 @@ class Prompt(BaseModel): template: str class Config: + """Configuration for this pydantic object.""" + extra = Extra.forbid @root_validator() def template_is_valid(cls, values: Dict) -> Dict: + """Check that template and input variables are consistent.""" input_variables = values["input_variables"] template = values["template"] dummy_inputs = {input_variable: "foo" for input_variable in input_variables} diff --git a/requirements.txt b/requirements.txt index e9b0668dc49a6..af4814053d129 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ black isort mypy flake8 +flake8-docstrings diff --git a/setup.py b/setup.py index afe2d7f39ca04..83611874923f5 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +"""Set up the package.""" from setuptools import find_packages, setup with open("README.md", "r") as f: diff --git a/tests/unit_tests/test_formatting.py b/tests/unit_tests/test_formatting.py index 627a2ef43860e..5db90e37ab920 100644 --- a/tests/unit_tests/test_formatting.py +++ b/tests/unit_tests/test_formatting.py @@ -1,9 +1,11 @@ +"""Test formatting functionality.""" import pytest from langchain.formatting import formatter def test_valid_formatting(): + """Test formatting works as expected.""" template = "This is a {foo} test." output = formatter.format(template, foo="good") expected_output = "This is a good test." @@ -11,12 +13,14 @@ def test_valid_formatting(): def test_does_not_allow_args(): + """Test formatting raises error when args are provided.""" template = "This is a {} test." with pytest.raises(ValueError): formatter.format(template, "good") def test_does_not_allow_extra_kwargs(): + """Test formatting does not allow extra key word arguments.""" template = "This is a {foo} test." with pytest.raises(KeyError): formatter.format(template, foo="good", bar="oops") diff --git a/tests/unit_tests/test_schema.py b/tests/unit_tests/test_prompt.py similarity index 76% rename from tests/unit_tests/test_schema.py rename to tests/unit_tests/test_prompt.py index 4706475e5acd6..7e6ff31fb9d8d 100644 --- a/tests/unit_tests/test_schema.py +++ b/tests/unit_tests/test_prompt.py @@ -1,9 +1,11 @@ +"""Test functionality related to prompts.""" import pytest from langchain.prompt import Prompt def test_prompt_valid(): + """Test prompts can be constructed.""" template = "This is a {foo} test." input_variables = ["foo"] prompt = Prompt(input_variables=input_variables, template=template) @@ -12,6 +14,7 @@ def test_prompt_valid(): def test_prompt_missing_input_variables(): + """Test error is raised when input variables are not provided.""" template = "This is a {foo} test." input_variables = [] with pytest.raises(ValueError): @@ -19,6 +22,7 @@ def test_prompt_missing_input_variables(): def test_prompt_extra_input_variables(): + """Test error is raised when there are too many input variables.""" template = "This is a {foo} test." input_variables = ["foo", "bar"] with pytest.raises(ValueError): @@ -26,6 +30,7 @@ def test_prompt_extra_input_variables(): def test_prompt_wrong_input_variables(): + """Test error is raised when name of input variable is wrong.""" template = "This is a {foo} test." input_variables = ["bar"] with pytest.raises(ValueError):