Skip to content

Commit

Permalink
Refactor Template into templates.py
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jul 1, 2023
1 parent 5e056fa commit 13fb4c2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 55 deletions.
59 changes: 4 additions & 55 deletions llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,5 @@
from pydantic import BaseModel
import string
from typing import Optional
from .hookspecs import hookimpl # noqa
from .hookspecs import hookspec # noqa
from .models import Model, Prompt, Response, OptionsError # noqa
from .hookspecs import hookimpl
from .models import Model, Prompt, Response, OptionsError
from .templates import Template


class Template(BaseModel):
name: str
prompt: Optional[str]
system: Optional[str]
model: Optional[str]
defaults: Optional[dict]

class Config:
extra = "forbid"

class MissingVariables(Exception):
pass

def execute(self, input, params=None):
params = params or {}
params["input"] = input
if self.defaults:
for k, v in self.defaults.items():
if k not in params:
params[k] = v
if not self.prompt:
system = self.interpolate(self.system, params)
prompt = input
else:
prompt = self.interpolate(self.prompt, params)
system = self.interpolate(self.system, params)
return prompt, system

@classmethod
def interpolate(cls, text, params):
if not text:
return text
# Confirm all variables in text are provided
string_template = string.Template(text)
vars = cls.extract_vars(string_template)
missing = [p for p in vars if p not in params]
if missing:
raise cls.MissingVariables(
"Missing variables: {}".format(", ".join(missing))
)
return string_template.substitute(**params)

@staticmethod
def extract_vars(string_template):
return [
match.group("named")
for match in string_template.pattern.finditer(string_template.template)
]
__all__ = ["Template", "Model", "Prompt", "Response", "OptionsError", "hookimpl"]
53 changes: 53 additions & 0 deletions llm/templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from pydantic import BaseModel
import string
from typing import Optional


class Template(BaseModel):
name: str
prompt: Optional[str]
system: Optional[str]
model: Optional[str]
defaults: Optional[dict]

class Config:
extra = "forbid"

class MissingVariables(Exception):
pass

def execute(self, input, params=None):
params = params or {}
params["input"] = input
if self.defaults:
for k, v in self.defaults.items():
if k not in params:
params[k] = v
if not self.prompt:
system = self.interpolate(self.system, params)
prompt = input
else:
prompt = self.interpolate(self.prompt, params)
system = self.interpolate(self.system, params)
return prompt, system

@classmethod
def interpolate(cls, text, params):
if not text:
return text
# Confirm all variables in text are provided
string_template = string.Template(text)
vars = cls.extract_vars(string_template)
missing = [p for p in vars if p not in params]
if missing:
raise cls.MissingVariables(
"Missing variables: {}".format(", ".join(missing))
)
return string_template.substitute(**params)

@staticmethod
def extract_vars(string_template):
return [
match.group("named")
for match in string_template.pattern.finditer(string_template.template)
]

0 comments on commit 13fb4c2

Please sign in to comment.