Skip to content

Commit

Permalink
Renamed template.execute() to template.evaluate() and added type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jul 10, 2023
1 parent 50b6647 commit 40b9296
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def read_prompt():
template_obj = load_template(template)
prompt = read_prompt()
try:
prompt, system = template_obj.execute(prompt, params)
prompt, system = template_obj.evaluate(prompt, params)
except Template.MissingVariables as ex:
raise click.ClickException(str(ex))
if model_id is None and template_obj.model:
Expand Down
14 changes: 9 additions & 5 deletions llm/templates.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
from pydantic import ConfigDict, BaseModel
import string
from typing import Optional
from typing import Optional, Any, Dict, List, Tuple


class Template(BaseModel):
name: str
prompt: Optional[str] = None
system: Optional[str] = None
model: Optional[str] = None
defaults: Optional[dict] = None
defaults: Optional[Dict[str, Any]] = None
model_config = ConfigDict(extra="forbid")

class MissingVariables(Exception):
pass

def execute(self, input, params=None):
def evaluate(
self, input: str, params: Optional[Dict[str, Any]] = None
) -> Tuple[Optional[str], Optional[str]]:
params = params or {}
params["input"] = input
if self.defaults:
for k, v in self.defaults.items():
if k not in params:
params[k] = v
prompt: Optional[str] = None
system: Optional[str] = None
if not self.prompt:
system = self.interpolate(self.system, params)
prompt = input
Expand All @@ -30,7 +34,7 @@ def execute(self, input, params=None):
return prompt, system

@classmethod
def interpolate(cls, text, params):
def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]:
if not text:
return text
# Confirm all variables in text are provided
Expand All @@ -44,7 +48,7 @@ def interpolate(cls, text, params):
return string_template.substitute(**params)

@staticmethod
def extract_vars(string_template):
def extract_vars(string_template: string.Template) -> List[str]:
return [
match.group("named")
for match in string_template.pattern.finditer(string_template.template)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@
),
),
)
def test_template_execute(
def test_template_evaluate(
prompt, system, defaults, params, expected_prompt, expected_system, expected_error
):
t = Template(name="t", prompt=prompt, system=system, defaults=defaults)
if expected_error:
with pytest.raises(Template.MissingVariables) as ex:
prompt, system = t.execute("input", params)
prompt, system = t.evaluate("input", params)
assert ex.value.args[0] == expected_error
else:
prompt, system = t.execute("input", params)
prompt, system = t.evaluate("input", params)
assert prompt == expected_prompt
assert system == expected_system

Expand Down

0 comments on commit 40b9296

Please sign in to comment.