Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update spice #556

Merged
merged 3 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import re
from datetime import datetime
from pathlib import Path
from typing import List
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
from uuid import uuid4

from openai.types.chat.completion_create_params import ResponseFormat
from spice import SpiceMessage

from benchmarks.arg_parser import common_benchmark_parser
from benchmarks.benchmark_result import BenchmarkResult
Expand Down Expand Up @@ -43,7 +45,7 @@ def git_diff_from_comparison_commit(sample: Sample, comparison_commit: str) -> s

async def grade(to_grade, prompt, model="gpt-4-1106-preview"):
try:
messages = [
messages: List[SpiceMessage] = [
{"role": "system", "content": prompt},
{"role": "user", "content": to_grade},
]
Expand Down
4 changes: 3 additions & 1 deletion benchmarks/exercism_practice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import List
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
from spice import SpiceMessage

import tqdm
from openai import BadRequestError
Expand Down Expand Up @@ -56,7 +58,7 @@ async def failure_analysis(exercise_runner, language):
test_results = exercise_runner.read_test_results()

final_message = f"All instructions:\n{instructions}\nCode to review:\n{code}\nTest" f" results:\n{test_results}"
messages = [
messages: List[SpiceMessage] = [
{"role": "system", "content": prompt},
{"role": "user", "content": final_message},
]
Expand Down
8 changes: 2 additions & 6 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,13 @@ async def _stream_model_response(

stream.send("Streaming...\n")
async with stream.interrupt_catcher(parser.shutdown):
parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response.stream()))
parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response))
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved

# Sampler and History require previous_file_lines
for file_edit in parsed_llm_response.file_edits:
file_edit.previous_file_lines = code_file_manager.file_lines.get(file_edit.file_path, []).copy()

# TODO: this is janky come up with better solution
# if the stream was interrupted, then the finally block in the response.stream() async generator
# will wait for an opportunity to run. This sleep call gives it that opportunity.
# the finally block runs the logging callback
await asyncio.sleep(0.01)
cost_tracker.log_api_call_stats(response.current_response())
cost_tracker.display_last_api_call()

messages.append(
Expand Down
91 changes: 59 additions & 32 deletions mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
Callable,
Dict,
List,
Literal,
Optional,
cast,
overload,
)

import attr
Expand All @@ -29,7 +31,10 @@
)
from openai.types.chat.completion_create_params import ResponseFormat
from PIL import Image
from spice import APIConnectionError, Spice, SpiceEmbeddings, SpiceError, SpiceResponse, SpiceWhisper
from spice import APIConnectionError, Spice, SpiceError, SpiceMessage, SpiceResponse, StreamingSpiceResponse
from spice.errors import NoAPIKeyError
from spice.models import WHISPER_1
from spice.providers import OPEN_AI

from mentat.errors import MentatError, ReturnToUser
from mentat.session_context import SESSION_CONTEXT
Expand Down Expand Up @@ -323,61 +328,83 @@ async def initialize_client(self):
if not load_dotenv(mentat_dir_path / ".env") and not load_dotenv(ctx.cwd / ".env"):
load_dotenv()

if os.getenv("AZURE_OPENAI_KEY") is not None:
embedding_and_whisper_provider = "azure"
else:
embedding_and_whisper_provider = "openai"
self.spice = Spice()
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved

try:
self.spice.load_provider(OPEN_AI)
except NoAPIKeyError:
from mentat.session_input import collect_user_input

ctx.stream.send(
"No OpenAI api key detected. To avoid entering your api key on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root.",
style="warning",
)
ctx.stream.send("Enter your api key:", style="info")
key = (await collect_user_input(log_input=False)).data
os.environ["OPENAI_API_KEY"] = key

self.spice_client = Spice()
self.spice_embedding_client = SpiceEmbeddings(provider=embedding_and_whisper_provider)
self.spice_whisper_client = SpiceWhisper(provider=embedding_and_whisper_provider)
@overload
async def call_llm_api(
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
self,
messages: List[SpiceMessage],
model: str,
stream: Literal[False],
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> SpiceResponse:
...

@overload
async def call_llm_api(
self,
messages: List[SpiceMessage],
model: str,
stream: Literal[True],
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> StreamingSpiceResponse:
...

@api_guard
async def call_llm_api(
self,
messages: list[ChatCompletionMessageParam],
messages: List[SpiceMessage],
model: str,
stream: bool,
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> SpiceResponse:
) -> SpiceResponse | StreamingSpiceResponse:
session_context = SESSION_CONTEXT.get()
config = session_context.config
cost_tracker = session_context.cost_tracker

if "claude" in config.model:
messages = normalize_messages_for_anthropic(messages)

# Confirm that model has enough tokens remaining.
tokens = prompt_tokens(messages, model)
raise_if_context_exceeds_max(tokens)

# TODO: make spice message format and use across codebase consistently
_messages = [
{"role": cast(str, message["role"]), "content": cast(str, message["content"])} for message in messages
]
if "type" in response_format and response_format["type"] == "json_object":
_response_format = {"type": "json_object"}
else:
_response_format = {"type": "text"}

with sentry_sdk.start_span(description="LLM Call") as span:
span.set_tag("model", model)

response = await self.spice_client.call_llm(
model=model,
messages=_messages,
stream=stream,
temperature=config.temperature,
response_format=_response_format,
logging_callback=cost_tracker.log_api_call_stats,
)
if not stream:
response = await self.spice.get_response(
model=model,
messages=messages,
temperature=config.temperature,
response_format=response_format, # pyright: ignore
)
cost_tracker.log_api_call_stats(response)
else:
response = await self.spice.stream_response(
model=model,
messages=messages,
temperature=config.temperature,
response_format=response_format, # pyright: ignore
)

return response

@api_guard
def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-ada-002") -> Embeddings:
return self.spice_embedding_client.get_embeddings(input_texts, model)
return self.spice.get_embeddings_sync(input_texts, model) # pyright: ignore

@api_guard
async def call_whisper_api(self, audio_path: Path) -> str:
return await self.spice_whisper_client.get_whisper_transcription(audio_path)
return await self.spice.get_transcription(audio_path, model=WHISPER_1)
2 changes: 1 addition & 1 deletion mentat/server/mentat_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MentatServer:
def __init__(self, cwd: Path, config: Config) -> None:
self.cwd = cwd
self.stopped = Event()
self.session = Session(self.cwd, config=config, apply_edits=False)
self.session = Session(self.cwd, config=config, apply_edits=False, show_update=False)
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved

async def _client_listener(self):
with open(3) as fd_input:
Expand Down
4 changes: 3 additions & 1 deletion mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
config: Config = Config(),
# Set to false for clients that apply the edits themselves (like vscode)
apply_edits: bool = True,
show_update: bool = True,
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
):
# All errors thrown here need to be caught here
self.stopped = Event()
Expand Down Expand Up @@ -112,7 +113,8 @@ def __init__(
self.error = None

# Functions that require session_context
check_version()
if show_update:
check_version()
config.send_errors_to_stream()
for path in paths:
code_context.include(path, exclude_patterns=exclude_paths)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ selenium==4.15.2
sentry-sdk==1.34.0
sounddevice==0.4.6
soundfile==0.12.1
spiceai==0.1.8
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
termcolor==2.3.0
textual==0.47.1
textual-autocomplete==2.1.0b0
Expand All @@ -29,4 +30,3 @@ typing_extensions==4.8.0
tqdm==4.66.1
webdriver_manager==4.0.1
watchfiles==0.21.0
spiceai==0.1.7
Loading