Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add provider checking based on model name and provider #571

Merged
merged 6 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from spice.errors import InvalidProviderError, UnknownModelError

from mentat.llm_api_handler import (
TOKEN_COUNT_WARNING,
Expand Down Expand Up @@ -91,9 +92,11 @@ async def count_tokens(
) -> int:
ctx = SESSION_CONTEXT.get()

_messages = await self.get_messages(system_prompt=system_prompt, include_code_message=include_code_message)
model = ctx.config.model
return ctx.llm_api_handler.spice.count_prompt_tokens(_messages, model)
try:
_messages = await self.get_messages(system_prompt=system_prompt, include_code_message=include_code_message)
return ctx.llm_api_handler.spice.count_prompt_tokens(_messages, ctx.config.model, ctx.config.provider)
except (UnknownModelError, InvalidProviderError):
return 0

async def get_messages(
self,
Expand Down Expand Up @@ -126,7 +129,7 @@ async def get_messages(

if include_code_message:
code_message = await ctx.code_context.get_code_message(
ctx.llm_api_handler.spice.count_prompt_tokens(_messages, ctx.config.model),
ctx.llm_api_handler.spice.count_prompt_tokens(_messages, ctx.config.model, ctx.config.provider),
prompt=(
prompt # Prompt can be image as well as text
if isinstance(prompt, str)
Expand Down Expand Up @@ -186,7 +189,7 @@ async def _stream_model_response(
terminate=True,
)

num_prompt_tokens = llm_api_handler.spice.count_prompt_tokens(messages, config.model)
num_prompt_tokens = llm_api_handler.spice.count_prompt_tokens(messages, config.model, config.provider)
stream.send(f"Total token count: {num_prompt_tokens}", style="info")
if num_prompt_tokens > TOKEN_COUNT_WARNING:
stream.send(
Expand Down Expand Up @@ -220,7 +223,7 @@ async def get_model_response(self) -> ParsedLLMResponse:
llm_api_handler = session_context.llm_api_handler

messages_snapshot = await self.get_messages(include_code_message=True)
tokens_used = llm_api_handler.spice.count_prompt_tokens(messages_snapshot, config.model)
tokens_used = llm_api_handler.spice.count_prompt_tokens(messages_snapshot, config.model, config.provider)
raise_if_context_exceeds_max(tokens_used)

try:
Expand All @@ -238,7 +241,7 @@ async def get_model_response(self) -> ParsedLLMResponse:
async def remaining_context(self) -> int | None:
ctx = SESSION_CONTEXT.get()
return get_max_tokens() - ctx.llm_api_handler.spice.count_prompt_tokens(
await self.get_messages(), ctx.config.model
await self.get_messages(), ctx.config.model, ctx.config.provider
)

async def can_add_to_context(self, message: str) -> bool:
Expand Down
88 changes: 73 additions & 15 deletions mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from dotenv import load_dotenv
from openai.types.chat.completion_create_params import ResponseFormat
from spice import EmbeddingResponse, Spice, SpiceMessage, SpiceResponse, StreamingSpiceResponse, TranscriptionResponse
from spice.errors import APIConnectionError, NoAPIKeyError
from spice.errors import APIConnectionError, AuthenticationError, InvalidProviderError, NoAPIKeyError
from spice.models import WHISPER_1
from spice.providers import OPEN_AI
from spice.spice import UnknownModelError, get_model_from_name
from spice.spice import UnknownModelError, get_model_from_name, get_provider_from_name

from mentat.errors import MentatError, ReturnToUser
from mentat.session_context import SESSION_CONTEXT
Expand Down Expand Up @@ -58,13 +58,20 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> RetType:
assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!"
try:
return await func(*args, **kwargs)
except AuthenticationError:
raise MentatError("Authentication error: Check your api key and try again.")
except APIConnectionError:
raise MentatError("API connection error: please check your internet connection and try again.")
raise MentatError("API connection error: Check your internet connection and try again.")
except UnknownModelError:
SESSION_CONTEXT.get().stream.send(
"Unknown model. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()
except InvalidProviderError:
SESSION_CONTEXT.get().stream.send(
"Unknown provider. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()

return async_wrapper # pyright: ignore[reportReturnType]
else:
Expand All @@ -73,13 +80,20 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> RetType:
assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!"
try:
return func(*args, **kwargs)
except AuthenticationError:
raise MentatError("Authentication error: Check your api key and try again.")
except APIConnectionError:
raise MentatError("API connection error: please check your internet connection and try again.")
raise MentatError("API connection error: Check your internet connection and try again.")
except UnknownModelError:
SESSION_CONTEXT.get().stream.send(
"Unknown model. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()
except InvalidProviderError:
SESSION_CONTEXT.get().stream.send(
"Unknown provider. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()

return sync_wrapper

Expand Down Expand Up @@ -142,19 +156,63 @@ async def initialize_client(self):
if not load_dotenv(mentat_dir_path / ".env") and not load_dotenv(ctx.cwd / ".env"):
load_dotenv()

try:
self.spice.load_provider(OPEN_AI)
except NoAPIKeyError:
from mentat.session_input import collect_user_input

user_provider = get_model_from_name(ctx.config.model).provider
if ctx.config.provider is not None:
try:
user_provider = get_provider_from_name(ctx.config.provider)
except InvalidProviderError:
ctx.stream.send(
f"Unknown provider {ctx.config.provider}. Use /config provider <provider> to set your provider.",
style="warning",
)
elif user_provider is None:
ctx.stream.send(
"No OpenAI api key detected. To avoid entering your api key on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root.",
f"Unknown model {ctx.config.model}. Use /config provider <provider> to set your provider.",
style="warning",
)
ctx.stream.send("Enter your api key:", style="info")
key = (await collect_user_input(log_input=False)).data
os.environ["OPENAI_API_KEY"] = key

# ragdaemon always needs an openai provider
providers = [OPEN_AI]
if user_provider is not None:
providers.append(user_provider)

for provider in providers:
try:
self.spice.load_provider(provider)
except NoAPIKeyError:
from mentat.session_input import collect_user_input

match provider.name:
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
case "open_ai" | "openai":
env_variable = "OPENAI_API_KEY"
case "anthropic":
env_variable = "ANTHROPIC_API_KEY"
case "azure":
if os.getenv("AZURE_OPENAI_ENDPOINT") is None:
ctx.stream.send(
f"No Azure OpenAI endpoint detected. To avoid entering your endpoint on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root and set AZURE_OPENAI_ENDPOINT.",
style="warning",
)
ctx.stream.send("Enter your endpoint:", style="info")
endpoint = (await collect_user_input(log_input=False)).data
os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
if os.getenv("AZURE_OPENAI_KEY") is not None:
return
env_variable = "AZURE_OPENAI_KEY"
case _:
raise MentatError(
f"No api key detected for provider {provider.name}. Create a .env file in ~/.mentat/.env or in your workspace root with your api key"
)

ctx.stream.send(
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
f"No {provider.name} api key detected. To avoid entering your api key on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root.",
style="warning",
)
ctx.stream.send("Enter your api key:", style="info")
key = (await collect_user_input(log_input=False)).data
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
os.environ[env_variable] = key

@overload
async def call_llm_api(
Expand Down Expand Up @@ -191,7 +249,7 @@ async def call_llm_api(
config = session_context.config

# Confirm that model has enough tokens remaining
tokens = self.spice.count_prompt_tokens(messages, model)
tokens = self.spice.count_prompt_tokens(messages, model, provider)
raise_if_context_exceeds_max(tokens)

with sentry_sdk.start_span(description="LLM Call") as span:
Expand Down
2 changes: 1 addition & 1 deletion mentat/revisor/revisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def revise_edit(file_edit: FileEdit):
ChatCompletionSystemMessageParam(content=f"Diff:\n{diff}", role="system"),
]
code_message = await ctx.code_context.get_code_message(
ctx.llm_api_handler.spice.count_prompt_tokens(messages, ctx.config.model)
ctx.llm_api_handler.spice.count_prompt_tokens(messages, ctx.config.model, ctx.config.provider)
)
messages.insert(1, ChatCompletionSystemMessageParam(content=code_message, role="system"))

Expand Down
135 changes: 64 additions & 71 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@

import attr
import sentry_sdk
from openai import (
APITimeoutError,
BadRequestError,
PermissionDeniedError,
RateLimitError,
)
from spice.errors import APIConnectionError, APIError, AuthenticationError

from mentat.agent_handler import AgentHandler
from mentat.auto_completer import AutoCompleter
Expand Down Expand Up @@ -162,76 +157,74 @@ async def _main(self):
code_file_manager = session_context.code_file_manager
agent_handler = session_context.agent_handler

await session_context.llm_api_handler.initialize_client()

await code_context.refresh_daemon()

check_model()
try:
await session_context.llm_api_handler.initialize_client()
await code_context.refresh_daemon()

check_model()

need_user_request = True
while True:
try:
await code_context.refresh_context_display()
if need_user_request:
# Normally, the code_file_manager pushes the edits; but when agent mode is on, we want all
# edits made between user input to be collected together.
if agent_handler.agent_enabled:
code_file_manager.history.push_edits()
stream.send(
"Use /undo to undo all changes from agent mode since last input.",
style="success",
)
message = await collect_input_with_commands()
if message.data.strip() == "":
continue
conversation.add_user_message(message.data)

parsed_llm_response = await conversation.get_model_response()
file_edits = [file_edit for file_edit in parsed_llm_response.file_edits if file_edit.is_valid()]
for file_edit in file_edits:
file_edit.resolve_conflicts()
if file_edits:
if session_context.config.revisor:
await revise_edits(file_edits)

if session_context.config.sampler:
session_context.sampler.set_active_diff()

self.send_file_edits(file_edits)
if self.apply_edits:
if not agent_handler.agent_enabled:
file_edits, need_user_request = await get_user_feedback_on_edits(file_edits)
applied_edits = await code_file_manager.write_changes_to_files(file_edits)
stream.send(
("Changes applied." if applied_edits else "No changes applied."),
style="input",
)
else:
need_user_request = True

need_user_request = True
while True:
await code_context.refresh_context_display()
try:
if need_user_request:
# Normally, the code_file_manager pushes the edits; but when agent mode is on, we want all
# edits made between user input to be collected together.
if agent_handler.agent_enabled:
code_file_manager.history.push_edits()
stream.send(
"Use /undo to undo all changes from agent mode since last input.",
style="success",
)
message = await collect_input_with_commands()
if message.data.strip() == "":
continue
conversation.add_user_message(message.data)

parsed_llm_response = await conversation.get_model_response()
file_edits = [file_edit for file_edit in parsed_llm_response.file_edits if file_edit.is_valid()]
for file_edit in file_edits:
file_edit.resolve_conflicts()
if file_edits:
if session_context.config.revisor:
await revise_edits(file_edits)

if session_context.config.sampler:
session_context.sampler.set_active_diff()

self.send_file_edits(file_edits)
if self.apply_edits:
if not agent_handler.agent_enabled:
file_edits, need_user_request = await get_user_feedback_on_edits(file_edits)
applied_edits = await code_file_manager.write_changes_to_files(file_edits)
stream.send(
("Changes applied." if applied_edits else "No changes applied."),
style="input",
)
if agent_handler.agent_enabled:
if parsed_llm_response.interrupted:
need_user_request = True
else:
need_user_request = await agent_handler.add_agent_context()
else:
need_user_request = True
stream.send(bool(file_edits), channel="edits_complete")

if agent_handler.agent_enabled:
if parsed_llm_response.interrupted:
need_user_request = True
else:
need_user_request = await agent_handler.add_agent_context()
else:
except ReturnToUser:
stream.send(None, channel="loading", terminate=True)
need_user_request = True
stream.send(bool(file_edits), channel="edits_complete")
except SessionExit:
stream.send(None, channel="client_exit")
break
except ReturnToUser:
stream.send(None, channel="loading", terminate=True)
need_user_request = True
continue
except (
APITimeoutError,
RateLimitError,
BadRequestError,
PermissionDeniedError,
) as e:
stream.send(f"Error accessing OpenAI API: {e.message}", style="error")
break
continue
except SessionExit:
stream.send(None, channel="client_exit")
except AuthenticationError:
raise MentatError("Authentication error: Check your api key and try again.")
except APIConnectionError:
raise MentatError("API connection error: Check your internet connection and try again.")
except APIError as e:
stream.send(f"Error accessing OpenAI API: {e}", style="error")

async def listen_for_session_exit(self):
await self.stream.recv(channel="session_exit")
Expand Down
Loading