Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anthropic Prompt Caching and more #108

Merged
merged 21 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ packages=["spice"]

[project]
name = "spiceai"
version = "0.3.20"
version = "0.4.0"
license = {text = "Apache-2.0"}
description = "A Python library for building AI-powered applications."
readme = "README.md"
Expand All @@ -31,5 +31,6 @@ dev = [
"ruff",
"pyright",
"pytest",
"pytest-asyncio"
"pytest-asyncio",
"termcolor",
]
57 changes: 57 additions & 0 deletions scripts/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio
import os
import random
import sys

import requests
from termcolor import cprint

from spice import Spice
from spice.models import SONNET_3_5
from spice.utils import print_stream

# Modify sys.path to ensure the script can run even when it's not part of the installed library.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))


def display_response_stats(response):
cprint(f"Input tokens: {response.input_tokens}", "cyan")
cprint(f"Cache Creation Input Tokens: {response.cache_creation_input_tokens}", "cyan")
cprint(f"Cache Read Input Tokens: {response.cache_read_input_tokens}", "cyan")
cprint(f"Output tokens: {response.output_tokens}", "cyan")
cprint(f"Total time: {response.total_time:.2f}s", "green")
cprint(f"Cost: ${response.cost / 100:.2f}", "green")


async def run(cache: bool):
cprint(f"Caching on: {cache}", "cyan")
client = Spice()
model = SONNET_3_5
book_text = requests.get("https://www.gutenberg.org/cache/epub/42671/pg42671.txt").text

messages = (
client.new_messages()
.add_system_message(f"Answer questions about a book. Seed:{random.random()}")
.add_user_message(f"<book>{book_text}</book>", cache=cache)
)

cprint("First model response:", "green")
response = await client.get_response(
messages=messages.copy().add_user_message("how many chapters are there? no elaboration."),
model=model,
streaming_callback=print_stream,
)
display_response_stats(response)

cprint("Second model response:", "green")
response = await client.get_response(
messages=messages.copy().add_user_message("how many volumes are there? no elaboration."),
model=model,
streaming_callback=print_stream,
)
display_response_stats(response)


if __name__ == "__main__":
cache = any(arg in sys.argv for arg in ["--cache", "-c"])
asyncio.run(run(cache))
40 changes: 29 additions & 11 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from spice import Spice
from spice.spice_message import SpiceMessage


async def basic_example():
client = Spice()

messages: List[SpiceMessage] = [
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]
Expand All @@ -22,6 +21,21 @@ async def basic_example():
print(response.text)


async def messages_example():
client = Spice()

# message convienence functions
messages = (
client.new_messages()
.add_system_message("You are a helpful assistant.")
.add_user_message("list 5 random species of birds")
)

response = await client.get_response(messages=messages, model="gpt-4o")

print(response.text)


async def streaming_example():
# You can set a default model for the client instead of passing it with each call
client = Spice(default_text_model="claude-3-opus-20240229")
Expand All @@ -30,8 +44,8 @@ async def streaming_example():
client.load_prompt("scripts/prompt.txt", name="my prompt")

# Spice can also automatically render Jinja templates.
messages: List[SpiceMessage] = [
{"role": "system", "content": client.get_rendered_prompt("my prompt", assistant_name="Ryan Reynolds")},
messages = [
{"role": "system", "content": client.get_rendered_prompt("my prompt", assistant_name="Friendly Robot")},
{"role": "user", "content": "list 5 random words"},
]
stream = await client.stream_response(messages=messages)
Expand Down Expand Up @@ -62,7 +76,7 @@ async def multiple_providers_example():

client = Spice(model_aliases=model_aliases)

messages: List[SpiceMessage] = [
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]
Expand All @@ -86,7 +100,7 @@ async def multiple_providers_example():
async def azure_example():
client = Spice()

messages: List[SpiceMessage] = [
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]
Expand All @@ -113,17 +127,19 @@ async def vision_example():

# Spice makes it easy to add images from files or the internet
from spice import SpiceMessage, SpiceMessages
from spice.models import CLAUDE_3_OPUS_20240229, GPT_4_1106_VISION_PREVIEW
from spice.models import CLAUDE_3_OPUS_20240229, GPT_4o
from spice.spice_message import file_image_message, user_message

messages: List[SpiceMessage] = [user_message("What do you see?"), file_image_message("~/.mentat/picture.png")]
response = await client.get_response(messages, GPT_4_1106_VISION_PREVIEW)
response = await client.get_response(messages, GPT_4o)
print(response.text)

# Alternatively, you can use the SpiceMessages wrapper to easily create your prompts
spice_messages: SpiceMessages = SpiceMessages(client)
spice_messages.add_user_message("What do you see?")
spice_messages.add_file_image_message("~/.mentat/picture.png")
spice_messages: SpiceMessages = (
SpiceMessages(client)
.add_file_image_message("~/.mentat/picture.png")
.add_user_message("What do you see? Describe the objects, colors, and style.")
)
response = await client.get_response(spice_messages, CLAUDE_3_OPUS_20240229)
print(response.text)

Expand All @@ -137,7 +153,9 @@ async def embeddings_and_transcription_example():

embeddings = await client.get_embeddings(input_texts, TEXT_EMBEDDING_ADA_002)
transcription = await client.get_transcription("~/.mentat/logs/audio/talk_transcription.wav", WHISPER_1)

print(transcription.text)
print(f"{len(embeddings.embeddings)} embeddings fetched for ${(embeddings.cost or 0) / 100:.2f}")


async def run_all_examples():
Expand Down
12 changes: 6 additions & 6 deletions scripts/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
import os
import sys
from timeit import default_timer as timer
from typing import List

from spice.models import SONNET_3_5, GPT_4o_2024_08_06

# Modify sys.path to ensure the script can run even when it's not part of the installed library.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from spice import Spice, SpiceMessage
from spice import Spice


async def speed_compare():
client = Spice()

messages: List[SpiceMessage] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 100 random words"},
]
messages = (
client.new_messages()
.add_system_message("You are a helpful assistant.")
.add_user_message("list 100 random words")
)

models = [GPT_4o_2024_08_06, SONNET_3_5]
runs = 3
Expand Down
6 changes: 3 additions & 3 deletions spice/call_args.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Callable, Collection, Dict, Generic, List, Optional, TypeVar, cast
from dataclasses import dataclass
from typing import Collection, Optional

from openai.types.chat.completion_create_params import ResponseFormat

from spice.spice_message import MessagesEncoder, SpiceMessage
from spice.spice_message import SpiceMessage


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion spice/retry_strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Generic, Optional, TypeVar
from typing import Generic, TypeVar

from spice.spice import SpiceCallArgs

Expand Down
2 changes: 1 addition & 1 deletion spice/retry_strategy/converter_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Any, Callable, Optional
from typing import Any, Callable

from spice.retry_strategy import Behavior, RetryStrategy
from spice.spice import SpiceCallArgs
Expand Down
4 changes: 1 addition & 3 deletions spice/retry_strategy/default_strategy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import dataclasses
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Optional, TypeVar
from typing import Any, Callable, Optional

from spice.retry_strategy import Behavior, RetryStrategy, T
from spice.spice import SpiceCallArgs
Expand Down
2 changes: 1 addition & 1 deletion spice/retry_strategy/validator_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Tuple

from spice.call_args import SpiceCallArgs
from spice.retry_strategy import Behavior, RetryStrategy
Expand Down
Loading
Loading