Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Spice #543

Merged
merged 32 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9eade0e
add spice to reqs
biobootloader Mar 19, 2024
8f229d6
anthropic working on mentat
biobootloader Mar 22, 2024
18079f9
update check model
biobootloader Mar 22, 2024
51014d6
logging callback
biobootloader Mar 22, 2024
f5588f1
update
biobootloader Mar 23, 2024
2d111cf
interrupted logging
biobootloader Mar 27, 2024
d474ce2
spice errors
biobootloader Mar 30, 2024
d0c46c6
Merge branch 'main' into integrate-spice
biobootloader Mar 30, 2024
443fba1
spice embeddings
biobootloader Mar 30, 2024
d739320
spicewhisper
biobootloader Mar 31, 2024
f649ad7
remove
biobootloader Mar 31, 2024
e7d7844
comment
biobootloader Mar 31, 2024
5f1fe25
embedding cost logging
biobootloader Mar 31, 2024
ae3d9e1
cleanup
biobootloader Mar 31, 2024
f766d80
fixed
biobootloader Mar 31, 2024
528e2b9
ruff
biobootloader Mar 31, 2024
eeae387
convert spice error
biobootloader Apr 1, 2024
f9446b9
typing
biobootloader Apr 1, 2024
ebc9abf
get spice from pypi
biobootloader Apr 1, 2024
6026d0d
spice version upgrade
biobootloader Apr 2, 2024
da22b60
hmm
biobootloader Apr 2, 2024
91bc6eb
update spice
biobootloader Apr 2, 2024
93f192d
spice
biobootloader Apr 2, 2024
3f9a9e7
apache-2.0
biobootloader Apr 2, 2024
36c2a5a
mocking spice response
biobootloader Apr 3, 2024
d727903
Merge branch 'main' into integrate-spice
biobootloader Apr 3, 2024
52fd9ec
fix tests
biobootloader Apr 3, 2024
c2bd2fa
anotha one
biobootloader Apr 3, 2024
4bf16e1
ruff ruff
biobootloader Apr 3, 2024
798e42c
format
biobootloader Apr 3, 2024
acb114e
docs
biobootloader Apr 3, 2024
4e40c30
oops
biobootloader Apr 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
. .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: ruff (linter)
run: ruff check .
run: ruff check --select I .
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
- name: ruff (formatter)
run: ruff format --check .

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def grade(to_grade, prompt, model="gpt-4-1106-preview"):

llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
llm_grade = await llm_api_handler.call_llm_api(messages, model, False, ResponseFormat(type="json_object"))
content = llm_grade.choices[0].message.content
content = llm_grade.text
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
return json.loads(content)
except Exception as e:
return {"error": str(e)}
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/context_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from benchmarks.arg_parser import common_benchmark_parser
from benchmarks.run_sample import setup_sample
from benchmarks.swe_bench_runner import get_swe_samples, SWE_BENCH_SAMPLES_DIR
from benchmarks.swe_bench_runner import SWE_BENCH_SAMPLES_DIR, get_swe_samples
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
from mentat import Mentat
from mentat.config import Config
from mentat.sampler.sample import Sample
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/exercism_practice.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def failure_analysis(exercise_runner, language):
try:
llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
llm_grade = await llm_api_handler.call_llm_api(messages, model, False)
response = llm_grade.choices[0].message.content
response = llm_grade.text
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
except BadRequestError:
response = "Unable to analyze test case\nreason: too many tokens to analyze"

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/run_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from pathlib import Path
from typing import Any

from git import Repo
import tqdm
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
from git import Repo

from mentat import Mentat
from mentat.config import Config
from mentat.errors import SampleError
from mentat.git_handler import get_git_diff
from mentat.parsers.git_parser import GitParser
from mentat.sampler.sample import Sample
from mentat.sampler.utils import get_active_snapshot_commit, setup_repo, apply_diff_to_repo
from mentat.sampler.utils import apply_diff_to_repo, get_active_snapshot_commit, setup_repo
from mentat.session_context import SESSION_CONTEXT


Expand Down
5 changes: 2 additions & 3 deletions benchmarks/swe_bench_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from pathlib import Path
from typing import Any

from datasets import load_dataset, DatasetDict # type: ignore
from datasets import DatasetDict, load_dataset # type: ignore
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved

from mentat.sampler.sample import Sample
from benchmarks.run_sample import validate_test_fields

from mentat.sampler.sample import Sample

SWE_BENCH_SAMPLES_DIR = Path(__file__).parent / "benchmarks" / "swe_bench_samples"
SWE_VALIDATION_RESULTS_PATH = (
Expand Down
27 changes: 16 additions & 11 deletions docs/source/user/alternative_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@
🦙 Alternative Models
=====================

Azure
Anthropic's Claude 3
---------
To use Anthropic models, provide the :code:`ANTHROPIC_API_KEY` environment variable instead of :code:`OPENAI_API_KEY`, and set the model `claude-3-opus-20240229` in the :code:`.mentat_config.json` file:
granawkins marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: bash

# in ~/.mentat/.env
ANTHROPIC_API_KEY=sk-*************

# In ~/.mentat/.mentat_config.json
{ "model": "claude-3-opus-20240229" }

OpenAI models on Azure
-----

To use the Azure API, provide the :code:`AZURE_OPENAI_ENDPOINT` (:code:`https://<your-instance-name>.openai.azure.com/`) and :code:`AZURE_OPENAI_KEY` environment variables instead of :code:`OPENAI_API_KEY`.
Expand All @@ -13,31 +25,24 @@ In addition, Mentat uses the :code:`gpt-4-1106-preview` model by default. When u
.. warning::
Due to changes in the OpenAI Python SDK, you can no longer use :code:`OPENAI_API_BASE` to access the Azure API with Mentat.

Anthropic
Using Other Models
---------

Mentat uses the OpenAI SDK to retrieve chat completions. This means that setting the `OPENAI_API_BASE` environment variable is enough to use any model that has the same response schema as OpenAI. To use models with different response schemas, we recommend setting up a litellm proxy as described `here <https://docs.litellm.ai/docs/proxy/quick_start>`__ and pointing `OPENAI_API_BASE` to the proxy. For example with anthropic:
Mentat uses the OpenAI SDK to retrieve chat completions. This means that setting the `OPENAI_API_BASE` environment variable is enough to use any model that has the same response schema as OpenAI. To use models with different response schemas, we recommend setting up a litellm proxy as described `here <https://docs.litellm.ai/docs/proxy/quick_start>`__ and pointing `OPENAI_API_BASE` to the proxy. For example:

.. code-block:: bash

pip install 'litellm[proxy]'
export ANTHROPIC_API_KEY=sk-*************
litellm --model claude-3-opus-2024-0229 --drop_params
litellm --model huggingface/bigcode/starcoder --drop_params
# Should see: Uvicorn running on http://0.0.0.0:8000

.. code-block:: bash

# In ~/.mentat/.env
OPENAI_API_BASE=http://localhost:8000
# In ~/.mentat/.mentat_config.json
{ "model": "claude" }
# or
export OPENAI_API_BASE=http://localhost:8000
mentat

.. note::

Anthropic has slightly different requirements for system messages so you must set your model to a string with "claude" in it. Other than that it isn't important as the exact model is set by the litellm proxy server flag.

🦙 Local Models
---------------
Expand Down
4 changes: 2 additions & 2 deletions mentat/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def enable_agent_mode(self):
]
model = ctx.config.model
response = await ctx.llm_api_handler.call_llm_api(messages, model, False)
content = response.choices[0].message.content or ""
content = response.text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change to use response.text directly simplifies the code. Ensure that all instances where the LLM API response is processed are updated to this simpler approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching to response.text for processing the LLM response simplifies the code. Ensure that this change is consistently applied across all similar instances in the codebase.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching to response.text for processing the LLM response simplifies the code. Ensure that this change is consistently applied across all similar instances in the codebase.


paths = [Path(path) for path in content.strip().split("\n") if Path(path).exists()]
self.agent_file_message = ""
Expand Down Expand Up @@ -87,7 +87,7 @@ async def _determine_commands(self) -> List[str]:
ctx.stream.send(f"Error accessing OpenAI API: {e.message}", style="error")
return []

content = response.choices[0].message.content or ""
content = response.text

messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
parsed_llm_response = await ctx.config.parser.parse_llm_response(content)
Expand Down
20 changes: 9 additions & 11 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,18 @@ async def _stream_model_response(

stream.send("Streaming...\n")
async with stream.interrupt_catcher(parser.shutdown):
parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response))
parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response.stream()))
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved

# Sampler and History require previous_file_lines
for file_edit in parsed_llm_response.file_edits:
file_edit.previous_file_lines = code_file_manager.file_lines.get(file_edit.file_path, []).copy()
if not parsed_llm_response.interrupted:
cost_tracker.display_last_api_call()
else:
# Generator doesn't log the api call if we interrupt it
cost_tracker.log_api_call_stats(
num_prompt_tokens,
count_tokens(parsed_llm_response.full_response, config.model, full_message=False),
config.model,
display=True,
)

# TODO: this is janky come up with better solution
# if the stream was interrupted, then the finally block in the response.stream() async generator
# will wait for an opportunity to run. This sleep call gives it that opportunity.
# the finally block runs the logging callback
await asyncio.sleep(0.01)
cost_tracker.display_last_api_call()

messages.append(
ChatCompletionAssistantMessageParam(role="assistant", content=parsed_llm_response.full_response)
Expand Down
68 changes: 27 additions & 41 deletions mentat/cost_tracker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import logging
from dataclasses import dataclass
from timeit import default_timer
from typing import AsyncIterator, Optional

from openai.types.chat import ChatCompletionChunk
from spice import SpiceResponse

from mentat.llm_api_handler import count_tokens, model_price_per_1000_tokens
from mentat.llm_api_handler import model_price_per_1000_tokens
from mentat.session_context import SESSION_CONTEXT


Expand All @@ -18,37 +16,47 @@ class CostTracker:

def log_api_call_stats(
self,
num_prompt_tokens: int,
num_sampled_tokens: int,
model: str,
call_time: Optional[float] = None,
decimal_places: int = 2,
display: bool = False,
response: SpiceResponse,
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
decimal_places = 2

model = response.call_args.model
input_tokens = response.input_tokens
output_tokens = response.output_tokens
total_time = response.total_time

speed_and_cost_string = ""
self.total_tokens += num_prompt_tokens + num_sampled_tokens
if num_sampled_tokens > 0 and call_time is not None:
tokens_per_second = num_sampled_tokens / call_time
self.total_tokens += response.total_tokens
if output_tokens > 0:
tokens_per_second = output_tokens / total_time
speed_and_cost_string += f"Speed: {tokens_per_second:.{decimal_places}f} tkns/s"
cost = model_price_per_1000_tokens(model)
if cost:
prompt_cost = (num_prompt_tokens / 1000) * cost[0]
sampled_cost = (num_sampled_tokens / 1000) * cost[1]
prompt_cost = (input_tokens / 1000) * cost[0]
sampled_cost = (output_tokens / 1000) * cost[1]
call_cost = prompt_cost + sampled_cost
self.total_cost += call_cost
if speed_and_cost_string:
speed_and_cost_string += " | "
speed_and_cost_string += f"Cost: ${call_cost:.{decimal_places}f}"
if display:
stream.send(speed_and_cost_string, style="info")

costs_logger = logging.getLogger("costs")
costs_logger.info(speed_and_cost_string)
self.last_api_call = speed_and_cost_string

def log_embedding_call_stats(self, tokens: int, model: str, total_time: float):
cost = model_price_per_1000_tokens(model)
# TODO: handle unknown models better / port to spice
if cost is None:
return

cost = cost[0]
call_cost = (tokens / 1000) * cost
self.total_cost += call_cost
costs_logger = logging.getLogger("costs")
costs_logger.info(f"Cost: ${call_cost:.2f}")
self.last_api_call = f"Embedding call time and cost: {total_time:.2f}s, ${call_cost:.2f}"

def display_last_api_call(self):
"""
Used so that places that call the llm can print the api call stats after they finish printing everything else.
Expand All @@ -60,25 +68,3 @@ def display_last_api_call(self):

def log_whisper_call_stats(self, seconds: float):
self.total_cost += seconds * 0.0001

async def response_logger_wrapper(
self,
prompt_tokens: int,
response: AsyncIterator[ChatCompletionChunk],
model: str,
) -> AsyncIterator[ChatCompletionChunk]:
full_response = ""
start_time = default_timer()
async for chunk in response:
# On Azure OpenAI, the first chunk streamed may contain only metadata relating to content filtering.
if len(chunk.choices) == 0:
continue
full_response += chunk.choices[0].delta.content or ""
yield chunk
time_elapsed = default_timer() - start_time
self.log_api_call_stats(
prompt_tokens,
count_tokens(full_response, model, full_message=False),
model,
time_elapsed,
)
6 changes: 3 additions & 3 deletions mentat/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ async def get_feature_similarity_scores(
start_time = default_timer()
stream.send(None, channel="loading")
collection.add(embed_checksums, embed_texts)
cost_tracker.log_api_call_stats(
cost_tracker.log_embedding_call_stats(
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
sum(embed_tokens),
0,
embedding_model,
start_time - default_timer(),
default_timer() - start_time,
)
cost_tracker.display_last_api_call()

# Get similarity scores
stream.send(None, channel="loading", terminate=True)
Expand Down
15 changes: 2 additions & 13 deletions mentat/feature_filters/llm_feature_filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
from pathlib import Path
from timeit import default_timer
from typing import Optional, Set

from openai.types.chat import (
Expand All @@ -16,7 +15,7 @@
from mentat.feature_filters.feature_filter import FeatureFilter
from mentat.feature_filters.truncate_filter import TruncateFilter
from mentat.include_files import get_code_features_for_path
from mentat.llm_api_handler import count_tokens, model_context_size, prompt_tokens
from mentat.llm_api_handler import count_tokens, model_context_size
from mentat.prompts.prompts import read_prompt
from mentat.session_context import SESSION_CONTEXT

Expand All @@ -41,7 +40,6 @@ async def filter(
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
config = session_context.config
cost_tracker = session_context.cost_tracker
llm_api_handler = session_context.llm_api_handler

stream.send(None, channel="loading")
Expand Down Expand Up @@ -89,22 +87,13 @@ async def filter(
)
)
selected_refs = list[Path]()
start_time = default_timer()
llm_response = await llm_api_handler.call_llm_api(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the removal of detailed LLM API call logging, consider implementing a new mechanism to log or monitor these calls for debugging and performance analysis.

messages=messages,
model=model,
stream=False,
response_format=ResponseFormat(type="json_object"),
)
message = (llm_response.choices[0].message.content) or ""
tokens = prompt_tokens(messages, model)
response_tokens = count_tokens(message, model, full_message=True)
cost_tracker.log_api_call_stats(
tokens,
response_tokens,
model,
default_timer() - start_time,
)
message = llm_response.text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the removal of detailed LLM API call logging, consider implementing a new mechanism to log or monitor these calls for debugging and performance analysis.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the removal of detailed LLM API call logging, consider implementing a new mechanism to log or monitor these calls for debugging and performance analysis.

stream.send(None, channel="loading", terminate=True)

# Parse response into features
Expand Down
Loading
Loading