Skip to content

Commit

Permalink
Anthropic Prompt Caching and more (#108)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: scotttestbot <176420715+scotttestbot@users.noreply.github.com>
  • Loading branch information
biobootloader and scotttestbot[bot] authored Aug 29, 2024
1 parent d950e63 commit 1b1ca8e
Show file tree
Hide file tree
Showing 16 changed files with 566 additions and 488 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ packages=["spice"]

[project]
name = "spiceai"
version = "0.3.20"
version = "0.4.0"
license = {text = "Apache-2.0"}
description = "A Python library for building AI-powered applications."
readme = "README.md"
Expand All @@ -31,5 +31,6 @@ dev = [
"ruff",
"pyright",
"pytest",
"pytest-asyncio"
"pytest-asyncio",
"termcolor",
]
57 changes: 57 additions & 0 deletions scripts/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio
import os
import random
import sys

import requests
from termcolor import cprint

from spice import Spice
from spice.models import SONNET_3_5
from spice.utils import print_stream

# Modify sys.path to ensure the script can run even when it's not part of the installed library.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))


def display_response_stats(response):
cprint(f"Input tokens: {response.input_tokens}", "cyan")
cprint(f"Cache Creation Input Tokens: {response.cache_creation_input_tokens}", "cyan")
cprint(f"Cache Read Input Tokens: {response.cache_read_input_tokens}", "cyan")
cprint(f"Output tokens: {response.output_tokens}", "cyan")
cprint(f"Total time: {response.total_time:.2f}s", "green")
cprint(f"Cost: ${response.cost / 100:.2f}", "green")


async def run(cache: bool):
cprint(f"Caching on: {cache}", "cyan")
client = Spice()
model = SONNET_3_5
book_text = requests.get("https://www.gutenberg.org/cache/epub/42671/pg42671.txt").text

messages = (
client.new_messages()
.add_system_text(f"Answer questions about a book. Seed:{random.random()}")
.add_user_text(f"<book>{book_text}</book>", cache=cache)
)

cprint("First model response:", "green")
response = await client.get_response(
messages=messages.copy().add_user_text("how many chapters are there? no elaboration."),
model=model,
streaming_callback=print_stream,
)
display_response_stats(response)

cprint("Second model response:", "green")
response = await client.get_response(
messages=messages.copy().add_user_text("how many volumes are there? no elaboration."),
model=model,
streaming_callback=print_stream,
)
display_response_stats(response)


if __name__ == "__main__":
cache = any(arg in sys.argv for arg in ["--cache", "-c"])
asyncio.run(run(cache))
64 changes: 35 additions & 29 deletions scripts/run.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
import asyncio
import os
import sys
from typing import List

# Modify sys.path to ensure the script can run even when it's not part of the installed library.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from spice import Spice
from spice.spice_message import SpiceMessage


async def basic_example():
client = Spice()
model = "gpt-4o"

messages: List[SpiceMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]
response = await client.get_response(messages=messages, model="gpt-4o")
messages = client.new_messages()
messages.add_system_text("You are a helpful assistant.")
messages.add_user_text("list 5 random words")

tokens = client.count_prompt_tokens(messages, model=model)
print(f"Prompt tokens: {tokens}")

response = await client.get_response(messages=messages, model=model)
print(response.text)

json_response = response.model_dump_json(indent=2)
print(json_response)


async def streaming_example():
# You can set a default model for the client instead of passing it with each call
Expand All @@ -30,10 +34,10 @@ async def streaming_example():
client.load_prompt("scripts/prompt.txt", name="my prompt")

# Spice can also automatically render Jinja templates.
messages: List[SpiceMessage] = [
{"role": "system", "content": client.get_rendered_prompt("my prompt", assistant_name="Ryan Reynolds")},
{"role": "user", "content": "list 5 random words"},
]
messages = client.new_messages()
messages.add_system_prompt("my prompt", assistant_name="Friendly Robot")
messages.add_user_text("list 5 random words")

stream = await client.stream_response(messages=messages)

async for text in stream:
Expand Down Expand Up @@ -62,10 +66,10 @@ async def multiple_providers_example():

client = Spice(model_aliases=model_aliases)

messages: List[SpiceMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]
messages = client.new_messages()
messages.add_system_text("You are a helpful assistant.")
messages.add_user_text("list 5 random words")

responses = await asyncio.gather(
client.get_response(messages=messages, model="task1_model"),
client.get_response(messages=messages, model="task2_model"),
Expand All @@ -86,10 +90,9 @@ async def multiple_providers_example():
async def azure_example():
client = Spice()

messages: List[SpiceMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]
messages = client.new_messages()
messages.add_system_text("You are a helpful assistant.")
messages.add_user_text("list 5 random words")

# To use Azure, specify the provider and the deployment model name
response = await client.get_response(messages=messages, model="first-gpt35", provider="azure")
Expand All @@ -112,19 +115,20 @@ async def vision_example():
client = Spice()

# Spice makes it easy to add images from files or the internet
from spice import SpiceMessage, SpiceMessages
from spice.models import CLAUDE_3_OPUS_20240229, GPT_4_1106_VISION_PREVIEW
from spice.spice_message import file_image_message, user_message
from spice.models import CLAUDE_3_OPUS_20240229, GPT_4o

messages: List[SpiceMessage] = [user_message("What do you see?"), file_image_message("~/.mentat/picture.png")]
response = await client.get_response(messages, GPT_4_1106_VISION_PREVIEW)
messages = client.new_messages()
messages.add_user_image_from_file("~/.mentat/picture.png")
messages.add_user_text("What do you see?")
response = await client.get_response(messages, GPT_4o)
print(response.text)

# Alternatively, you can use the SpiceMessages wrapper to easily create your prompts
spice_messages: SpiceMessages = SpiceMessages(client)
spice_messages.add_user_message("What do you see?")
spice_messages.add_file_image_message("~/.mentat/picture.png")
response = await client.get_response(spice_messages, CLAUDE_3_OPUS_20240229)
messages = (
client.new_messages()
.add_user_image_from_file("~/.mentat/picture.png")
.add_user_text("What do you see? Describe the objects, colors, and style.")
)
response = await client.get_response(messages, CLAUDE_3_OPUS_20240229)
print(response.text)


Expand All @@ -137,7 +141,9 @@ async def embeddings_and_transcription_example():

embeddings = await client.get_embeddings(input_texts, TEXT_EMBEDDING_ADA_002)
transcription = await client.get_transcription("~/.mentat/logs/audio/talk_transcription.wav", WHISPER_1)

print(transcription.text)
print(f"{len(embeddings.embeddings)} embeddings fetched for ${(embeddings.cost or 0) / 100:.2f}")


async def run_all_examples():
Expand Down
10 changes: 4 additions & 6 deletions scripts/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@
import os
import sys
from timeit import default_timer as timer
from typing import List

from spice.models import SONNET_3_5, GPT_4o_2024_08_06

# Modify sys.path to ensure the script can run even when it's not part of the installed library.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from spice import Spice, SpiceMessage
from spice import Spice


async def speed_compare():
client = Spice()

messages: List[SpiceMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 100 random words"},
]
messages = (
client.new_messages().add_system_text("You are a helpful assistant.").add_user_text("list 100 random words")
)

models = [GPT_4o_2024_08_06, SONNET_3_5]
runs = 3
Expand Down
2 changes: 1 addition & 1 deletion spice/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .spice import Spice, SpiceResponse, StreamingSpiceResponse, EmbeddingResponse, TranscriptionResponse # noqa
from .spice_message import SpiceMessage, SpiceMessages # noqa
from .utils import print_stream # noqa
from .spice_message import SpiceMessages, SpiceMessage # noqa
11 changes: 5 additions & 6 deletions spice/call_args.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Callable, Collection, Dict, Generic, List, Optional, TypeVar, cast
from typing import List, Optional

from openai.types.chat.completion_create_params import ResponseFormat
from pydantic import BaseModel

from spice.spice_message import MessagesEncoder, SpiceMessage
from spice.spice_message import SpiceMessage


@dataclass
class SpiceCallArgs:
class SpiceCallArgs(BaseModel):
model: str
messages: Collection[SpiceMessage]
messages: List[SpiceMessage]
stream: bool = False
temperature: Optional[float] = None
max_tokens: Optional[int] = None
Expand Down
4 changes: 2 additions & 2 deletions spice/retry_strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Generic, Optional, TypeVar
from typing import Generic, TypeVar

from spice.spice import SpiceCallArgs
from spice.call_args import SpiceCallArgs

T = TypeVar("T")

Expand Down
15 changes: 7 additions & 8 deletions spice/retry_strategy/converter_strategy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import dataclasses
from typing import Any, Callable, Optional
from typing import Any, Callable

from spice.retry_strategy import Behavior, RetryStrategy
from spice.spice import SpiceCallArgs
from spice.spice_message import assistant_message, user_message
from spice.spice_message import SpiceMessages


def default_failure_message(message: str) -> str:
Expand All @@ -29,10 +28,10 @@ def decide(
return Behavior.RETURN, call_args, result, name
except Exception as e:
if attempt_number < self.retries:
messages = list(call_args.messages)
messages.append(assistant_message(model_output))
messages.append(user_message(self.render_failure_message(str(e))))
call_args = dataclasses.replace(call_args, messages=messages)
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail"
messages = SpiceMessages(messages=call_args.messages)
messages.add_assistant_text(model_output)
messages.add_user_text(self.render_failure_message(str(e)))
new_call_args = call_args.model_copy(update={"messages": messages})
return Behavior.RETRY, new_call_args, None, f"{name}-retry-{attempt_number}-fail"
else:
raise ValueError("Failed to get a valid response after all retries")
6 changes: 2 additions & 4 deletions spice/retry_strategy/default_strategy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import dataclasses
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Optional, TypeVar
from typing import Any, Callable, Optional

from spice.call_args import SpiceCallArgs
from spice.retry_strategy import Behavior, RetryStrategy, T
from spice.spice import SpiceCallArgs


class DefaultRetryStrategy(RetryStrategy):
Expand Down
15 changes: 7 additions & 8 deletions spice/retry_strategy/validator_strategy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import dataclasses
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Tuple

from spice.call_args import SpiceCallArgs
from spice.retry_strategy import Behavior, RetryStrategy
from spice.spice_message import assistant_message, user_message
from spice.spice_message import SpiceMessages


def default_failure_message(message: str) -> str:
Expand Down Expand Up @@ -40,11 +39,11 @@ def decide(
passed, message = self.validator(model_output)
if not passed:
if attempt_number < self.retries:
messages = list(call_args.messages)
messages.append(assistant_message(model_output))
messages.append(user_message(self.render_failure_message(message)))
call_args = dataclasses.replace(call_args, messages=messages)
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail"
messages = SpiceMessages(messages=call_args.messages)
messages.add_assistant_text(model_output)
messages.add_user_text(self.render_failure_message(message))
new_call_args = call_args.model_copy(update={"messages": messages})
return Behavior.RETRY, new_call_args, None, f"{name}-retry-{attempt_number}-fail"
else:
raise ValueError("Failed to get a valid response after all retries")
else:
Expand Down
Loading

0 comments on commit 1b1ca8e

Please sign in to comment.