Skip to content

Commit

Permalink
support o1-preview and o1-mini
Browse files Browse the repository at this point in the history
  • Loading branch information
biobootloader committed Sep 13, 2024
1 parent 1e77579 commit afc364f
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ packages=["spice"]

[project]
name = "spiceai"
version = "0.4.3"
version = "0.4.4"
license = {text = "Apache-2.0"}
description = "A Python library for building AI-powered applications."
readme = "README.md"
Expand Down
34 changes: 34 additions & 0 deletions scripts/o1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import asyncio
import os
import sys
from timeit import default_timer as timer

from spice.models import o1_mini, o1_preview

# Modify sys.path to ensure the script can run even when it's not part of the installed library.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from spice import Spice


async def run_o1():
client = Spice()

messages = client.new_messages().add_user_text("list 5 random words")

models = [o1_mini, o1_preview]

for model in models:
print(f"\nRunning {model.name}:")
start = timer()
response = await client.get_response(messages=messages, model=model)
end = timer()
print(response.text)
print(f"input tokens: {response.input_tokens}")
print(f"output tokens: {response.output_tokens}")
print(f"reasoning tokens: {response.reasoning_tokens}")
print(f"total time: {end - start:.2f}s")


if __name__ == "__main__":
asyncio.run(run_o1())
11 changes: 11 additions & 0 deletions spice/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ class UnknownModel(TextModel, EmbeddingModel, TranscriptionModel):
pass


o1_preview = TextModel("o1-preview", OPEN_AI, input_cost=15, output_cost=60, context_length=128000)

o1_preview_2024_09_12 = TextModel(
"o1-preview-2024-09-12", OPEN_AI, input_cost=15, output_cost=60, context_length=128000
)

o1_mini = TextModel("o1-mini", OPEN_AI, input_cost=3, output_cost=12, context_length=128000)

o1_mini_2024_09_12 = TextModel("o1-mini-2024-09-12", OPEN_AI, input_cost=3, output_cost=12, context_length=128000)


GPT_4o = TextModel("gpt-4o", OPEN_AI, input_cost=500, output_cost=1500, context_length=128000)
"""Warning: This model always points to OpenAI's latest GPT-4o model (currently gpt-4o-2024-05-13), so the input and output costs may be incorrect. We recommend using specific versions of GPT-4o instead."""

Expand Down
6 changes: 6 additions & 0 deletions spice/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class SpiceResponse(BaseModel, Generic[T]):
description="""The number of output tokens given by the model in this response.
May be inaccurate for incomplete streamed responses."""
)
reasoning_tokens: int = Field(
description="""The number of reasoning tokens given by the model in this response. These are
also counted in output_tokens. Only applies to OpenAI o1 models.""",
)
completed: bool = Field(
description="""Whether or not this response was fully completed.
This will only ever be false for incomplete streamed responses."""
Expand Down Expand Up @@ -195,6 +199,7 @@ def current_response(self) -> SpiceResponse:
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
output_tokens=output_tokens,
reasoning_tokens=0,
completed=self._finished,
cost=cost,
)
Expand Down Expand Up @@ -532,6 +537,7 @@ async def get_response(
cache_creation_input_tokens=text_and_tokens.cache_creation_input_tokens, # type: ignore
cache_read_input_tokens=text_and_tokens.cache_read_input_tokens, # type: ignore
output_tokens=text_and_tokens.output_tokens, # type: ignore
reasoning_tokens=text_and_tokens.reasoning_tokens or 0, # type: ignore
completed=True,
cost=cost,
result=result,
Expand Down
25 changes: 17 additions & 8 deletions spice/wrapped_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class TextAndTokens(BaseModel):
cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
output_tokens: Optional[int] = None
reasoning_tokens: Optional[int] = None


class WrappedClient(ABC):
Expand Down Expand Up @@ -130,27 +131,29 @@ def _convert_messages(self, messages: Collection[SpiceMessage]) -> List[ChatComp

@override
async def get_chat_completion_or_stream(self, call_args: SpiceCallArgs):
# If using vision you have to set max_tokens or api errors
if call_args.max_tokens is None and "gpt-4" in call_args.model:
max_tokens = 4096
else:
max_tokens = call_args.max_tokens

# WrappedOpenAIClient can be used with a proxy to a non openai llm, which may not support response_format
maybe_kwargs: Dict[str, Any] = {}
if call_args.response_format is not None and "type" in call_args.response_format:
maybe_kwargs["response_format"] = call_args.response_format
if call_args.stream:
maybe_kwargs["stream_options"] = {"include_usage": True}

# If using vision you have to set max_tokens or api errors
if call_args.max_tokens is None and "gpt-4" in call_args.model:
max_tokens = 4096
else:
max_tokens = call_args.max_tokens
if max_tokens is not None:
maybe_kwargs["max_tokens"] = max_tokens
if call_args.temperature is not None:
maybe_kwargs["temperature"] = call_args.temperature

converted_messages = self._convert_messages(call_args.messages)

return await self._client.chat.completions.create(
model=call_args.model,
messages=converted_messages,
stream=call_args.stream,
temperature=call_args.temperature,
max_tokens=max_tokens,
**maybe_kwargs,
)

Expand All @@ -169,12 +172,18 @@ def process_chunk(self, chunk, call_args: SpiceCallArgs):

@override
def extract_text_and_tokens(self, chat_completion, call_args: SpiceCallArgs):
# not working on azure
if hasattr(chat_completion.usage, "completion_tokens_details"):
reasoning_tokens = chat_completion.usage.completion_tokens_details["reasoning_tokens"]
else:
reasoning_tokens = 0
return TextAndTokens(
text=chat_completion.choices[0].message.content,
input_tokens=chat_completion.usage.prompt_tokens,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
output_tokens=chat_completion.usage.completion_tokens,
reasoning_tokens=reasoning_tokens,
)

@override
Expand Down

0 comments on commit afc364f

Please sign in to comment.