diff --git a/pyproject.toml b/pyproject.toml index 4351e6f..1372a75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ packages=["spice"] [project] name = "spiceai" -version = "0.4.3" +version = "0.4.4" license = {text = "Apache-2.0"} description = "A Python library for building AI-powered applications." readme = "README.md" diff --git a/scripts/o1.py b/scripts/o1.py new file mode 100644 index 0000000..ca67216 --- /dev/null +++ b/scripts/o1.py @@ -0,0 +1,34 @@ +import asyncio +import os +import sys +from timeit import default_timer as timer + +from spice.models import o1_mini, o1_preview + +# Modify sys.path to ensure the script can run even when it's not part of the installed library. +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from spice import Spice + + +async def run_o1(): + client = Spice() + + messages = client.new_messages().add_user_text("list 5 random words") + + models = [o1_mini, o1_preview] + + for model in models: + print(f"\nRunning {model.name}:") + start = timer() + response = await client.get_response(messages=messages, model=model) + end = timer() + print(response.text) + print(f"input tokens: {response.input_tokens}") + print(f"output tokens: {response.output_tokens}") + print(f"reasoning tokens: {response.reasoning_tokens}") + print(f"total time: {end - start:.2f}s") + + +if __name__ == "__main__": + asyncio.run(run_o1()) diff --git a/spice/models.py b/spice/models.py index d0fa915..edf15ea 100644 --- a/spice/models.py +++ b/spice/models.py @@ -69,6 +69,17 @@ class UnknownModel(TextModel, EmbeddingModel, TranscriptionModel): pass +o1_preview = TextModel("o1-preview", OPEN_AI, input_cost=15, output_cost=60, context_length=128000) + +o1_preview_2024_09_12 = TextModel( + "o1-preview-2024-09-12", OPEN_AI, input_cost=15, output_cost=60, context_length=128000 +) + +o1_mini = TextModel("o1-mini", OPEN_AI, input_cost=3, output_cost=12, context_length=128000) + +o1_mini_2024_09_12 = TextModel("o1-mini-2024-09-12", OPEN_AI, input_cost=3, output_cost=12, context_length=128000) + + GPT_4o = TextModel("gpt-4o", OPEN_AI, input_cost=500, output_cost=1500, context_length=128000) """Warning: This model always points to OpenAI's latest GPT-4o model (currently gpt-4o-2024-05-13), so the input and output costs may be incorrect. We recommend using specific versions of GPT-4o instead.""" diff --git a/spice/spice.py b/spice/spice.py index 86a4be6..06eb490 100644 --- a/spice/spice.py +++ b/spice/spice.py @@ -68,6 +68,10 @@ class SpiceResponse(BaseModel, Generic[T]): description="""The number of output tokens given by the model in this response. May be inaccurate for incomplete streamed responses.""" ) + reasoning_tokens: int = Field( + description="""The number of reasoning tokens given by the model in this response. These are + also counted in output_tokens. Only applies to OpenAI o1 models.""", + ) completed: bool = Field( description="""Whether or not this response was fully completed. This will only ever be false for incomplete streamed responses.""" @@ -195,6 +199,7 @@ def current_response(self) -> SpiceResponse: cache_creation_input_tokens=cache_creation_input_tokens, cache_read_input_tokens=cache_read_input_tokens, output_tokens=output_tokens, + reasoning_tokens=0, completed=self._finished, cost=cost, ) @@ -532,6 +537,7 @@ async def get_response( cache_creation_input_tokens=text_and_tokens.cache_creation_input_tokens, # type: ignore cache_read_input_tokens=text_and_tokens.cache_read_input_tokens, # type: ignore output_tokens=text_and_tokens.output_tokens, # type: ignore + reasoning_tokens=text_and_tokens.reasoning_tokens or 0, # type: ignore completed=True, cost=cost, result=result, diff --git a/spice/wrapped_clients.py b/spice/wrapped_clients.py index 2852bdf..e26277a 100644 --- a/spice/wrapped_clients.py +++ b/spice/wrapped_clients.py @@ -60,6 +60,7 @@ class TextAndTokens(BaseModel): cache_creation_input_tokens: Optional[int] = None cache_read_input_tokens: Optional[int] = None output_tokens: Optional[int] = None + reasoning_tokens: Optional[int] = None class WrappedClient(ABC): @@ -130,18 +131,22 @@ def _convert_messages(self, messages: Collection[SpiceMessage]) -> List[ChatComp @override async def get_chat_completion_or_stream(self, call_args: SpiceCallArgs): + # If using vision you have to set max_tokens or api errors + if call_args.max_tokens is None and "gpt-4" in call_args.model: + max_tokens = 4096 + else: + max_tokens = call_args.max_tokens + # WrappedOpenAIClient can be used with a proxy to a non openai llm, which may not support response_format maybe_kwargs: Dict[str, Any] = {} if call_args.response_format is not None and "type" in call_args.response_format: maybe_kwargs["response_format"] = call_args.response_format if call_args.stream: maybe_kwargs["stream_options"] = {"include_usage": True} - - # If using vision you have to set max_tokens or api errors - if call_args.max_tokens is None and "gpt-4" in call_args.model: - max_tokens = 4096 - else: - max_tokens = call_args.max_tokens + if max_tokens is not None: + maybe_kwargs["max_tokens"] = max_tokens + if call_args.temperature is not None: + maybe_kwargs["temperature"] = call_args.temperature converted_messages = self._convert_messages(call_args.messages) @@ -149,8 +154,6 @@ async def get_chat_completion_or_stream(self, call_args: SpiceCallArgs): model=call_args.model, messages=converted_messages, stream=call_args.stream, - temperature=call_args.temperature, - max_tokens=max_tokens, **maybe_kwargs, ) @@ -169,12 +172,18 @@ def process_chunk(self, chunk, call_args: SpiceCallArgs): @override def extract_text_and_tokens(self, chat_completion, call_args: SpiceCallArgs): + # not working on azure + if hasattr(chat_completion.usage, "completion_tokens_details"): + reasoning_tokens = chat_completion.usage.completion_tokens_details["reasoning_tokens"] + else: + reasoning_tokens = 0 return TextAndTokens( text=chat_completion.choices[0].message.content, input_tokens=chat_completion.usage.prompt_tokens, cache_creation_input_tokens=0, cache_read_input_tokens=0, output_tokens=chat_completion.usage.completion_tokens, + reasoning_tokens=reasoning_tokens, ) @override