Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue #76 #77 #78 and other known bugs #79

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Available options include:
| `--dataset-name` | str | Name of the Huggingface dataset | `wentingzhao/commit0_combined` |
| `--dataset-split` | str | Split of the Huggingface dataset | `test` |
| `--base-dir` | str | Base directory to clone repos to | `repos/` |
| `--commit0-dot-file-path` | str | Storing path for stateful commit0 configs | `.commit0.yaml` |
| `--commit0-config-file` | str | Storing path for stateful commit0 configs | `.commit0.yaml` |

### Build

Expand All @@ -64,7 +64,7 @@ Available options include:
| Argument | Type | Description | Default |
|----------|------|-------------|---------|
| `--num-workers` | int | Number of workers | `8` |
| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` |
| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` |
| `--verbose` | int | Verbosity level (1 or 2) | `1` |

### Get Tests
Expand All @@ -91,7 +91,7 @@ Available options include:
| `--reference` | bool | Test the reference commit | `False` |
| `--coverage` | bool | Get coverage information | `False` |
| `--rebuild` | bool | Rebuild an image | `False` |
| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` |
| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` |
| `--verbose` | int | Verbosity level (1 or 2) | `1` |
| `--stdin` | bool | Read test names from stdin | `False` |

Expand All @@ -109,7 +109,7 @@ Available options include:
| `--num-workers` | int | Number of workers to use | `8` |
| `--reference` | bool | Evaluate the reference commit | `False` |
| `--coverage` | bool | Get coverage information | `False` |
| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` |
| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` |
| `--rebuild` | bool | Rebuild images | `False` |

### Lint
Expand All @@ -121,7 +121,7 @@ Available options include:
|----------|------|-------------|---------|
| `repo_or_repo_dir` | str | Directory of the repository to test | |
| `--files` | List[Path] | Files to lint (optional) | |
| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` |
| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` |
| `--verbose` | int | Verbosity level (1 or 2) | `1` |

### Save
Expand All @@ -134,7 +134,7 @@ Available options include:
| `owner` | str | Owner of the repository | |
| `branch` | str | Branch to save | |
| `--github-token` | str | GitHub token for authentication | |
| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` |
| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` |

## Agent

Expand Down
128 changes: 106 additions & 22 deletions agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,9 @@ def get_target_edit_files(
local_repo: git.Repo,
src_dir: str,
test_dir: str,
latest_commit: str,
branch: str,
reference_commit: str,
use_topo_sort_dependencies: bool = True,
) -> tuple[list[str], dict]:
"""Find the files with functions with the pass statement."""
target_dir = str(local_repo.working_dir)
Expand Down Expand Up @@ -269,7 +270,7 @@ def get_target_edit_files(
), "all files should be included"

# change to latest commit
local_repo.git.checkout(latest_commit)
local_repo.git.checkout(branch)

# Remove the base_dir prefix
topological_sort_files = [
Expand All @@ -282,35 +283,88 @@ def get_target_edit_files(
key_without_prefix = key.replace(target_dir, "").lstrip("/")
value_without_prefix = [v.replace(target_dir, "").lstrip("/") for v in value]
import_dependencies_without_prefix[key_without_prefix] = value_without_prefix
if use_topo_sort_dependencies:
return topological_sort_files, import_dependencies_without_prefix
else:
filtered_files = [
file.replace(target_dir, "").lstrip("/") for file in filtered_files
]
return filtered_files, import_dependencies_without_prefix


def get_target_edit_files_from_patch(
local_repo: git.Repo, patch: str, use_topo_sort_dependencies: bool = True
) -> tuple[list[str], dict]:
"""Get the target files from the patch."""
working_dir = str(local_repo.working_dir)
target_files = set()
for line in patch.split("\n"):
if line.startswith("+++") or line.startswith("---"):
file_path = line.split()[1]
if file_path.startswith("a/"):
file_path = file_path[2:]
if file_path.startswith("b/"):
file_path = file_path[2:]
target_files.add(file_path)

target_files_list = list(target_files)
target_files_list = [
os.path.join(working_dir, file_path) for file_path in target_files_list
]

return topological_sort_files, import_dependencies_without_prefix
if use_topo_sort_dependencies:
topological_sort_files, import_dependencies = (
topological_sort_based_on_dependencies(target_files_list)
)
if len(topological_sort_files) != len(target_files_list):
if len(topological_sort_files) < len(target_files_list):
missing_files = set(target_files_list) - set(topological_sort_files)
topological_sort_files = topological_sort_files + list(missing_files)
else:
raise ValueError(
"topological_sort_files should not be longer than target_files_list"
)
assert len(topological_sort_files) == len(
target_files_list
), "all files should be included"

topological_sort_files = [
file.replace(working_dir, "").lstrip("/") for file in topological_sort_files
]
for key, value in import_dependencies.items():
import_dependencies[key] = [
v.replace(working_dir, "").lstrip("/") for v in value
]
return topological_sort_files, import_dependencies
else:
target_files_list = [
file.replace(working_dir, "").lstrip("/") for file in target_files_list
]
return target_files_list, {}


def get_message(
agent_config: AgentConfig,
repo_path: str,
test_dir: str | None = None,
test_file: str | None = None,
test_files: list[str] | None = None,
) -> str:
"""Get the message to Aider."""
prompt = f"{PROMPT_HEADER}" + agent_config.user_prompt

if agent_config.use_unit_tests_info and test_dir:
unit_tests_info = (
f"\n{UNIT_TESTS_INFO_HEADER} "
+ get_dir_info(
dir_path=Path(os.path.join(repo_path, test_dir)),
prefix="",
include_stubs=True,
)[: agent_config.max_unit_tests_info_length]
)
elif agent_config.use_unit_tests_info and test_file:
unit_tests_info = (
f"\n{UNIT_TESTS_INFO_HEADER} "
+ get_file_info(
# if agent_config.use_unit_tests_info and test_file:
# unit_tests_info = (
# f"\n{UNIT_TESTS_INFO_HEADER} "
# + get_file_info(
# file_path=Path(os.path.join(repo_path, test_file)), prefix=""
# )[: agent_config.max_unit_tests_info_length]
# )
if agent_config.use_unit_tests_info and test_files:
unit_tests_info = f"\n{UNIT_TESTS_INFO_HEADER} "
for test_file in test_files:
unit_tests_info += get_file_info(
file_path=Path(os.path.join(repo_path, test_file)), prefix=""
)[: agent_config.max_unit_tests_info_length]
)
)
unit_tests_info = unit_tests_info[: agent_config.max_unit_tests_info_length]
else:
unit_tests_info = ""

Expand Down Expand Up @@ -405,6 +459,33 @@ def create_branch(repo: git.Repo, branch: str, from_commit: str) -> None:
raise RuntimeError(f"Failed to create or switch to branch '{branch}': {e}")


def get_changed_files_from_commits(
repo: git.Repo, commit1: str, commit2: str
) -> list[str]:
"""Get the changed files from two commits."""
try:
# Get the commit objects
commit1_obj = repo.commit(commit1)
commit2_obj = repo.commit(commit2)

# Get the diff between the two commits
diff = commit1_obj.diff(commit2_obj)

# Extract the changed file paths
changed_files = [item.a_path for item in diff]

# Check if each changed file is a Python file
python_files = [file for file in changed_files if file.endswith(".py")]

# Update the changed_files list to only include Python files
changed_files = python_files

return changed_files
except Exception as e:
print(f"An error occurred: {e}")
return []


def args2string(agent_config: AgentConfig) -> str:
"""Converts specific fields from an `AgentConfig` object into a formatted string.
Expand Down Expand Up @@ -453,13 +534,14 @@ def get_changed_files(repo: git.Repo) -> list[str]:
return files_changed


def get_lint_cmd(repo_name: str, use_lint_info: bool) -> str:
def get_lint_cmd(repo_name: str, use_lint_info: bool, commit0_config_file: str) -> str:
"""Generate a linting command based on whether to include files.
Args:
----
repo_name (str): The name of the repository.
use_lint_info (bool): A flag indicating whether to include changed files in the lint command.
commit0_config_file (str): The path to the commit0 dot file.
Returns:
-------
Expand All @@ -469,7 +551,9 @@ def get_lint_cmd(repo_name: str, use_lint_info: bool) -> str:
"""
lint_cmd = "python -m commit0 lint "
if use_lint_info:
lint_cmd += repo_name + " --files "
lint_cmd += (
repo_name + " --commit0-config-file " + commit0_config_file + " --files "
)
else:
lint_cmd = ""
return lint_cmd
Expand Down
69 changes: 38 additions & 31 deletions agent/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aider.models import Model
from aider.io import InputOutput
import re
import os


def handle_logging(logging_name: str, log_file: Path) -> None:
Expand All @@ -24,6 +25,23 @@ def handle_logging(logging_name: str, log_file: Path) -> None:
class AgentReturn(ABC):
def __init__(self, log_file: Path):
self.log_file = log_file

self.last_cost = 0.0


class Agents(ABC):
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration

@abstractmethod
def run(self) -> AgentReturn:
"""Start agent"""
raise NotImplementedError


class AiderReturn(AgentReturn):
def __init__(self, log_file: Path):
super().__init__(log_file)
self.last_cost = self.get_money_cost()

def get_money_cost(self) -> float:
Expand All @@ -40,20 +58,25 @@ def get_money_cost(self) -> float:
return last_cost


class Agents(ABC):
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration

@abstractmethod
def run(self) -> AgentReturn:
"""Start agent"""
raise NotImplementedError


class AiderAgents(Agents):
def __init__(self, max_iteration: int, model_name: str):
super().__init__(max_iteration)
self.model = Model(model_name)
# Check if API key is set for the model
if "gpt" in model_name:
api_key = os.environ.get("OPENAI_API_KEY", None)
elif "claude" in model_name:
api_key = os.environ.get("ANTHROPIC_API_KEY", None)
elif "gemini" in model_name:
api_key = os.environ.get("API_KEY", None)
else:
raise ValueError(f"Unsupported model: {model_name}")

if not api_key:
raise ValueError(
"API Key Error: There is no API key associated with the model for this agent. "
"Edit model_name parameter in .agent.yaml, export API key for that model, and try again."
)

def run(
self,
Expand All @@ -63,6 +86,7 @@ def run(
fnames: list[str],
log_dir: Path,
test_first: bool = False,
lint_first: bool = False,
) -> AgentReturn:
"""Start aider agent"""
if test_cmd:
Expand Down Expand Up @@ -90,11 +114,6 @@ def run(
sys.stdout = open(log_file, "a")
sys.stderr = open(log_file, "a")

# Log the message
agent_message_log_file = log_dir / "agent_message.log"
with open(agent_message_log_file, "a") as f:
f.write(f"Message Sent: {message}\n\n")

# Configure httpx and backoff logging
handle_logging("httpx", log_file)
handle_logging("backoff", log_file)
Expand All @@ -113,36 +132,24 @@ def run(
test_cmd=test_cmd,
io=io,
)
coder.max_reflection = self.max_iteration
coder.max_reflections = self.max_iteration
coder.stream = True

# Run the agent
if test_first:
test_errors = coder.commands.cmd_test(test_cmd)
if test_errors:
coder.run(test_errors)
elif lint_first:
coder.commands.cmd_lint(fnames=fnames)
else:
coder.run(message)

# #### TMP

# #### TMP
# import time
# import random

# time.sleep(random.random() * 5)
# n = random.random() / 10
# with open(log_file, "a") as f:
# f.write(
# f"> Tokens: 33k sent, 1.3k received. Cost: $0.12 message, ${n} session. \n"
# )
# #### TMP

# Close redirected stdout and stderr
sys.stdout.close()
sys.stderr.close()
# Restore original stdout and stderr
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__

return AgentReturn(log_file)
return AiderReturn(log_file)
3 changes: 3 additions & 0 deletions agent/class_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ class AgentConfig:
model_name: str
use_user_prompt: bool
user_prompt: str
use_topo_sort_dependencies: bool
add_import_module_to_context: bool
use_repo_info: bool
max_repo_info_length: int
use_unit_tests_info: bool
max_unit_tests_info_length: int
use_spec_info: bool
max_spec_info_length: int
use_lint_info: bool
run_entire_dir_lint: bool
max_lint_info_length: int
pre_commit_config_path: str
run_tests: bool
Expand Down
Loading
Loading