Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swe bench results #549

Merged
merged 10 commits into from
Apr 2, 2024
2 changes: 2 additions & 0 deletions benchmarks/benchmark_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class BenchmarkResult:
missing_functionality: Optional[bool] = attr.ib(default=None, metadata={"aggregation": "percent"})
extra_functionality: Optional[bool] = attr.ib(default=None, metadata={"aggregation": "percent"})
referenced_format: Optional[bool] = attr.ib(default=None, metadata={"aggregation": "percent"})
test_eval_results: Optional[dict] = attr.ib(default=None, metadata={"display": "json"})
granawkins marked this conversation as resolved.
Show resolved Hide resolved
test_eval_passed: Optional[bool] = attr.ib(default=None, metadata={"aggregation": "percent"})

def display_color(self) -> str:
if self.passed is None:
Expand Down
11 changes: 8 additions & 3 deletions benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,12 @@ async def run(self, retries: int = 1) -> list[BenchmarkResult]:
family=formatted_title,
)
try:
sample_result = await run_sample(sample)
sample_result = await run_sample(sample, config=self.config)
result.cost = sample_result["cost"]
result.tokens = sample_result["tokens"]
result.transcript = sample_result["transcript"]
result.test_eval_results = sample_result["test_eval_results"]
granawkins marked this conversation as resolved.
Show resolved Hide resolved
result.test_eval_passed = sample_result["test_eval_passed"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a try-except block around benchmark.run(retries=retries) to catch and handle any exceptions that might occur during the execution of a benchmark. This would allow the benchmarking process to continue even if one benchmark fails, improving the robustness of the benchmark runner.

Suggested change
result.test_eval_passed = sample_result["test_eval_passed"]
try:
result = asyncio.run(benchmark.run(retries=retries))
with open(results_cache, "a") as f:
for r in result:
total_cost += r.cost if r.cost else 0.0
f.write(r.to_json() + "
")
except Exception as e:
print(f"Error running benchmark {benchmark.title}: {e}")

Handling exceptions in this manner ensures that the benchmarking process is not halted due to a single failure, providing a more resilient and user-friendly experience.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the loop that processes each benchmark, it would be useful to log the total number of benchmarks processed and how many were skipped due to reaching the max_benchmarks limit. This information could be valuable for users to understand the scope of the benchmark run and ensure transparency.

Suggested change
result.test_eval_passed = sample_result["test_eval_passed"]
print(f"Processed {i+1} benchmarks. Skipped {len(benchmarks) - i - 1} benchmarks due to max_benchmarks limit.")

Including such logging can enhance the user's understanding of the benchmarking process and provide insights into the execution flow.

if self.verify is not None:
result.verify = self.verify()

Expand All @@ -251,7 +253,7 @@ def benchmark_listed(title, benchmarks):
return False


def run_benchmarks(user_benchmarks: list[str], directory: str, retries: int = 1):
def run_benchmarks(user_benchmarks: list[str], directory: str, retries: int = 1, max_benchmarks: int | None = None):
# Load benchmarks
dir_path = Path(directory).resolve()
assert dir_path.exists(), f"Invalid directory: {directory}"
Expand All @@ -277,7 +279,9 @@ def run_benchmarks(user_benchmarks: list[str], directory: str, retries: int = 1)
results_cache = dir_path / f"benchmark_results_cache_{uuid4()}.jsonl"
results_cache.touch()
total_cost = 0.0
for benchmark in benchmarks:
for i, benchmark in enumerate(benchmarks):
if max_benchmarks and i >= max_benchmarks:
break
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
granawkins marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the loop that processes each benchmark, it would be useful to log the total number of benchmarks processed and how many were skipped due to reaching the max_benchmarks limit. This information could be valuable for users to understand the scope of the benchmark run and ensure transparency.

Suggested change
break
print(f"Processed {i+1} benchmarks. Skipped {len(benchmarks) - i - 1} benchmarks due to max_benchmarks limit.")

Including such logging can enhance the user's understanding of the benchmarking process and provide insights into the execution flow.

# Run benchmark.run() with timeout
try:
result = asyncio.run(benchmark.run(retries=retries))
Expand Down Expand Up @@ -328,4 +332,5 @@ def run_benchmarks(user_benchmarks: list[str], directory: str, retries: int = 1)
args.benchmarks,
args.directory,
args.retries,
args.max_benchmarks,
)
259 changes: 86 additions & 173 deletions benchmarks/context_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,193 +1,106 @@
#!/usr/bin/env python
import asyncio
import json
import os
from collections import defaultdict
from itertools import islice
from datetime import datetime
from pathlib import Path
from typing import Any

from git import Repo

from benchmarks.arg_parser import common_benchmark_parser
from mentat.code_context import CodeContext
from mentat.code_feature import CodeFeature
from mentat.code_file_manager import CodeFileManager
from benchmarks.run_sample import setup_sample
from benchmarks.swe_bench_runner import get_swe_samples, SWE_BENCH_SAMPLES_DIR
from mentat import Mentat
from mentat.config import Config
from mentat.cost_tracker import CostTracker
from mentat.llm_api_handler import count_tokens, model_context_size
from mentat.sampler.utils import clone_repo
from mentat.session_context import SESSION_CONTEXT, SessionContext


class MockStream:
def send(self, message, **kwargs):
end = kwargs.get("end", "\n")
print(message, end=end)


def _load_benchmarks() -> dict[str, dict[str, Any]]:
"""Load all benchmarks found in benchmark_repos"""
benchmarks = {}
benchmarks_dir = Path(__file__).parent / "../benchmark_repos"
for repo_dir in benchmarks_dir.iterdir():
benchmarks_path = repo_dir / "benchmarks.json"
if benchmarks_path.exists():
with open(benchmarks_path, "r") as f:
benchmarks.update(json.load(f))
return benchmarks


def _convert_features_to_line_sets(git_root: Path, features: list[CodeFeature]) -> defaultdict[set]:
"""Convert a list of features to a dict of {path: set(lines)} for comparison"""
lines = defaultdict(set)
for feature in features:
# Non-explicit features (e.g. CodeMaps) are considered false positives.
# Using negative numbers here as that affect.

path = feature.path.relative_to(git_root)
interval = feature.interval
lines[path].update(range(interval.start, interval.end + 1))
return lines


def evaluate(
git_root: Path,
actual: list[CodeFeature],
expected: list[CodeFeature],
) -> dict[str, float]:
"""Compare two lists of features and return precision, recall and f1 scores"""
actual_lines = _convert_features_to_line_sets(git_root, actual)
expected_lines = _convert_features_to_line_sets(git_root, expected)

_TP, _FP, _FN = 0, 0, 0
for file in actual_lines | expected_lines:
actual_set = actual_lines[file]
expected_set = expected_lines[file]
_TP += len(actual_set & expected_set)
_FP += len(actual_set - expected_set)
_FN += len(expected_set - actual_set)
from mentat.sampler.sample import Sample
from mentat.session_context import SESSION_CONTEXT

precision, recall, f1 = None, None, None
if (_TP + _FP) > 0:
precision = _TP / (_TP + _FP)
if (_TP + _FN) > 0:
recall = _TP / (_TP + _FN)
if precision and recall:
f1 = 2 * precision * recall / (precision + recall)

return {"precision": precision, "recall": recall, "f1": f1}
def _score(predicted: set[Path], expected: set[Path]) -> dict[str, Any]:
true_positives = predicted.intersection(expected)
false_positives = predicted.difference(expected)
false_negatives = expected.difference(predicted)
precision = len(true_positives) / (len(true_positives) + len(false_positives))
recall = len(true_positives) / (len(true_positives) + len(false_negatives))
return {"precision": precision, "recall": recall}


async def select_features_for_benchmark(session_context, benchmark, eval=True, use_expected=False, use_llm=True):
"""Select features for benchmark using expected edits as a guide"""
git_root = session_context.git_root
config = session_context.config
parser = config.parser
code_context = session_context.code_context
async def run_auto_context_benchmark(
sample: Sample, config: Config, cwd: Path | str | None = None, include_context: bool = False
) -> dict[str, Any]:
"""Run a sample using Mentat and return the resulting diff"""
starting_dir = Path.cwd()

# The longest context that could have been included to generate expected_edits
model = config.model
mentat_prompt_tokens = count_tokens(parser.get_system_prompt(), model)
expected_edits, expected_edits_tokens = None, 0
if use_expected:
expected_edits = benchmark["expected_edits"]
expected_edits_tokens = count_tokens(expected_edits, model)
max_context_tokens = model_context_size(model) - mentat_prompt_tokens - expected_edits_tokens
# Fill-in available context
config.auto_context_tokens = 8000
code_context.use_llm = use_llm
await code_context.get_code_message(benchmark["prompt"], max_context_tokens, expected_edits)
git_root_length = len(str(git_root)) + 1
selected_features = [f.ref()[git_root_length:] for f in code_context.features]

selector_performance = {}
if eval:
edited_features = [CodeFeature(git_root / f) for f in benchmark["edited_features"]]
selector_performance = evaluate(git_root, code_context.features, edited_features)
return {"features": selected_features, "score": selector_performance}


async def test_code_context_performance(benchmarks, max_benchmarks=10):
"""Run a set of benchmarks and evaluate performance

Run standalone:
`./benchmarks/context_benchmark.py`
"""
# Load applicable benchmarks
all_benchmarks = _load_benchmarks()
if len(benchmarks) > 0:
benchmarks_to_run = {k: v for k, v in all_benchmarks.items() if k in benchmarks}
else:
benchmarks_to_run = dict(islice(all_benchmarks.items(), max_benchmarks))

# Run each one
scores = {}
for benchmark in benchmarks_to_run.values():
print("\n" + benchmark["prompt"])

# Setup the cwd the same way as in generate
url = benchmark["codebase_url"]
codebase = clone_repo(url=url, local_dir_name=url.split("/")[-1], refresh=False)
os.chdir(codebase)
repo = Repo(".")
repo.git.checkout(benchmark["commit"])

# Initialize a full SESSION_CONTEXT
stream = MockStream()
config = Config()
code_context = CodeContext(stream, os.getcwd())
session_context = SessionContext(
stream,
CostTracker(),
Path.cwd(),
config,
code_context,
CodeFileManager(),
None,
if not config.auto_context_tokens or not sample.context:
raise ValueError(
"In order to run the auto-context benchmark, sample.context must not "
"be empty (ground truth) and config.auto_context_tokens must be > 0."
)
SESSION_CONTEXT.set(session_context)

# Run the benchmark and print results
scores = []
for use_llm in [False, True]:
for use_expected in [False, True]:
try:
if not use_llm and use_expected:
continue # Not relevant
results = await select_features_for_benchmark(
session_context,
benchmark,
eval=True,
use_expected=use_expected,
use_llm=use_llm,
)
score = {
**results["score"],
"selected_features": results["features"],
"edited_features": benchmark["edited_features"],
"use_llm": use_llm,
"use_expected": use_expected,
}
scores.append(score)
print(
f" UseExpected={use_expected}\t"
f"| LLM={use_llm}\t"
f"| Recall={(score['recall'] or 0.):.3f}\t"
f"| Precision={(score['precision'] or 0.):.3f}"
)
except Exception as e:
print(f"Error: '{e}'; skipping")

return scores
paths = [] if not include_context else [Path(a) for a in sample.context]

try:
_, cwd, _, _ = setup_sample(sample, None, skip_test_exec=True)
exclude_paths = [cwd / ".venv"]
mentat = Mentat(cwd=cwd, paths=paths, exclude_paths=exclude_paths, config=config or Config())
await mentat.startup()
await asyncio.sleep(0.01) # Required to initialize llm_api_handler for embeddings

# TODO: If there's a conversation history, we might consider the cumulative context.
# Setup a mock for the LLM response and run the conversation until this point.
code_context = SESSION_CONTEXT.get().code_context
_ = await code_context.get_code_message(0, sample.message_prompt)
predicted = set(path.relative_to(cwd) for path in code_context.include_files.keys())
actual = {Path(a) for a in sample.context}
score = _score(predicted, actual)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the potentially long duration of validate_swe_samples, consider providing users with an option to skip this step, especially if they are re-running benchmarks and are confident in the validity of their samples. This could save a significant amount of time in certain scenarios.

Suggested change
actual = {Path(a) for a in sample.context}
if not args.skip_validation:
validate_swe_samples()

Introducing a command-line argument like --skip-validation could offer users more control over the benchmarking process, making it more flexible and user-friendly.


await mentat.shutdown()
return score
finally:
os.chdir(starting_dir)


def main(user_samples: list[str], directory: str):
# Load benchmarks
dir_path = Path(directory).resolve()
assert dir_path.exists(), f"Invalid directory: {directory}"
print(f"Running benchmarks from {dir_path}")
samples: list[Sample] = []
for root, dirs, files in os.walk(dir_path):
for file in files:
path = Path(root) / file
if file.endswith(".json"):
sample = Sample.load(path)
else:
continue
if user_samples and not any(s in sample.title for s in user_samples):
continue
samples.append(sample)
print("Found Samples:\n" + "\n".join(s.title for s in samples))
print("*" * 80)

config = Config(auto_context_tokens=8000)
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
results_path = dir_path / f"context_benchmark_results_{timestamp}.jsonl"
for sample in samples:
print(f"Running benchmark for {sample.title}")
accuracy = asyncio.run(run_auto_context_benchmark(sample, config, cwd=dir_path))
print(f"Results: {accuracy}")
print("*" * 80)
with open(results_path, "a") as f:
f.write(json.dumps({sample.id: accuracy}) + "\n")


if __name__ == "__main__":
parser = common_benchmark_parser()
args = parser.parse_args()
asyncio.run(
test_code_context_performance(
args.benchmarks,
args.max_benchmarks,
)
if args.swe_bench:
if args.swe_bench not in {"dev", "train", "test"}:
print("Invalid SWE-Bench split.")
exit(1)
# Download and save SWE benchmarks as Samples
samples = get_swe_samples(args.swe_bench, args.max_benchmarks)
sample_titles = [sample.title for sample in samples]
args.benchmarks = sample_titles
args.directory = SWE_BENCH_SAMPLES_DIR / args.swe_bench
main(
args.benchmarks,
args.directory,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the potentially long duration of validate_swe_samples, consider providing users with an option to skip this step, especially if they are re-running benchmarks and are confident in the validity of their samples. This could save a significant amount of time in certain scenarios.

Suggested change
args.directory,
if not args.skip_validation:
validate_swe_samples()

Introducing a command-line argument like --skip-validation could offer users more control over the benchmarking process, making it more flexible and user-friendly.

Loading
Loading