Skip to content

Commit

Permalink
Use Spice (#543)
Browse files Browse the repository at this point in the history
  • Loading branch information
biobootloader authored Apr 3, 2024
1 parent cc6716a commit f3cabe6
Show file tree
Hide file tree
Showing 22 changed files with 133 additions and 315 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
. .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: ruff (linter)
run: ruff check .
run: ruff check --select I .
- name: ruff (formatter)
run: ruff format --check .

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def grade(to_grade, prompt, model="gpt-4-1106-preview"):

llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
llm_grade = await llm_api_handler.call_llm_api(messages, model, False, ResponseFormat(type="json_object"))
content = llm_grade.choices[0].message.content
content = llm_grade.text
return json.loads(content)
except Exception as e:
return {"error": str(e)}
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/context_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from benchmarks.arg_parser import common_benchmark_parser
from benchmarks.run_sample import setup_sample
from benchmarks.swe_bench_runner import get_swe_samples, SWE_BENCH_SAMPLES_DIR
from benchmarks.swe_bench_runner import SWE_BENCH_SAMPLES_DIR, get_swe_samples
from mentat import Mentat
from mentat.config import Config
from mentat.sampler.sample import Sample
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/exercism_practice.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def failure_analysis(exercise_runner, language):
try:
llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
llm_grade = await llm_api_handler.call_llm_api(messages, model, False)
response = llm_grade.choices[0].message.content
response = llm_grade.text
except BadRequestError:
response = "Unable to analyze test case\nreason: too many tokens to analyze"

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/run_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from pathlib import Path
from typing import Any

from git import Repo
import tqdm
from git import Repo

from mentat import Mentat
from mentat.config import Config
from mentat.errors import SampleError
from mentat.git_handler import get_git_diff
from mentat.parsers.git_parser import GitParser
from mentat.sampler.sample import Sample
from mentat.sampler.utils import get_active_snapshot_commit, setup_repo, apply_diff_to_repo
from mentat.sampler.utils import apply_diff_to_repo, get_active_snapshot_commit, setup_repo
from mentat.session_context import SESSION_CONTEXT


Expand Down
5 changes: 2 additions & 3 deletions benchmarks/swe_bench_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from pathlib import Path
from typing import Any

from datasets import load_dataset, DatasetDict # type: ignore
from datasets import DatasetDict, load_dataset # type: ignore

from mentat.sampler.sample import Sample
from benchmarks.run_sample import validate_test_fields

from mentat.sampler.sample import Sample

SWE_BENCH_SAMPLES_DIR = Path(__file__).parent / "benchmarks" / "swe_bench_samples"
SWE_VALIDATION_RESULTS_PATH = (
Expand Down
27 changes: 16 additions & 11 deletions docs/source/user/alternative_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@
🦙 Alternative Models
=====================

Azure
Anthropic's Claude 3
---------
To use Anthropic models, provide the :code:`ANTHROPIC_API_KEY` environment variable instead of :code:`OPENAI_API_KEY`, and set the model `claude-3-opus-20240229` in the :code:`.mentat_config.json` file:

.. code-block:: bash
# in ~/.mentat/.env
ANTHROPIC_API_KEY=sk-*************
# In ~/.mentat/.mentat_config.json
{ "model": "claude-3-opus-20240229" }
OpenAI models on Azure
-----
To use the Azure API, provide the :code:`AZURE_OPENAI_ENDPOINT` (:code:`https://<your-instance-name>.openai.azure.com/`) and :code:`AZURE_OPENAI_KEY` environment variables instead of :code:`OPENAI_API_KEY`.
Expand All @@ -13,31 +25,24 @@ In addition, Mentat uses the :code:`gpt-4-1106-preview` model by default. When u
.. warning::
Due to changes in the OpenAI Python SDK, you can no longer use :code:`OPENAI_API_BASE` to access the Azure API with Mentat.
Anthropic
Using Other Models
---------
Mentat uses the OpenAI SDK to retrieve chat completions. This means that setting the `OPENAI_API_BASE` environment variable is enough to use any model that has the same response schema as OpenAI. To use models with different response schemas, we recommend setting up a litellm proxy as described `here <https://docs.litellm.ai/docs/proxy/quick_start>`__ and pointing `OPENAI_API_BASE` to the proxy. For example with anthropic:
Mentat uses the OpenAI SDK to retrieve chat completions. This means that setting the `OPENAI_API_BASE` environment variable is enough to use any model that has the same response schema as OpenAI. To use models with different response schemas, we recommend setting up a litellm proxy as described `here <https://docs.litellm.ai/docs/proxy/quick_start>`__ and pointing `OPENAI_API_BASE` to the proxy. For example:
.. code-block:: bash
pip install 'litellm[proxy]'
export ANTHROPIC_API_KEY=sk-*************
litellm --model claude-3-opus-2024-0229 --drop_params
litellm --model huggingface/bigcode/starcoder --drop_params
# Should see: Uvicorn running on http://0.0.0.0:8000
.. code-block:: bash
# In ~/.mentat/.env
OPENAI_API_BASE=http://localhost:8000
# In ~/.mentat/.mentat_config.json
{ "model": "claude" }
# or
export OPENAI_API_BASE=http://localhost:8000
mentat
.. note::
Anthropic has slightly different requirements for system messages so you must set your model to a string with "claude" in it. Other than that it isn't important as the exact model is set by the litellm proxy server flag.
🦙 Local Models
---------------
Expand Down
4 changes: 2 additions & 2 deletions mentat/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def enable_agent_mode(self):
]
model = ctx.config.model
response = await ctx.llm_api_handler.call_llm_api(messages, model, False)
content = response.choices[0].message.content or ""
content = response.text

paths = [Path(path) for path in content.strip().split("\n") if Path(path).exists()]
self.agent_file_message = ""
Expand Down Expand Up @@ -87,7 +87,7 @@ async def _determine_commands(self) -> List[str]:
ctx.stream.send(f"Error accessing OpenAI API: {e.message}", style="error")
return []

content = response.choices[0].message.content or ""
content = response.text

messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
parsed_llm_response = await ctx.config.parser.parse_llm_response(content)
Expand Down
20 changes: 9 additions & 11 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,18 @@ async def _stream_model_response(

stream.send("Streaming...\n")
async with stream.interrupt_catcher(parser.shutdown):
parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response))
parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response.stream()))

# Sampler and History require previous_file_lines
for file_edit in parsed_llm_response.file_edits:
file_edit.previous_file_lines = code_file_manager.file_lines.get(file_edit.file_path, []).copy()
if not parsed_llm_response.interrupted:
cost_tracker.display_last_api_call()
else:
# Generator doesn't log the api call if we interrupt it
cost_tracker.log_api_call_stats(
num_prompt_tokens,
count_tokens(parsed_llm_response.full_response, config.model, full_message=False),
config.model,
display=True,
)

# TODO: this is janky come up with better solution
# if the stream was interrupted, then the finally block in the response.stream() async generator
# will wait for an opportunity to run. This sleep call gives it that opportunity.
# the finally block runs the logging callback
await asyncio.sleep(0.01)
cost_tracker.display_last_api_call()

messages.append(
ChatCompletionAssistantMessageParam(role="assistant", content=parsed_llm_response.full_response)
Expand Down
68 changes: 27 additions & 41 deletions mentat/cost_tracker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import logging
from dataclasses import dataclass
from timeit import default_timer
from typing import AsyncIterator, Optional

from openai.types.chat import ChatCompletionChunk
from spice import SpiceResponse

from mentat.llm_api_handler import count_tokens, model_price_per_1000_tokens
from mentat.llm_api_handler import model_price_per_1000_tokens
from mentat.session_context import SESSION_CONTEXT


Expand All @@ -18,37 +16,47 @@ class CostTracker:

def log_api_call_stats(
self,
num_prompt_tokens: int,
num_sampled_tokens: int,
model: str,
call_time: Optional[float] = None,
decimal_places: int = 2,
display: bool = False,
response: SpiceResponse,
) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
decimal_places = 2

model = response.call_args.model
input_tokens = response.input_tokens
output_tokens = response.output_tokens
total_time = response.total_time

speed_and_cost_string = ""
self.total_tokens += num_prompt_tokens + num_sampled_tokens
if num_sampled_tokens > 0 and call_time is not None:
tokens_per_second = num_sampled_tokens / call_time
self.total_tokens += response.total_tokens
if output_tokens > 0:
tokens_per_second = output_tokens / total_time
speed_and_cost_string += f"Speed: {tokens_per_second:.{decimal_places}f} tkns/s"
cost = model_price_per_1000_tokens(model)
if cost:
prompt_cost = (num_prompt_tokens / 1000) * cost[0]
sampled_cost = (num_sampled_tokens / 1000) * cost[1]
prompt_cost = (input_tokens / 1000) * cost[0]
sampled_cost = (output_tokens / 1000) * cost[1]
call_cost = prompt_cost + sampled_cost
self.total_cost += call_cost
if speed_and_cost_string:
speed_and_cost_string += " | "
speed_and_cost_string += f"Cost: ${call_cost:.{decimal_places}f}"
if display:
stream.send(speed_and_cost_string, style="info")

costs_logger = logging.getLogger("costs")
costs_logger.info(speed_and_cost_string)
self.last_api_call = speed_and_cost_string

def log_embedding_call_stats(self, tokens: int, model: str, total_time: float):
cost = model_price_per_1000_tokens(model)
# TODO: handle unknown models better / port to spice
if cost is None:
return

cost = cost[0]
call_cost = (tokens / 1000) * cost
self.total_cost += call_cost
costs_logger = logging.getLogger("costs")
costs_logger.info(f"Cost: ${call_cost:.2f}")
self.last_api_call = f"Embedding call time and cost: {total_time:.2f}s, ${call_cost:.2f}"

def display_last_api_call(self):
"""
Used so that places that call the llm can print the api call stats after they finish printing everything else.
Expand All @@ -60,25 +68,3 @@ def display_last_api_call(self):

def log_whisper_call_stats(self, seconds: float):
self.total_cost += seconds * 0.0001

async def response_logger_wrapper(
self,
prompt_tokens: int,
response: AsyncIterator[ChatCompletionChunk],
model: str,
) -> AsyncIterator[ChatCompletionChunk]:
full_response = ""
start_time = default_timer()
async for chunk in response:
# On Azure OpenAI, the first chunk streamed may contain only metadata relating to content filtering.
if len(chunk.choices) == 0:
continue
full_response += chunk.choices[0].delta.content or ""
yield chunk
time_elapsed = default_timer() - start_time
self.log_api_call_stats(
prompt_tokens,
count_tokens(full_response, model, full_message=False),
model,
time_elapsed,
)
6 changes: 3 additions & 3 deletions mentat/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ async def get_feature_similarity_scores(
start_time = default_timer()
stream.send(None, channel="loading")
collection.add(embed_checksums, embed_texts)
cost_tracker.log_api_call_stats(
cost_tracker.log_embedding_call_stats(
sum(embed_tokens),
0,
embedding_model,
start_time - default_timer(),
default_timer() - start_time,
)
cost_tracker.display_last_api_call()

# Get similarity scores
stream.send(None, channel="loading", terminate=True)
Expand Down
15 changes: 2 additions & 13 deletions mentat/feature_filters/llm_feature_filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
from pathlib import Path
from timeit import default_timer
from typing import Optional, Set

from openai.types.chat import (
Expand All @@ -16,7 +15,7 @@
from mentat.feature_filters.feature_filter import FeatureFilter
from mentat.feature_filters.truncate_filter import TruncateFilter
from mentat.include_files import get_code_features_for_path
from mentat.llm_api_handler import count_tokens, model_context_size, prompt_tokens
from mentat.llm_api_handler import count_tokens, model_context_size
from mentat.prompts.prompts import read_prompt
from mentat.session_context import SESSION_CONTEXT

Expand All @@ -41,7 +40,6 @@ async def filter(
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
config = session_context.config
cost_tracker = session_context.cost_tracker
llm_api_handler = session_context.llm_api_handler

stream.send(None, channel="loading")
Expand Down Expand Up @@ -89,22 +87,13 @@ async def filter(
)
)
selected_refs = list[Path]()
start_time = default_timer()
llm_response = await llm_api_handler.call_llm_api(
messages=messages,
model=model,
stream=False,
response_format=ResponseFormat(type="json_object"),
)
message = (llm_response.choices[0].message.content) or ""
tokens = prompt_tokens(messages, model)
response_tokens = count_tokens(message, model, full_message=True)
cost_tracker.log_api_call_stats(
tokens,
response_tokens,
model,
default_timer() - start_time,
)
message = llm_response.text
stream.send(None, channel="loading", terminate=True)

# Parse response into features
Expand Down
Loading

0 comments on commit f3cabe6

Please sign in to comment.