Skip to content

Commit

Permalink
Move patch functionality from Coder to it's own file
Browse files Browse the repository at this point in the history
  • Loading branch information
TechNickAI committed Aug 13, 2023
1 parent a686786 commit 9c9c7f5
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 339 deletions.
164 changes: 0 additions & 164 deletions aicodebot/coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from aicodebot.lm import token_size
from pathlib import Path
from pygments.lexers import ClassNotFound, get_lexer_for_mimetype, guess_lexer_for_filename
from types import SimpleNamespace
import fnmatch, mimetypes, re, subprocess


Expand All @@ -13,41 +12,6 @@ class Coder:

UNKNOWN_FILE_TYPE = "unknown"

@staticmethod
def apply_patch(patch_string, is_rebuilt=False):
"""Applies a patch to the local file system using git apply."""
try:
result = subprocess.run(
[
"git",
"apply",
"--verbose",
"--recount",
"--inaccurate-eof",
],
input=patch_string.encode("utf-8"),
check=True,
capture_output=True,
)
logger.debug(f"git apply output {result.stdout}")
except subprocess.CalledProcessError as e:
logger.error("Failed to apply patch:")
print(patch_string) # noqa: T201
logger.error(e.stderr)

# Rebuild it and try again
if not is_rebuilt:
rebuilt_patch = Coder.rebuild_patch(patch_string)
if patch_string != rebuilt_patch:
logger.error("Received an invalid patch from the LM, fixing.")
logger.error(f"Original patch: {patch_string}")
logger.error(f"Rebuilt patch: {rebuilt_patch}")
return Coder.apply_patch(rebuilt_patch, is_rebuilt=True)

return False
else:
return True

@staticmethod
def auto_file_context(max_tokens, max_file_tokens):
"""Automatically generate a file context based on what we think the user is working on"""
Expand Down Expand Up @@ -318,131 +282,3 @@ def parse_github_url(repo_url):

owner, repo = match.groups()
return owner, repo

@staticmethod
def rebuild_patch(patch_string): # noqa: PLR0915
"""We ask the LM to respond with unified patch format. It often gets it wrong, especially the chunk headers.
This function looks at the intent of the patch and rebuilds it in a [hopefully] correct format."""

def parse_line(line): # noqa: PLR0911
"""Parse a line of the patch and return a SimpleNamespace with the line, type, and parsed line."""
if line.startswith(("diff --git", "index")):
return SimpleNamespace(line=line, type="header", parsed=line)
elif line.startswith("---"):
return SimpleNamespace(line=line, type="source_file", parsed=line[6:])
elif line.startswith("+++"):
return SimpleNamespace(line=line, type="destination_file", parsed=line[6:])
elif line.startswith("@@"):
chunk_header_match = re.match(r"@@ -(\d+),(\d+) \+(\d+),(\d+) @@", line)
if not chunk_header_match:
raise ValueError(f"Invalid chunk header: {line}")

chunk_header = SimpleNamespace(
start1=int(chunk_header_match.group(1)),
count1=int(chunk_header_match.group(2)),
start2=int(chunk_header_match.group(3)),
count2=int(chunk_header_match.group(4)),
)

return SimpleNamespace(line=line, type="chunk_header", parsed=chunk_header)
elif line.startswith("+"):
return SimpleNamespace(line=line, type="addition", parsed=line[1:])
elif line.startswith("-"):
return SimpleNamespace(line=line, type="subtraction", parsed=line[1:])
elif line.startswith(" "):
return SimpleNamespace(line=line, type="context", parsed=line[1:])
else:
raise ValueError(f"Invalid line: '{line}'")

# ------------------------- Parse the incoming patch ------------------------- #
parsed_lines = []
chunk_header = None
for line in patch_string.lstrip().splitlines():
if chunk_header and not line.startswith(("+", "-", " ")):
# Sometimes the LM will add a context line without a space
# If we see that, we'll assume it's a context line
line = " " + line # noqa: PLW2901

parsed_line = parse_line(line)
parsed_lines.append(parsed_line)
if parsed_lines[-1].type == "chunk_header":
chunk_header = parsed_lines[-1].parsed

# Check for critical fields
source_file_line = next(line for line in parsed_lines if line.type == "source_file")
if not source_file_line:
raise ValueError("No source file found in patch")

first_context_line = next(line for line in parsed_lines if line.type == "context")
if not first_context_line:
raise ValueError("No context line found in patch")

if not chunk_header:
# Chunk header missing. This shouldn't happen, but we should be able to recover
chunk_header = SimpleNamespace(start1=0, count1=0, start2=0, count2=0)

start1 = chunk_header.start1
first_change_line = next(line for line in parsed_lines if line.type in ("addition", "subtraction"))
lines_of_context = 3

# ------------------------- Rebuild the context lines ------------------------ #
# Get the correct start line from the first context line, by looking at the source file
source_file = source_file_line.parsed
source_file_contents = []
if source_file != "/dev/null" and Path(source_file).exists():
source_file_contents = Path(source_file).read_text().splitlines()

# Determine the correct line of the first change
# We will start looking at start1 - 1, and walk until we find it
for i in range(start1 - 1, len(source_file_contents)):
if source_file_contents[i] == first_change_line.parsed:
first_change_line_number = i + 1
break
else:
raise ValueError(f"Could not find first change line in source file: {first_change_line.parsed}")

# Disregard the existing context lines from the parsed lines
parsed_lines = [line for line in parsed_lines if line.type != "context"]

# Add x lines of context before the first change
for i in range(first_change_line_number - lines_of_context, first_change_line_number):
# Get the index number of the first changed line in parsed_lines
first_change_line_index = next(
i for i, line in enumerate(parsed_lines) if line.type in ("addition", "subtraction")
)
parsed_lines.insert(first_change_line_index, parse_line(f" {source_file_contents[i-1]}"))

# Add x lines of context after the last change
number_of_subtractions = len([line for line in parsed_lines if line.type == "subtraction"])
start_trailing_context = first_change_line_number + number_of_subtractions
for i in range(start_trailing_context, start_trailing_context + lines_of_context):
parsed_lines.append(parse_line(f" {source_file_contents[i-1]}"))

# ------------------------- Rebuild the chunk header ------------------------- #

# Add up the number of context lines, additions, and subtractions
# This will be the new count1 and count2
start2 = start1
count1 = count2 = 0
for line in parsed_lines:
if line.type in ("context", "subtraction"):
count1 += 1
if line.type in ("context", "addition"):
count2 += 1

new_chunk_header = f"@@ -{start1},{count1} +{start2},{count2} @@"

# ----------------------------- Rebuild the patch ---------------------------- #

new_patch = []
for line in parsed_lines:
if line.type == "chunk_header":
new_patch.append(new_chunk_header)
elif line.type == "source_file":
new_patch.append(f"--- a/{line.parsed}")
elif line.type == "destination_file":
new_patch.append(f"+++ b/{line.parsed}")
else:
new_patch.append(f"{line.line}")

return "\n".join(new_patch) + "\n"
4 changes: 3 additions & 1 deletion aicodebot/commands/sidekick.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def sidekick(request, no_files, max_file_tokens, files): # noqa: PLR0915
# have a record of what they asked for on their terminal
console.print(parsed_human_input)
try:
with Live(OurMarkdown(f"Talking to {lmm.model_name} via {lmm.provider}"), auto_refresh=True) as live:
with Live(
OurMarkdown(f"Sending task to {lmm.model_name} via {lmm.provider}"), auto_refresh=False
) as live:
chain = lmm.chain_factory(
prompt=prompt,
streaming=True,
Expand Down
3 changes: 2 additions & 1 deletion aicodebot/input.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from aicodebot.coder import Coder
from aicodebot.lm import token_size
from aicodebot.patch import Patch
from pathlib import Path
from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
Expand Down Expand Up @@ -70,7 +71,7 @@ def parse_human_input(self, human_input): # noqa: PLR0911, PLR0915
for diff_block in self.diff_blocks:
# Apply the diff with git apply
count += 1
if Coder.apply_patch(diff_block):
if Patch.apply_patch(diff_block):
self.console.print(Panel(f"✅ change {count} applied."))
return self.CONTINUE

Expand Down
169 changes: 169 additions & 0 deletions aicodebot/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from aicodebot.helpers import logger
from pathlib import Path
from types import SimpleNamespace
import re, subprocess


class Patch:
"""Handle patches in unified diff format for making changes to the local file system."""

@staticmethod
def apply_patch(patch_string, is_rebuilt=False):
"""Applies a patch to the local file system using git apply."""
try:
result = subprocess.run(
[
"git",
"apply",
"--verbose",
"--recount",
"--inaccurate-eof",
],
input=patch_string.encode("utf-8"),
check=True,
capture_output=True,
)
logger.debug(f"git apply output {result.stdout}")
except subprocess.CalledProcessError as e:
logger.error("Failed to apply patch:")
print(patch_string) # noqa: T201
logger.error(e.stderr)

# Rebuild it and try again
if not is_rebuilt:
rebuilt_patch = Patch.rebuild_patch(patch_string)
if patch_string != rebuilt_patch:
return Patch.apply_patch(rebuilt_patch, is_rebuilt=True)

return False
else:
return True

@staticmethod
def parse_line(line): # noqa: PLR0911
"""Parse a line of the patch and return a SimpleNamespace with the line, type, and parsed line."""
if line.startswith(("diff --git", "index")):
return SimpleNamespace(line=line, type="header", parsed=line)
elif line.startswith("---"):
return SimpleNamespace(line=line, type="source_file", parsed=line[6:])
elif line.startswith("+++"):
return SimpleNamespace(line=line, type="destination_file", parsed=line[6:])
elif line.startswith("@@"):
chunk_header_match = re.match(r"@@ -(\d+),(\d+) \+(\d+),(\d+) @@", line)
if not chunk_header_match:
raise ValueError(f"Invalid chunk header: {line}")

chunk_header = SimpleNamespace(
start1=int(chunk_header_match.group(1)),
count1=int(chunk_header_match.group(2)),
start2=int(chunk_header_match.group(3)),
count2=int(chunk_header_match.group(4)),
)

return SimpleNamespace(line=line, type="chunk_header", parsed=chunk_header)
elif line.startswith("+"):
return SimpleNamespace(line=line, type="addition", parsed=line[1:])
elif line.startswith("-"):
return SimpleNamespace(line=line, type="subtraction", parsed=line[1:])
elif line.startswith(" "):
return SimpleNamespace(line=line, type="context", parsed=line[1:])
else:
raise ValueError(f"Invalid line: '{line}'")

@staticmethod
def rebuild_patch(patch_string): # noqa: PLR0915
"""We ask the LM to respond with unified patch format. It often gets it wrong, especially the chunk headers.
This function looks at the intent of the patch and rebuilds it in a [hopefully] correct format."""

# ------------------------- Parse the incoming patch ------------------------- #
parsed_lines = []
chunk_header = None
for line in patch_string.lstrip().splitlines():
if chunk_header and not line.startswith(("+", "-", " ")):
# Sometimes the LM will add a context line without a space
# If we see that, we'll assume it's a context line
line = " " + line # noqa: PLW2901

parsed_line = Patch.parse_line(line)
parsed_lines.append(parsed_line)
if parsed_lines[-1].type == "chunk_header":
chunk_header = parsed_lines[-1].parsed

# Check for critical fields
source_file_line = next(line for line in parsed_lines if line.type == "source_file")
if not source_file_line:
raise ValueError("No source file found in patch")

first_context_line = next(line for line in parsed_lines if line.type == "context")
if not first_context_line:
raise ValueError("No context line found in patch")

if not chunk_header:
# Chunk header missing. This shouldn't happen, but we should be able to recover
chunk_header = SimpleNamespace(start1=0, count1=0, start2=0, count2=0)

start1 = chunk_header.start1
first_change_line = next(line for line in parsed_lines if line.type in ("addition", "subtraction"))
lines_of_context = 3

# ------------------------- Rebuild the context lines ------------------------ #
# Get the correct start line from the first context line, by looking at the source file
source_file = source_file_line.parsed
source_file_contents = []
if source_file != "/dev/null" and Path(source_file).exists():
source_file_contents = Path(source_file).read_text().splitlines()

# Determine the correct line of the first change
# We will start looking at start1 - 1, and walk until we find it
for i in range(start1 - 1, len(source_file_contents)):
if source_file_contents[i] == first_change_line.parsed:
first_change_line_number = i + 1
break
else:
raise ValueError(f"Could not find first change line in source file: {first_change_line.parsed}")

# Disregard the existing context lines from the parsed lines
parsed_lines = [line for line in parsed_lines if line.type != "context"]

# Add x lines of context before the first change
for i in range(first_change_line_number - lines_of_context, first_change_line_number):
# Get the index number of the first changed line in parsed_lines
first_change_line_index = next(
i for i, line in enumerate(parsed_lines) if line.type in ("addition", "subtraction")
)
parsed_lines.insert(first_change_line_index, Patch.parse_line(f" {source_file_contents[i-1]}"))

# Add x lines of context after the last change
number_of_subtractions = len([line for line in parsed_lines if line.type == "subtraction"])
start_trailing_context = first_change_line_number + number_of_subtractions
for i in range(start_trailing_context, start_trailing_context + lines_of_context):
parsed_lines.append(Patch.parse_line(f" {source_file_contents[i-1]}"))

# ------------------------- Rebuild the chunk header ------------------------- #

# Add up the number of context lines, additions, and subtractions
# This will be the new count1 and count2
start2 = start1
count1 = count2 = 0
for line in parsed_lines:
if line.type in ("context", "subtraction"):
count1 += 1
if line.type in ("context", "addition"):
count2 += 1

new_chunk_header = f"@@ -{start1},{count1} +{start2},{count2} @@"

# ----------------------------- Rebuild the patch ---------------------------- #

new_patch = []
for line in parsed_lines:
if line.type == "chunk_header":
new_patch.append(new_chunk_header)
elif line.type == "source_file":
new_patch.append(f"--- a/{line.parsed}")
elif line.type == "destination_file":
new_patch.append(f"+++ b/{line.parsed}")
else:
new_patch.append(f"{line.line}")

return "\n".join(new_patch) + "\n"
Loading

0 comments on commit 9c9c7f5

Please sign in to comment.