Skip to content

Commit

Permalink
update spice
Browse files Browse the repository at this point in the history
  • Loading branch information
PCSwingle committed Apr 5, 2024
1 parent f60f362 commit f8a61fd
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 43 deletions.
4 changes: 3 additions & 1 deletion benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import re
from datetime import datetime
from pathlib import Path
from typing import List
from uuid import uuid4

from openai.types.chat.completion_create_params import ResponseFormat
from spice import SpiceMessage

from benchmarks.arg_parser import common_benchmark_parser
from benchmarks.benchmark_result import BenchmarkResult
Expand Down Expand Up @@ -43,7 +45,7 @@ def git_diff_from_comparison_commit(sample: Sample, comparison_commit: str) -> s

async def grade(to_grade, prompt, model="gpt-4-1106-preview"):
try:
messages = [
messages: List[SpiceMessage] = [
{"role": "system", "content": prompt},
{"role": "user", "content": to_grade},
]
Expand Down
4 changes: 3 additions & 1 deletion benchmarks/exercism_practice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import List
from spice import SpiceMessage

import tqdm
from openai import BadRequestError
Expand Down Expand Up @@ -56,7 +58,7 @@ async def failure_analysis(exercise_runner, language):
test_results = exercise_runner.read_test_results()

final_message = f"All instructions:\n{instructions}\nCode to review:\n{code}\nTest" f" results:\n{test_results}"
messages = [
messages: List[SpiceMessage] = [
{"role": "system", "content": prompt},
{"role": "user", "content": final_message},
]
Expand Down
8 changes: 2 additions & 6 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,13 @@ async def _stream_model_response(

stream.send("Streaming...\n")
async with stream.interrupt_catcher(parser.shutdown):
parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response.stream()))
parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response))

# Sampler and History require previous_file_lines
for file_edit in parsed_llm_response.file_edits:
file_edit.previous_file_lines = code_file_manager.file_lines.get(file_edit.file_path, []).copy()

# TODO: this is janky come up with better solution
# if the stream was interrupted, then the finally block in the response.stream() async generator
# will wait for an opportunity to run. This sleep call gives it that opportunity.
# the finally block runs the logging callback
await asyncio.sleep(0.01)
cost_tracker.log_api_call_stats(response.current_response())
cost_tracker.display_last_api_call()

messages.append(
Expand Down
90 changes: 58 additions & 32 deletions mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
Callable,
Dict,
List,
Literal,
Optional,
cast,
overload,
)

import attr
Expand All @@ -29,10 +31,14 @@
)
from openai.types.chat.completion_create_params import ResponseFormat
from PIL import Image
from spice import APIConnectionError, Spice, SpiceEmbeddings, SpiceError, SpiceResponse, SpiceWhisper
from spice import APIConnectionError, Spice, SpiceError, SpiceMessage, SpiceResponse, StreamingSpiceResponse
from spice.errors import NoAPIKeyError
from spice.models import WHISPER_1
from spice.providers import OPEN_AI

from mentat.errors import MentatError, ReturnToUser
from mentat.session_context import SESSION_CONTEXT
from mentat.session_input import collect_user_input
from mentat.utils import mentat_dir_path

TOKEN_COUNT_WARNING = 32000
Expand Down Expand Up @@ -323,61 +329,81 @@ async def initialize_client(self):
if not load_dotenv(mentat_dir_path / ".env") and not load_dotenv(ctx.cwd / ".env"):
load_dotenv()

if os.getenv("AZURE_OPENAI_KEY") is not None:
embedding_and_whisper_provider = "azure"
else:
embedding_and_whisper_provider = "openai"
self.spice = Spice()

self.spice_client = Spice()
self.spice_embedding_client = SpiceEmbeddings(provider=embedding_and_whisper_provider)
self.spice_whisper_client = SpiceWhisper(provider=embedding_and_whisper_provider)
try:
self.spice.load_provider(OPEN_AI)
except NoAPIKeyError:
ctx.stream.send(
"No OpenAI api key detected. To avoid entering your api key on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root.",
style="warning",
)
ctx.stream.send("Enter your api key:", style="info")
key = (await collect_user_input(log_input=False)).data
os.environ["OPENAI_API_KEY"] = key

@overload
async def call_llm_api(
self,
messages: List[SpiceMessage],
model: str,
stream: Literal[False],
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> SpiceResponse:
...

@overload
async def call_llm_api(
self,
messages: List[SpiceMessage],
model: str,
stream: Literal[True],
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> StreamingSpiceResponse:
...

@api_guard
async def call_llm_api(
self,
messages: list[ChatCompletionMessageParam],
messages: List[SpiceMessage],
model: str,
stream: bool,
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> SpiceResponse:
) -> SpiceResponse | StreamingSpiceResponse:
session_context = SESSION_CONTEXT.get()
config = session_context.config
cost_tracker = session_context.cost_tracker

if "claude" in config.model:
messages = normalize_messages_for_anthropic(messages)

# Confirm that model has enough tokens remaining.
tokens = prompt_tokens(messages, model)
raise_if_context_exceeds_max(tokens)

# TODO: make spice message format and use across codebase consistently
_messages = [
{"role": cast(str, message["role"]), "content": cast(str, message["content"])} for message in messages
]
if "type" in response_format and response_format["type"] == "json_object":
_response_format = {"type": "json_object"}
else:
_response_format = {"type": "text"}

with sentry_sdk.start_span(description="LLM Call") as span:
span.set_tag("model", model)

response = await self.spice_client.call_llm(
model=model,
messages=_messages,
stream=stream,
temperature=config.temperature,
response_format=_response_format,
logging_callback=cost_tracker.log_api_call_stats,
)
if not stream:
response = await self.spice.get_response(
model=model,
messages=messages,
temperature=config.temperature,
response_format=response_format, # pyright: ignore
)
cost_tracker.log_api_call_stats(response)
else:
response = await self.spice.stream_response(
model=model,
messages=messages,
temperature=config.temperature,
response_format=response_format, # pyright: ignore
)

return response

@api_guard
def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-ada-002") -> Embeddings:
return self.spice_embedding_client.get_embeddings(input_texts, model)
return self.spice.get_embeddings_sync(input_texts, model) # pyright: ignore

@api_guard
async def call_whisper_api(self, audio_path: Path) -> str:
return await self.spice_whisper_client.get_whisper_transcription(audio_path)
return await self.spice.get_transcription(audio_path, model=WHISPER_1)
2 changes: 1 addition & 1 deletion mentat/server/mentat_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MentatServer:
def __init__(self, cwd: Path, config: Config) -> None:
self.cwd = cwd
self.stopped = Event()
self.session = Session(self.cwd, config=config, apply_edits=False)
self.session = Session(self.cwd, config=config, apply_edits=False, show_update=False)

async def _client_listener(self):
with open(3) as fd_input:
Expand Down
4 changes: 3 additions & 1 deletion mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
config: Config = Config(),
# Set to false for clients that apply the edits themselves (like vscode)
apply_edits: bool = True,
show_update: bool = True,
):
# All errors thrown here need to be caught here
self.stopped = Event()
Expand Down Expand Up @@ -112,7 +113,8 @@ def __init__(
self.error = None

# Functions that require session_context
check_version()
if show_update:
check_version()
config.send_errors_to_stream()
for path in paths:
code_context.include(path, exclude_patterns=exclude_paths)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ selenium==4.15.2
sentry-sdk==1.34.0
sounddevice==0.4.6
soundfile==0.12.1
spiceai==0.1.8
termcolor==2.3.0
textual==0.47.1
textual-autocomplete==2.1.0b0
Expand All @@ -29,4 +30,3 @@ typing_extensions==4.8.0
tqdm==4.66.1
webdriver_manager==4.0.1
watchfiles==0.21.0
spiceai==0.1.7

0 comments on commit f8a61fd

Please sign in to comment.