Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anthropic Prompt Caching and more #108

Merged
merged 21 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ packages=["spice"]

[project]
name = "spiceai"
version = "0.3.20"
version = "0.4.0"
license = {text = "Apache-2.0"}
description = "A Python library for building AI-powered applications."
readme = "README.md"
Expand All @@ -31,5 +31,6 @@ dev = [
"ruff",
"pyright",
"pytest",
"pytest-asyncio"
"pytest-asyncio",
"termcolor",
]
57 changes: 57 additions & 0 deletions scripts/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio
import os
import random
import sys

import requests
from termcolor import cprint

from spice import Spice
from spice.models import SONNET_3_5
from spice.utils import print_stream

# Modify sys.path to ensure the script can run even when it's not part of the installed library.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))


def display_response_stats(response):
cprint(f"Input tokens: {response.input_tokens}", "cyan")
cprint(f"Cache Creation Input Tokens: {response.cache_creation_input_tokens}", "cyan")
cprint(f"Cache Read Input Tokens: {response.cache_read_input_tokens}", "cyan")
cprint(f"Output tokens: {response.output_tokens}", "cyan")
cprint(f"Total time: {response.total_time:.2f}s", "green")
cprint(f"Cost: ${response.cost / 100:.2f}", "green")


async def run(cache: bool):
cprint(f"Caching on: {cache}", "cyan")
client = Spice()
model = SONNET_3_5
book_text = requests.get("https://www.gutenberg.org/cache/epub/42671/pg42671.txt").text

messages = (
client.new_messages()
.add_system_text(f"Answer questions about a book. Seed:{random.random()}")
.add_user_text(f"<book>{book_text}</book>", cache=cache)
)

cprint("First model response:", "green")
response = await client.get_response(
messages=messages.copy().add_user_text("how many chapters are there? no elaboration."),
model=model,
streaming_callback=print_stream,
)
display_response_stats(response)

cprint("Second model response:", "green")
response = await client.get_response(
messages=messages.copy().add_user_text("how many volumes are there? no elaboration."),
model=model,
streaming_callback=print_stream,
)
display_response_stats(response)


if __name__ == "__main__":
cache = any(arg in sys.argv for arg in ["--cache", "-c"])
asyncio.run(run(cache))
64 changes: 35 additions & 29 deletions scripts/run.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
import asyncio
import os
import sys
from typing import List

# Modify sys.path to ensure the script can run even when it's not part of the installed library.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from spice import Spice
from spice.spice_message import SpiceMessage


async def basic_example():
client = Spice()
model = "gpt-4o"

messages: List[SpiceMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]
response = await client.get_response(messages=messages, model="gpt-4o")
messages = client.new_messages()
messages.add_system_text("You are a helpful assistant.")
messages.add_user_text("list 5 random words")

tokens = client.count_prompt_tokens(messages, model=model)
print(f"Prompt tokens: {tokens}")

response = await client.get_response(messages=messages, model=model)
print(response.text)

json_response = response.model_dump_json(indent=2)
print(json_response)


async def streaming_example():
# You can set a default model for the client instead of passing it with each call
Expand All @@ -30,10 +34,10 @@ async def streaming_example():
client.load_prompt("scripts/prompt.txt", name="my prompt")

# Spice can also automatically render Jinja templates.
messages: List[SpiceMessage] = [
{"role": "system", "content": client.get_rendered_prompt("my prompt", assistant_name="Ryan Reynolds")},
{"role": "user", "content": "list 5 random words"},
]
messages = client.new_messages()
messages.add_system_prompt("my prompt", assistant_name="Friendly Robot")
messages.add_user_text("list 5 random words")

stream = await client.stream_response(messages=messages)

async for text in stream:
Expand Down Expand Up @@ -62,10 +66,10 @@ async def multiple_providers_example():

client = Spice(model_aliases=model_aliases)

messages: List[SpiceMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]
messages = client.new_messages()
messages.add_system_text("You are a helpful assistant.")
messages.add_user_text("list 5 random words")

responses = await asyncio.gather(
client.get_response(messages=messages, model="task1_model"),
client.get_response(messages=messages, model="task2_model"),
Expand All @@ -86,10 +90,9 @@ async def multiple_providers_example():
async def azure_example():
client = Spice()

messages: List[SpiceMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]
messages = client.new_messages()
messages.add_system_text("You are a helpful assistant.")
messages.add_user_text("list 5 random words")

# To use Azure, specify the provider and the deployment model name
response = await client.get_response(messages=messages, model="first-gpt35", provider="azure")
Expand All @@ -112,19 +115,20 @@ async def vision_example():
client = Spice()

# Spice makes it easy to add images from files or the internet
from spice import SpiceMessage, SpiceMessages
from spice.models import CLAUDE_3_OPUS_20240229, GPT_4_1106_VISION_PREVIEW
from spice.spice_message import file_image_message, user_message
from spice.models import CLAUDE_3_OPUS_20240229, GPT_4o

messages: List[SpiceMessage] = [user_message("What do you see?"), file_image_message("~/.mentat/picture.png")]
response = await client.get_response(messages, GPT_4_1106_VISION_PREVIEW)
messages = client.new_messages()
messages.add_user_image_from_file("~/.mentat/picture.png")
messages.add_user_text("What do you see?")
response = await client.get_response(messages, GPT_4o)
print(response.text)

# Alternatively, you can use the SpiceMessages wrapper to easily create your prompts
spice_messages: SpiceMessages = SpiceMessages(client)
spice_messages.add_user_message("What do you see?")
spice_messages.add_file_image_message("~/.mentat/picture.png")
response = await client.get_response(spice_messages, CLAUDE_3_OPUS_20240229)
messages = (
client.new_messages()
.add_user_image_from_file("~/.mentat/picture.png")
.add_user_text("What do you see? Describe the objects, colors, and style.")
)
response = await client.get_response(messages, CLAUDE_3_OPUS_20240229)
print(response.text)


Expand All @@ -137,7 +141,9 @@ async def embeddings_and_transcription_example():

embeddings = await client.get_embeddings(input_texts, TEXT_EMBEDDING_ADA_002)
transcription = await client.get_transcription("~/.mentat/logs/audio/talk_transcription.wav", WHISPER_1)

print(transcription.text)
print(f"{len(embeddings.embeddings)} embeddings fetched for ${(embeddings.cost or 0) / 100:.2f}")


async def run_all_examples():
Expand Down
10 changes: 4 additions & 6 deletions scripts/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@
import os
import sys
from timeit import default_timer as timer
from typing import List

from spice.models import SONNET_3_5, GPT_4o_2024_08_06

# Modify sys.path to ensure the script can run even when it's not part of the installed library.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from spice import Spice, SpiceMessage
from spice import Spice


async def speed_compare():
client = Spice()

messages: List[SpiceMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 100 random words"},
]
messages = (
client.new_messages().add_system_text("You are a helpful assistant.").add_user_text("list 100 random words")
)

models = [GPT_4o_2024_08_06, SONNET_3_5]
runs = 3
Expand Down
2 changes: 1 addition & 1 deletion spice/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .spice import Spice, SpiceResponse, StreamingSpiceResponse, EmbeddingResponse, TranscriptionResponse # noqa
from .spice_message import SpiceMessage, SpiceMessages # noqa
from .utils import print_stream # noqa
from .spice_message import SpiceMessages, SpiceMessage # noqa
11 changes: 5 additions & 6 deletions spice/call_args.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Callable, Collection, Dict, Generic, List, Optional, TypeVar, cast
from typing import List, Optional

from openai.types.chat.completion_create_params import ResponseFormat
from pydantic import BaseModel

from spice.spice_message import MessagesEncoder, SpiceMessage
from spice.spice_message import SpiceMessage


@dataclass
class SpiceCallArgs:
class SpiceCallArgs(BaseModel):
model: str
messages: Collection[SpiceMessage]
messages: List[SpiceMessage]
stream: bool = False
temperature: Optional[float] = None
max_tokens: Optional[int] = None
Expand Down
4 changes: 2 additions & 2 deletions spice/retry_strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Generic, Optional, TypeVar
from typing import Generic, TypeVar

from spice.spice import SpiceCallArgs
from spice.call_args import SpiceCallArgs

T = TypeVar("T")

Expand Down
15 changes: 7 additions & 8 deletions spice/retry_strategy/converter_strategy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import dataclasses
from typing import Any, Callable, Optional
from typing import Any, Callable

from spice.retry_strategy import Behavior, RetryStrategy
from spice.spice import SpiceCallArgs
from spice.spice_message import assistant_message, user_message
from spice.spice_message import SpiceMessages


def default_failure_message(message: str) -> str:
Expand All @@ -29,10 +28,10 @@ def decide(
return Behavior.RETURN, call_args, result, name
except Exception as e:
if attempt_number < self.retries:
messages = list(call_args.messages)
messages.append(assistant_message(model_output))
messages.append(user_message(self.render_failure_message(str(e))))
call_args = dataclasses.replace(call_args, messages=messages)
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail"
messages = SpiceMessages(messages=call_args.messages)
messages.add_assistant_text(model_output)
messages.add_user_text(self.render_failure_message(str(e)))
new_call_args = call_args.model_copy(update={"messages": messages})
return Behavior.RETRY, new_call_args, None, f"{name}-retry-{attempt_number}-fail"
else:
raise ValueError("Failed to get a valid response after all retries")
6 changes: 2 additions & 4 deletions spice/retry_strategy/default_strategy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import dataclasses
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Optional, TypeVar
from typing import Any, Callable, Optional

from spice.call_args import SpiceCallArgs
from spice.retry_strategy import Behavior, RetryStrategy, T
from spice.spice import SpiceCallArgs


class DefaultRetryStrategy(RetryStrategy):
Expand Down
15 changes: 7 additions & 8 deletions spice/retry_strategy/validator_strategy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import dataclasses
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Tuple

from spice.call_args import SpiceCallArgs
from spice.retry_strategy import Behavior, RetryStrategy
from spice.spice_message import assistant_message, user_message
from spice.spice_message import SpiceMessages


def default_failure_message(message: str) -> str:
Expand Down Expand Up @@ -40,11 +39,11 @@ def decide(
passed, message = self.validator(model_output)
if not passed:
if attempt_number < self.retries:
messages = list(call_args.messages)
messages.append(assistant_message(model_output))
messages.append(user_message(self.render_failure_message(message)))
call_args = dataclasses.replace(call_args, messages=messages)
return Behavior.RETRY, call_args, None, f"{name}-retry-{attempt_number}-fail"
messages = SpiceMessages(messages=call_args.messages)
messages.add_assistant_text(model_output)
messages.add_user_text(self.render_failure_message(message))
new_call_args = call_args.model_copy(update={"messages": messages})
return Behavior.RETRY, new_call_args, None, f"{name}-retry-{attempt_number}-fail"
else:
raise ValueError("Failed to get a valid response after all retries")
else:
Expand Down
Loading
Loading