diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index a3aba54c8..12cbda924 100755 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -8,9 +8,11 @@ import re from datetime import datetime from pathlib import Path +from typing import List from uuid import uuid4 from openai.types.chat.completion_create_params import ResponseFormat +from spice import SpiceMessage from benchmarks.arg_parser import common_benchmark_parser from benchmarks.benchmark_result import BenchmarkResult @@ -43,7 +45,7 @@ def git_diff_from_comparison_commit(sample: Sample, comparison_commit: str) -> s async def grade(to_grade, prompt, model="gpt-4-1106-preview"): try: - messages = [ + messages: List[SpiceMessage] = [ {"role": "system", "content": prompt}, {"role": "user", "content": to_grade}, ] diff --git a/benchmarks/exercism_practice.py b/benchmarks/exercism_practice.py index 1d1a4cd39..6334c75a3 100755 --- a/benchmarks/exercism_practice.py +++ b/benchmarks/exercism_practice.py @@ -5,6 +5,8 @@ from datetime import datetime from functools import partial from pathlib import Path +from typing import List +from spice import SpiceMessage import tqdm from openai import BadRequestError @@ -56,7 +58,7 @@ async def failure_analysis(exercise_runner, language): test_results = exercise_runner.read_test_results() final_message = f"All instructions:\n{instructions}\nCode to review:\n{code}\nTest" f" results:\n{test_results}" - messages = [ + messages: List[SpiceMessage] = [ {"role": "system", "content": prompt}, {"role": "user", "content": final_message}, ] diff --git a/mentat/conversation.py b/mentat/conversation.py index 5e4f1c60f..3538c5bbf 100644 --- a/mentat/conversation.py +++ b/mentat/conversation.py @@ -198,17 +198,13 @@ async def _stream_model_response( stream.send("Streaming...\n") async with stream.interrupt_catcher(parser.shutdown): - parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response.stream())) + parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response)) # Sampler and History require previous_file_lines for file_edit in parsed_llm_response.file_edits: file_edit.previous_file_lines = code_file_manager.file_lines.get(file_edit.file_path, []).copy() - # TODO: this is janky come up with better solution - # if the stream was interrupted, then the finally block in the response.stream() async generator - # will wait for an opportunity to run. This sleep call gives it that opportunity. - # the finally block runs the logging callback - await asyncio.sleep(0.01) + cost_tracker.log_api_call_stats(response.current_response()) cost_tracker.display_last_api_call() messages.append( diff --git a/mentat/llm_api_handler.py b/mentat/llm_api_handler.py index f48eb8b64..5613c3af0 100644 --- a/mentat/llm_api_handler.py +++ b/mentat/llm_api_handler.py @@ -12,8 +12,10 @@ Callable, Dict, List, + Literal, Optional, cast, + overload, ) import attr @@ -29,10 +31,14 @@ ) from openai.types.chat.completion_create_params import ResponseFormat from PIL import Image -from spice import APIConnectionError, Spice, SpiceEmbeddings, SpiceError, SpiceResponse, SpiceWhisper +from spice import APIConnectionError, Spice, SpiceError, SpiceMessage, SpiceResponse, StreamingSpiceResponse +from spice.errors import NoAPIKeyError +from spice.models import WHISPER_1 +from spice.providers import OPEN_AI from mentat.errors import MentatError, ReturnToUser from mentat.session_context import SESSION_CONTEXT +from mentat.session_input import collect_user_input from mentat.utils import mentat_dir_path TOKEN_COUNT_WARNING = 32000 @@ -323,61 +329,81 @@ async def initialize_client(self): if not load_dotenv(mentat_dir_path / ".env") and not load_dotenv(ctx.cwd / ".env"): load_dotenv() - if os.getenv("AZURE_OPENAI_KEY") is not None: - embedding_and_whisper_provider = "azure" - else: - embedding_and_whisper_provider = "openai" + self.spice = Spice() - self.spice_client = Spice() - self.spice_embedding_client = SpiceEmbeddings(provider=embedding_and_whisper_provider) - self.spice_whisper_client = SpiceWhisper(provider=embedding_and_whisper_provider) + try: + self.spice.load_provider(OPEN_AI) + except NoAPIKeyError: + ctx.stream.send( + "No OpenAI api key detected. To avoid entering your api key on startup, create a .env file in" + " ~/.mentat/.env or in your workspace root.", + style="warning", + ) + ctx.stream.send("Enter your api key:", style="info") + key = (await collect_user_input(log_input=False)).data + os.environ["OPENAI_API_KEY"] = key + + @overload + async def call_llm_api( + self, + messages: List[SpiceMessage], + model: str, + stream: Literal[False], + response_format: ResponseFormat = ResponseFormat(type="text"), + ) -> SpiceResponse: + ... + + @overload + async def call_llm_api( + self, + messages: List[SpiceMessage], + model: str, + stream: Literal[True], + response_format: ResponseFormat = ResponseFormat(type="text"), + ) -> StreamingSpiceResponse: + ... @api_guard async def call_llm_api( self, - messages: list[ChatCompletionMessageParam], + messages: List[SpiceMessage], model: str, stream: bool, response_format: ResponseFormat = ResponseFormat(type="text"), - ) -> SpiceResponse: + ) -> SpiceResponse | StreamingSpiceResponse: session_context = SESSION_CONTEXT.get() config = session_context.config cost_tracker = session_context.cost_tracker - if "claude" in config.model: - messages = normalize_messages_for_anthropic(messages) - # Confirm that model has enough tokens remaining. tokens = prompt_tokens(messages, model) raise_if_context_exceeds_max(tokens) - # TODO: make spice message format and use across codebase consistently - _messages = [ - {"role": cast(str, message["role"]), "content": cast(str, message["content"])} for message in messages - ] - if "type" in response_format and response_format["type"] == "json_object": - _response_format = {"type": "json_object"} - else: - _response_format = {"type": "text"} - with sentry_sdk.start_span(description="LLM Call") as span: span.set_tag("model", model) - response = await self.spice_client.call_llm( - model=model, - messages=_messages, - stream=stream, - temperature=config.temperature, - response_format=_response_format, - logging_callback=cost_tracker.log_api_call_stats, - ) + if not stream: + response = await self.spice.get_response( + model=model, + messages=messages, + temperature=config.temperature, + response_format=response_format, # pyright: ignore + ) + cost_tracker.log_api_call_stats(response) + else: + response = await self.spice.stream_response( + model=model, + messages=messages, + temperature=config.temperature, + response_format=response_format, # pyright: ignore + ) return response @api_guard def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-ada-002") -> Embeddings: - return self.spice_embedding_client.get_embeddings(input_texts, model) + return self.spice.get_embeddings_sync(input_texts, model) # pyright: ignore @api_guard async def call_whisper_api(self, audio_path: Path) -> str: - return await self.spice_whisper_client.get_whisper_transcription(audio_path) + return await self.spice.get_transcription(audio_path, model=WHISPER_1) diff --git a/mentat/server/mentat_server.py b/mentat/server/mentat_server.py index 67d1c762a..fc42abd0c 100644 --- a/mentat/server/mentat_server.py +++ b/mentat/server/mentat_server.py @@ -18,7 +18,7 @@ class MentatServer: def __init__(self, cwd: Path, config: Config) -> None: self.cwd = cwd self.stopped = Event() - self.session = Session(self.cwd, config=config, apply_edits=False) + self.session = Session(self.cwd, config=config, apply_edits=False, show_update=False) async def _client_listener(self): with open(3) as fd_input: diff --git a/mentat/session.py b/mentat/session.py index 02c8ad76b..b9d7c762d 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -58,6 +58,7 @@ def __init__( config: Config = Config(), # Set to false for clients that apply the edits themselves (like vscode) apply_edits: bool = True, + show_update: bool = True, ): # All errors thrown here need to be caught here self.stopped = Event() @@ -112,7 +113,8 @@ def __init__( self.error = None # Functions that require session_context - check_version() + if show_update: + check_version() config.send_errors_to_stream() for path in paths: code_context.include(path, exclude_patterns=exclude_paths) diff --git a/requirements.txt b/requirements.txt index e1390485d..ebca0d533 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,7 @@ selenium==4.15.2 sentry-sdk==1.34.0 sounddevice==0.4.6 soundfile==0.12.1 +spiceai==0.1.8 termcolor==2.3.0 textual==0.47.1 textual-autocomplete==2.1.0b0 @@ -29,4 +30,3 @@ typing_extensions==4.8.0 tqdm==4.66.1 webdriver_manager==4.0.1 watchfiles==0.21.0 -spiceai==0.1.7