-
Notifications
You must be signed in to change notification settings - Fork 244
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bbf9e3b
commit f7d459b
Showing
6 changed files
with
127 additions
and
215 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,193 +1,106 @@ | ||
#!/usr/bin/env python | ||
import asyncio | ||
import json | ||
import os | ||
from collections import defaultdict | ||
from itertools import islice | ||
from datetime import datetime | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
from git import Repo | ||
|
||
from benchmarks.arg_parser import common_benchmark_parser | ||
from mentat.code_context import CodeContext | ||
from mentat.code_feature import CodeFeature | ||
from mentat.code_file_manager import CodeFileManager | ||
from benchmarks.run_sample import setup_sample | ||
from benchmarks.swe_bench_runner import get_swe_samples, SWE_BENCH_SAMPLES_DIR | ||
from mentat import Mentat | ||
from mentat.config import Config | ||
from mentat.cost_tracker import CostTracker | ||
from mentat.llm_api_handler import count_tokens, model_context_size | ||
from mentat.sampler.utils import clone_repo | ||
from mentat.session_context import SESSION_CONTEXT, SessionContext | ||
|
||
|
||
class MockStream: | ||
def send(self, message, **kwargs): | ||
end = kwargs.get("end", "\n") | ||
print(message, end=end) | ||
|
||
|
||
def _load_benchmarks() -> dict[str, dict[str, Any]]: | ||
"""Load all benchmarks found in benchmark_repos""" | ||
benchmarks = {} | ||
benchmarks_dir = Path(__file__).parent / "../benchmark_repos" | ||
for repo_dir in benchmarks_dir.iterdir(): | ||
benchmarks_path = repo_dir / "benchmarks.json" | ||
if benchmarks_path.exists(): | ||
with open(benchmarks_path, "r") as f: | ||
benchmarks.update(json.load(f)) | ||
return benchmarks | ||
|
||
|
||
def _convert_features_to_line_sets(git_root: Path, features: list[CodeFeature]) -> defaultdict[set]: | ||
"""Convert a list of features to a dict of {path: set(lines)} for comparison""" | ||
lines = defaultdict(set) | ||
for feature in features: | ||
# Non-explicit features (e.g. CodeMaps) are considered false positives. | ||
# Using negative numbers here as that affect. | ||
|
||
path = feature.path.relative_to(git_root) | ||
interval = feature.interval | ||
lines[path].update(range(interval.start, interval.end + 1)) | ||
return lines | ||
|
||
|
||
def evaluate( | ||
git_root: Path, | ||
actual: list[CodeFeature], | ||
expected: list[CodeFeature], | ||
) -> dict[str, float]: | ||
"""Compare two lists of features and return precision, recall and f1 scores""" | ||
actual_lines = _convert_features_to_line_sets(git_root, actual) | ||
expected_lines = _convert_features_to_line_sets(git_root, expected) | ||
|
||
_TP, _FP, _FN = 0, 0, 0 | ||
for file in actual_lines | expected_lines: | ||
actual_set = actual_lines[file] | ||
expected_set = expected_lines[file] | ||
_TP += len(actual_set & expected_set) | ||
_FP += len(actual_set - expected_set) | ||
_FN += len(expected_set - actual_set) | ||
from mentat.sampler.sample import Sample | ||
from mentat.session_context import SESSION_CONTEXT | ||
|
||
precision, recall, f1 = None, None, None | ||
if (_TP + _FP) > 0: | ||
precision = _TP / (_TP + _FP) | ||
if (_TP + _FN) > 0: | ||
recall = _TP / (_TP + _FN) | ||
if precision and recall: | ||
f1 = 2 * precision * recall / (precision + recall) | ||
|
||
return {"precision": precision, "recall": recall, "f1": f1} | ||
def _score(predicted: set[Path], expected: set[Path]) -> dict[str, Any]: | ||
true_positives = predicted.intersection(expected) | ||
false_positives = predicted.difference(expected) | ||
false_negatives = expected.difference(predicted) | ||
precision = len(true_positives) / (len(true_positives) + len(false_positives)) | ||
recall = len(true_positives) / (len(true_positives) + len(false_negatives)) | ||
return {"precision": precision, "recall": recall} | ||
|
||
|
||
async def select_features_for_benchmark(session_context, benchmark, eval=True, use_expected=False, use_llm=True): | ||
"""Select features for benchmark using expected edits as a guide""" | ||
git_root = session_context.git_root | ||
config = session_context.config | ||
parser = config.parser | ||
code_context = session_context.code_context | ||
async def run_auto_context_benchmark( | ||
sample: Sample, config: Config, cwd: Path | str | None = None, include_context: bool = False | ||
) -> dict[str, Any]: | ||
"""Run a sample using Mentat and return the resulting diff""" | ||
starting_dir = Path.cwd() | ||
|
||
# The longest context that could have been included to generate expected_edits | ||
model = config.model | ||
mentat_prompt_tokens = count_tokens(parser.get_system_prompt(), model) | ||
expected_edits, expected_edits_tokens = None, 0 | ||
if use_expected: | ||
expected_edits = benchmark["expected_edits"] | ||
expected_edits_tokens = count_tokens(expected_edits, model) | ||
max_context_tokens = model_context_size(model) - mentat_prompt_tokens - expected_edits_tokens | ||
# Fill-in available context | ||
config.auto_context_tokens = 8000 | ||
code_context.use_llm = use_llm | ||
await code_context.get_code_message(benchmark["prompt"], max_context_tokens, expected_edits) | ||
git_root_length = len(str(git_root)) + 1 | ||
selected_features = [f.ref()[git_root_length:] for f in code_context.features] | ||
|
||
selector_performance = {} | ||
if eval: | ||
edited_features = [CodeFeature(git_root / f) for f in benchmark["edited_features"]] | ||
selector_performance = evaluate(git_root, code_context.features, edited_features) | ||
return {"features": selected_features, "score": selector_performance} | ||
|
||
|
||
async def test_code_context_performance(benchmarks, max_benchmarks=10): | ||
"""Run a set of benchmarks and evaluate performance | ||
Run standalone: | ||
`./benchmarks/context_benchmark.py` | ||
""" | ||
# Load applicable benchmarks | ||
all_benchmarks = _load_benchmarks() | ||
if len(benchmarks) > 0: | ||
benchmarks_to_run = {k: v for k, v in all_benchmarks.items() if k in benchmarks} | ||
else: | ||
benchmarks_to_run = dict(islice(all_benchmarks.items(), max_benchmarks)) | ||
|
||
# Run each one | ||
scores = {} | ||
for benchmark in benchmarks_to_run.values(): | ||
print("\n" + benchmark["prompt"]) | ||
|
||
# Setup the cwd the same way as in generate | ||
url = benchmark["codebase_url"] | ||
codebase = clone_repo(url=url, local_dir_name=url.split("/")[-1], refresh=False) | ||
os.chdir(codebase) | ||
repo = Repo(".") | ||
repo.git.checkout(benchmark["commit"]) | ||
|
||
# Initialize a full SESSION_CONTEXT | ||
stream = MockStream() | ||
config = Config() | ||
code_context = CodeContext(stream, os.getcwd()) | ||
session_context = SessionContext( | ||
stream, | ||
CostTracker(), | ||
Path.cwd(), | ||
config, | ||
code_context, | ||
CodeFileManager(), | ||
None, | ||
if not config.auto_context_tokens or not sample.context: | ||
raise ValueError( | ||
"In order to run the auto-context benchmark, sample.context must not " | ||
"be empty (ground truth) and config.auto_context_tokens must be > 0." | ||
) | ||
SESSION_CONTEXT.set(session_context) | ||
|
||
# Run the benchmark and print results | ||
scores = [] | ||
for use_llm in [False, True]: | ||
for use_expected in [False, True]: | ||
try: | ||
if not use_llm and use_expected: | ||
continue # Not relevant | ||
results = await select_features_for_benchmark( | ||
session_context, | ||
benchmark, | ||
eval=True, | ||
use_expected=use_expected, | ||
use_llm=use_llm, | ||
) | ||
score = { | ||
**results["score"], | ||
"selected_features": results["features"], | ||
"edited_features": benchmark["edited_features"], | ||
"use_llm": use_llm, | ||
"use_expected": use_expected, | ||
} | ||
scores.append(score) | ||
print( | ||
f" UseExpected={use_expected}\t" | ||
f"| LLM={use_llm}\t" | ||
f"| Recall={(score['recall'] or 0.):.3f}\t" | ||
f"| Precision={(score['precision'] or 0.):.3f}" | ||
) | ||
except Exception as e: | ||
print(f"Error: '{e}'; skipping") | ||
|
||
return scores | ||
paths = [] if not include_context else [Path(a) for a in sample.context] | ||
|
||
try: | ||
_, cwd, _, _ = setup_sample(sample, None, skip_test_exec=True) | ||
exclude_paths = [cwd / ".venv"] | ||
mentat = Mentat(cwd=cwd, paths=paths, exclude_paths=exclude_paths, config=config or Config()) | ||
await mentat.startup() | ||
await asyncio.sleep(0.01) # Required to initialize llm_api_handler for embeddings | ||
|
||
# TODO: If there's a conversation history, we might consider the cumulative context. | ||
# Setup a mock for the LLM response and run the conversation until this point. | ||
code_context = SESSION_CONTEXT.get().code_context | ||
_ = await code_context.get_code_message(0, sample.message_prompt) | ||
predicted = set(path.relative_to(cwd) for path in code_context.include_files.keys()) | ||
actual = {Path(a) for a in sample.context} | ||
score = _score(predicted, actual) | ||
|
||
await mentat.shutdown() | ||
return score | ||
finally: | ||
os.chdir(starting_dir) | ||
|
||
|
||
def main(user_samples: list[str], directory: str): | ||
# Load benchmarks | ||
dir_path = Path(directory).resolve() | ||
assert dir_path.exists(), f"Invalid directory: {directory}" | ||
print(f"Running benchmarks from {dir_path}") | ||
samples: list[Sample] = [] | ||
for root, dirs, files in os.walk(dir_path): | ||
for file in files: | ||
path = Path(root) / file | ||
if file.endswith(".json"): | ||
sample = Sample.load(path) | ||
else: | ||
continue | ||
if user_samples and not any(s in sample.title for s in user_samples): | ||
continue | ||
samples.append(sample) | ||
print("Found Samples:\n" + "\n".join(s.title for s in samples)) | ||
print("*" * 80) | ||
|
||
config = Config(auto_context_tokens=8000) | ||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") | ||
results_path = dir_path / f"context_benchmark_results_{timestamp}.jsonl" | ||
for sample in samples: | ||
print(f"Running benchmark for {sample.title}") | ||
accuracy = asyncio.run(run_auto_context_benchmark(sample, config, cwd=dir_path)) | ||
print(f"Results: {accuracy}") | ||
print("*" * 80) | ||
with open(results_path, "a") as f: | ||
f.write(json.dumps({sample.id: accuracy}) + "\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = common_benchmark_parser() | ||
args = parser.parse_args() | ||
asyncio.run( | ||
test_code_context_performance( | ||
args.benchmarks, | ||
args.max_benchmarks, | ||
) | ||
if args.swe_bench: | ||
if args.swe_bench not in {"dev", "train", "test"}: | ||
print("Invalid SWE-Bench split.") | ||
exit(1) | ||
# Download and save SWE benchmarks as Samples | ||
samples = get_swe_samples(args.swe_bench, args.max_benchmarks) | ||
sample_titles = [sample.title for sample in samples] | ||
args.benchmarks = sample_titles | ||
args.directory = SWE_BENCH_SAMPLES_DIR / args.swe_bench | ||
main( | ||
args.benchmarks, | ||
args.directory, | ||
) |
Oops, something went wrong.