diff --git a/mentat/conversation.py b/mentat/conversation.py index 11bcc8733..b2bfbb4d5 100644 --- a/mentat/conversation.py +++ b/mentat/conversation.py @@ -14,6 +14,7 @@ ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, ) +from spice.errors import InvalidProviderError, UnknownModelError from mentat.llm_api_handler import ( TOKEN_COUNT_WARNING, @@ -91,9 +92,11 @@ async def count_tokens( ) -> int: ctx = SESSION_CONTEXT.get() - _messages = await self.get_messages(system_prompt=system_prompt, include_code_message=include_code_message) - model = ctx.config.model - return ctx.llm_api_handler.spice.count_prompt_tokens(_messages, model) + try: + _messages = await self.get_messages(system_prompt=system_prompt, include_code_message=include_code_message) + return ctx.llm_api_handler.spice.count_prompt_tokens(_messages, ctx.config.model, ctx.config.provider) + except (UnknownModelError, InvalidProviderError): + return 0 async def get_messages( self, @@ -126,7 +129,7 @@ async def get_messages( if include_code_message: code_message = await ctx.code_context.get_code_message( - ctx.llm_api_handler.spice.count_prompt_tokens(_messages, ctx.config.model), + ctx.llm_api_handler.spice.count_prompt_tokens(_messages, ctx.config.model, ctx.config.provider), prompt=( prompt # Prompt can be image as well as text if isinstance(prompt, str) @@ -186,7 +189,7 @@ async def _stream_model_response( terminate=True, ) - num_prompt_tokens = llm_api_handler.spice.count_prompt_tokens(messages, config.model) + num_prompt_tokens = llm_api_handler.spice.count_prompt_tokens(messages, config.model, config.provider) stream.send(f"Total token count: {num_prompt_tokens}", style="info") if num_prompt_tokens > TOKEN_COUNT_WARNING: stream.send( @@ -220,7 +223,7 @@ async def get_model_response(self) -> ParsedLLMResponse: llm_api_handler = session_context.llm_api_handler messages_snapshot = await self.get_messages(include_code_message=True) - tokens_used = llm_api_handler.spice.count_prompt_tokens(messages_snapshot, config.model) + tokens_used = llm_api_handler.spice.count_prompt_tokens(messages_snapshot, config.model, config.provider) raise_if_context_exceeds_max(tokens_used) try: @@ -238,7 +241,7 @@ async def get_model_response(self) -> ParsedLLMResponse: async def remaining_context(self) -> int | None: ctx = SESSION_CONTEXT.get() return get_max_tokens() - ctx.llm_api_handler.spice.count_prompt_tokens( - await self.get_messages(), ctx.config.model + await self.get_messages(), ctx.config.model, ctx.config.provider ) async def can_add_to_context(self, message: str) -> bool: diff --git a/mentat/llm_api_handler.py b/mentat/llm_api_handler.py index 5152e88cd..98cdbecc3 100644 --- a/mentat/llm_api_handler.py +++ b/mentat/llm_api_handler.py @@ -19,10 +19,10 @@ from dotenv import load_dotenv from openai.types.chat.completion_create_params import ResponseFormat from spice import EmbeddingResponse, Spice, SpiceMessage, SpiceResponse, StreamingSpiceResponse, TranscriptionResponse -from spice.errors import APIConnectionError, NoAPIKeyError +from spice.errors import APIConnectionError, AuthenticationError, InvalidProviderError, NoAPIKeyError from spice.models import WHISPER_1 from spice.providers import OPEN_AI -from spice.spice import UnknownModelError, get_model_from_name +from spice.spice import UnknownModelError, get_model_from_name, get_provider_from_name from mentat.errors import MentatError, ReturnToUser from mentat.session_context import SESSION_CONTEXT @@ -58,13 +58,20 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> RetType: assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!" try: return await func(*args, **kwargs) + except AuthenticationError: + raise MentatError("Authentication error: Check your api key and try again.") except APIConnectionError: - raise MentatError("API connection error: please check your internet connection and try again.") + raise MentatError("API connection error: Check your internet connection and try again.") except UnknownModelError: SESSION_CONTEXT.get().stream.send( "Unknown model. Use /config provider and try again.", style="error" ) raise ReturnToUser() + except InvalidProviderError: + SESSION_CONTEXT.get().stream.send( + "Unknown provider. Use /config provider and try again.", style="error" + ) + raise ReturnToUser() return async_wrapper # pyright: ignore[reportReturnType] else: @@ -73,13 +80,20 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> RetType: assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!" try: return func(*args, **kwargs) + except AuthenticationError: + raise MentatError("Authentication error: Check your api key and try again.") except APIConnectionError: - raise MentatError("API connection error: please check your internet connection and try again.") + raise MentatError("API connection error: Check your internet connection and try again.") except UnknownModelError: SESSION_CONTEXT.get().stream.send( "Unknown model. Use /config provider and try again.", style="error" ) raise ReturnToUser() + except InvalidProviderError: + SESSION_CONTEXT.get().stream.send( + "Unknown provider. Use /config provider and try again.", style="error" + ) + raise ReturnToUser() return sync_wrapper @@ -142,19 +156,63 @@ async def initialize_client(self): if not load_dotenv(mentat_dir_path / ".env") and not load_dotenv(ctx.cwd / ".env"): load_dotenv() - try: - self.spice.load_provider(OPEN_AI) - except NoAPIKeyError: - from mentat.session_input import collect_user_input - + user_provider = get_model_from_name(ctx.config.model).provider + if ctx.config.provider is not None: + try: + user_provider = get_provider_from_name(ctx.config.provider) + except InvalidProviderError: + ctx.stream.send( + f"Unknown provider {ctx.config.provider}. Use /config provider to set your provider.", + style="warning", + ) + elif user_provider is None: ctx.stream.send( - "No OpenAI api key detected. To avoid entering your api key on startup, create a .env file in" - " ~/.mentat/.env or in your workspace root.", + f"Unknown model {ctx.config.model}. Use /config provider to set your provider.", style="warning", ) - ctx.stream.send("Enter your api key:", style="info") - key = (await collect_user_input(log_input=False)).data - os.environ["OPENAI_API_KEY"] = key + + # ragdaemon always needs an openai provider + providers = [OPEN_AI] + if user_provider is not None: + providers.append(user_provider) + + for provider in providers: + try: + self.spice.load_provider(provider) + except NoAPIKeyError: + from mentat.session_input import collect_user_input + + match provider.name: + case "open_ai" | "openai": + env_variable = "OPENAI_API_KEY" + case "anthropic": + env_variable = "ANTHROPIC_API_KEY" + case "azure": + if os.getenv("AZURE_OPENAI_ENDPOINT") is None: + ctx.stream.send( + f"No Azure OpenAI endpoint detected. To avoid entering your endpoint on startup, create a .env file in" + " ~/.mentat/.env or in your workspace root and set AZURE_OPENAI_ENDPOINT.", + style="warning", + ) + ctx.stream.send("Enter your endpoint:", style="info") + endpoint = (await collect_user_input(log_input=False)).data + os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint + if os.getenv("AZURE_OPENAI_KEY") is not None: + return + env_variable = "AZURE_OPENAI_KEY" + case _: + raise MentatError( + f"No api key detected for provider {provider.name}. Create a .env file in ~/.mentat/.env or in your workspace root with your api key" + ) + + ctx.stream.send( + f"No {provider.name} api key detected. To avoid entering your api key on startup, create a .env file in" + " ~/.mentat/.env or in your workspace root.", + style="warning", + ) + ctx.stream.send("Enter your api key:", style="info") + key = (await collect_user_input(log_input=False)).data + os.environ[env_variable] = key @overload async def call_llm_api( @@ -191,7 +249,7 @@ async def call_llm_api( config = session_context.config # Confirm that model has enough tokens remaining - tokens = self.spice.count_prompt_tokens(messages, model) + tokens = self.spice.count_prompt_tokens(messages, model, provider) raise_if_context_exceeds_max(tokens) with sentry_sdk.start_span(description="LLM Call") as span: diff --git a/mentat/revisor/revisor.py b/mentat/revisor/revisor.py index fbe7e5296..9f3321ce9 100644 --- a/mentat/revisor/revisor.py +++ b/mentat/revisor/revisor.py @@ -59,7 +59,7 @@ async def revise_edit(file_edit: FileEdit): ChatCompletionSystemMessageParam(content=f"Diff:\n{diff}", role="system"), ] code_message = await ctx.code_context.get_code_message( - ctx.llm_api_handler.spice.count_prompt_tokens(messages, ctx.config.model) + ctx.llm_api_handler.spice.count_prompt_tokens(messages, ctx.config.model, ctx.config.provider) ) messages.insert(1, ChatCompletionSystemMessageParam(content=code_message, role="system")) diff --git a/mentat/session.py b/mentat/session.py index 088cfc723..6a1ffb6ed 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -9,12 +9,7 @@ import attr import sentry_sdk -from openai import ( - APITimeoutError, - BadRequestError, - PermissionDeniedError, - RateLimitError, -) +from spice.errors import APIConnectionError, APIError, AuthenticationError from mentat.agent_handler import AgentHandler from mentat.auto_completer import AutoCompleter @@ -162,76 +157,74 @@ async def _main(self): code_file_manager = session_context.code_file_manager agent_handler = session_context.agent_handler - await session_context.llm_api_handler.initialize_client() - - await code_context.refresh_daemon() - - check_model() + try: + await session_context.llm_api_handler.initialize_client() + await code_context.refresh_daemon() + + check_model() + + need_user_request = True + while True: + try: + await code_context.refresh_context_display() + if need_user_request: + # Normally, the code_file_manager pushes the edits; but when agent mode is on, we want all + # edits made between user input to be collected together. + if agent_handler.agent_enabled: + code_file_manager.history.push_edits() + stream.send( + "Use /undo to undo all changes from agent mode since last input.", + style="success", + ) + message = await collect_input_with_commands() + if message.data.strip() == "": + continue + conversation.add_user_message(message.data) + + parsed_llm_response = await conversation.get_model_response() + file_edits = [file_edit for file_edit in parsed_llm_response.file_edits if file_edit.is_valid()] + for file_edit in file_edits: + file_edit.resolve_conflicts() + if file_edits: + if session_context.config.revisor: + await revise_edits(file_edits) + + if session_context.config.sampler: + session_context.sampler.set_active_diff() + + self.send_file_edits(file_edits) + if self.apply_edits: + if not agent_handler.agent_enabled: + file_edits, need_user_request = await get_user_feedback_on_edits(file_edits) + applied_edits = await code_file_manager.write_changes_to_files(file_edits) + stream.send( + ("Changes applied." if applied_edits else "No changes applied."), + style="input", + ) + else: + need_user_request = True - need_user_request = True - while True: - await code_context.refresh_context_display() - try: - if need_user_request: - # Normally, the code_file_manager pushes the edits; but when agent mode is on, we want all - # edits made between user input to be collected together. - if agent_handler.agent_enabled: - code_file_manager.history.push_edits() - stream.send( - "Use /undo to undo all changes from agent mode since last input.", - style="success", - ) - message = await collect_input_with_commands() - if message.data.strip() == "": - continue - conversation.add_user_message(message.data) - - parsed_llm_response = await conversation.get_model_response() - file_edits = [file_edit for file_edit in parsed_llm_response.file_edits if file_edit.is_valid()] - for file_edit in file_edits: - file_edit.resolve_conflicts() - if file_edits: - if session_context.config.revisor: - await revise_edits(file_edits) - - if session_context.config.sampler: - session_context.sampler.set_active_diff() - - self.send_file_edits(file_edits) - if self.apply_edits: - if not agent_handler.agent_enabled: - file_edits, need_user_request = await get_user_feedback_on_edits(file_edits) - applied_edits = await code_file_manager.write_changes_to_files(file_edits) - stream.send( - ("Changes applied." if applied_edits else "No changes applied."), - style="input", - ) + if agent_handler.agent_enabled: + if parsed_llm_response.interrupted: + need_user_request = True + else: + need_user_request = await agent_handler.add_agent_context() else: need_user_request = True + stream.send(bool(file_edits), channel="edits_complete") - if agent_handler.agent_enabled: - if parsed_llm_response.interrupted: - need_user_request = True - else: - need_user_request = await agent_handler.add_agent_context() - else: + except ReturnToUser: + stream.send(None, channel="loading", terminate=True) need_user_request = True - stream.send(bool(file_edits), channel="edits_complete") - except SessionExit: - stream.send(None, channel="client_exit") - break - except ReturnToUser: - stream.send(None, channel="loading", terminate=True) - need_user_request = True - continue - except ( - APITimeoutError, - RateLimitError, - BadRequestError, - PermissionDeniedError, - ) as e: - stream.send(f"Error accessing OpenAI API: {e.message}", style="error") - break + continue + except SessionExit: + stream.send(None, channel="client_exit") + except AuthenticationError: + raise MentatError("Authentication error: Check your api key and try again.") + except APIConnectionError: + raise MentatError("API connection error: Check your internet connection and try again.") + except APIError as e: + stream.send(f"Error accessing OpenAI API: {e}", style="error") async def listen_for_session_exit(self): await self.stream.recv(channel="session_exit")