diff --git a/llm/cli.py b/llm/cli.py index 18087bcd..e0e4d1a5 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -171,7 +171,7 @@ def read_prompt(): template_obj = load_template(template) prompt = read_prompt() try: - prompt, system = template_obj.execute(prompt, params) + prompt, system = template_obj.evaluate(prompt, params) except Template.MissingVariables as ex: raise click.ClickException(str(ex)) if model_id is None and template_obj.model: diff --git a/llm/templates.py b/llm/templates.py index 6885fe38..4a18a44e 100644 --- a/llm/templates.py +++ b/llm/templates.py @@ -1,6 +1,6 @@ from pydantic import ConfigDict, BaseModel import string -from typing import Optional +from typing import Optional, Any, Dict, List, Tuple class Template(BaseModel): @@ -8,19 +8,23 @@ class Template(BaseModel): prompt: Optional[str] = None system: Optional[str] = None model: Optional[str] = None - defaults: Optional[dict] = None + defaults: Optional[Dict[str, Any]] = None model_config = ConfigDict(extra="forbid") class MissingVariables(Exception): pass - def execute(self, input, params=None): + def evaluate( + self, input: str, params: Optional[Dict[str, Any]] = None + ) -> Tuple[Optional[str], Optional[str]]: params = params or {} params["input"] = input if self.defaults: for k, v in self.defaults.items(): if k not in params: params[k] = v + prompt: Optional[str] = None + system: Optional[str] = None if not self.prompt: system = self.interpolate(self.system, params) prompt = input @@ -30,7 +34,7 @@ def execute(self, input, params=None): return prompt, system @classmethod - def interpolate(cls, text, params): + def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]: if not text: return text # Confirm all variables in text are provided @@ -44,7 +48,7 @@ def interpolate(cls, text, params): return string_template.substitute(**params) @staticmethod - def extract_vars(string_template): + def extract_vars(string_template: string.Template) -> List[str]: return [ match.group("named") for match in string_template.pattern.finditer(string_template.template) diff --git a/tests/test_templates.py b/tests/test_templates.py index 5ca22827..08e6c424 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -27,16 +27,16 @@ ), ), ) -def test_template_execute( +def test_template_evaluate( prompt, system, defaults, params, expected_prompt, expected_system, expected_error ): t = Template(name="t", prompt=prompt, system=system, defaults=defaults) if expected_error: with pytest.raises(Template.MissingVariables) as ex: - prompt, system = t.execute("input", params) + prompt, system = t.evaluate("input", params) assert ex.value.args[0] == expected_error else: - prompt, system = t.execute("input", params) + prompt, system = t.evaluate("input", params) assert prompt == expected_prompt assert system == expected_system