Skip to content

Commit

Permalink
get rid of metaprogramming
Browse files Browse the repository at this point in the history
  • Loading branch information
scotttestbot[bot] committed Aug 29, 2024
1 parent 948a693 commit d725924
Showing 1 changed file with 23 additions and 36 deletions.
59 changes: 23 additions & 36 deletions spice/spice_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import mimetypes
from collections.abc import Collection
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, get_args
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -59,26 +59,6 @@ def __init__(
super().__init__(role=role, content=content, cache=cache, prompt_metadata=prompt_metadata)


def _generate_role_specific_methods(cls):
"""Generates add_user_text, add_system_text, etc. methods for SpiceMessages."""
add_methods = [method for method in dir(cls) if method.startswith("add_")]

def create_method(original, role):
def new_method(self, *args, **kwargs):
return original(self, role, *args, **kwargs)

return new_method

for method_name in add_methods:
original_method = getattr(cls, method_name)
for role in get_args(Role):
new_method_name = f"add_{role}_{method_name[4:]}"
setattr(cls, new_method_name, create_method(original_method, role))

return cls


@_generate_role_specific_methods
class SpiceMessages(List[SpiceMessage]):
"""A collection of messages to be sent to an API endpoint."""

Expand All @@ -87,19 +67,25 @@ def __init__(self, client: Optional[Spice] = None, messages: Collection[SpiceMes
self.data: list[SpiceMessage] = [message for message in messages]

def add_text(self, role: Role, text: str, cache: bool = False) -> SpiceMessages:
"""Appends a message with the given role and text."""
self.data.append(SpiceMessage(role=role, text=text, cache=cache))
return self

def add_image_from_url(self, role: Role, url: str, cache: bool = False) -> SpiceMessages:
"""Appends a message with the given role and image from the given url."""
def add_user_text(self, text: str, cache: bool = False) -> SpiceMessages:
return self.add_text("user", text, cache)

def add_system_text(self, text: str, cache: bool = False) -> SpiceMessages:
return self.add_text("system", text, cache)

def add_assistant_text(self, text: str, cache: bool = False) -> SpiceMessages:
return self.add_text("assistant", text, cache)

def add_user_image_from_url(self, url: str, cache: bool = False) -> SpiceMessages:
if not (url.startswith("http://") or url.startswith("https://")):
raise ImageError(f"Invalid image URL {url}: Must be http or https protocol.")
self.data.append(SpiceMessage(role=role, image_url=url, cache=cache))
self.data.append(SpiceMessage(role="user", image_url=url, cache=cache))
return self

def add_image_from_file(self, role: Role, file_path: Path | str, cache: bool = False) -> SpiceMessages:
"""Appends a message with the given role and image from the given file path."""
def add_user_image_from_file(self, file_path: Path | str, cache: bool = False) -> SpiceMessages:
file_path = Path(file_path).expanduser().resolve()
if not file_path.exists():
raise ImageError(f"Invalid image at {file_path}: file does not exist.")
Expand All @@ -109,7 +95,7 @@ def add_image_from_file(self, role: Role, file_path: Path | str, cache: bool = F
with file_path.open("rb") as file:
image_bytes = file.read()
image = base64.b64encode(image_bytes).decode("utf-8")
self.data.append(SpiceMessage(role=role, image_url=f"data:{media_type};base64,{image}", cache=cache))
self.data.append(SpiceMessage(role="user", image_url=f"data:{media_type};base64,{image}", cache=cache))
return self

def add_prompt(self, role: Role, name: str, cache: bool = False, **context: Any) -> SpiceMessages:
Expand All @@ -127,6 +113,15 @@ def add_prompt(self, role: Role, name: str, cache: bool = False, **context: Any)
)
return self

def add_user_prompt(self, name: str, cache: bool = False, **context: Any) -> SpiceMessages:
return self.add_prompt("user", name, cache, **context)

def add_system_prompt(self, name: str, cache: bool = False, **context: Any) -> SpiceMessages:
return self.add_prompt("system", name, cache, **context)

def add_assistant_prompt(self, name: str, cache: bool = False, **context: Any) -> SpiceMessages:
return self.add_prompt("assistant", name, cache, **context)

def __iter__(self):
return iter(self.data)

Expand All @@ -140,11 +135,3 @@ def copy(self):
new_copy = SpiceMessages(self._client)
new_copy.data = self.data.copy()
return new_copy

if TYPE_CHECKING:

def __getattribute__(self, name: str) -> Any:
for role in get_args(Role):
if name.startswith(f"add_{role}_"):
_, suffix = name.split(f"{role}_")
return getattr(self, f"add_{suffix}")

0 comments on commit d725924

Please sign in to comment.