Skip to content

Commit

Permalink
remove unused code_feature methods / funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
granawkins committed Apr 10, 2024
1 parent e7bf6e6 commit eeac156
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 450 deletions.
2 changes: 2 additions & 0 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ async def search(
all_nodes_sorted = self.daemon.search(query, max_results)
all_features_sorted = list[tuple[CodeFeature, float]]()
for node in all_nodes_sorted:
if node.get("type") not in {"file", "chunk"}:
continue
distance = node["distance"]
path, interval = split_intervals_from_path(Path(node["ref"]))
intervals = parse_intervals(interval)
Expand Down
125 changes: 8 additions & 117 deletions mentat/code_feature.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from __future__ import annotations

import asyncio
import logging
from collections import OrderedDict, defaultdict
from collections import defaultdict
from pathlib import Path
from typing import Optional

import attr
from ragdaemon.utils import get_document

from mentat.diff_context import annotate_file_message, parse_diff
from mentat.errors import MentatError
from mentat.git_handler import get_diff_for_file
from mentat.interval import INTERVAL_FILE_END, Interval
from mentat.llm_api_handler import count_tokens
from mentat.session_context import SESSION_CONTEXT
Expand Down Expand Up @@ -62,118 +59,12 @@ def interval_string(self) -> str:
def __str__(self, cwd: Optional[Path] = None) -> str:
return self.rel_path(cwd) + self.interval_string()

def get_code_message(self, standalone: bool = True) -> list[str]:
"""
Gets this code features code message.
If standalone is true, will include the filename at top and extra newline at the end.
If feature contains entire file, will add inline diff annotations; otherwise, will append them to the end.
"""
if not self.path.exists() or self.path.is_dir():
return []

session_context = SESSION_CONTEXT.get()
code_file_manager = session_context.code_file_manager
parser = session_context.config.parser
code_context = session_context.code_context

code_message: list[str] = []

if standalone:
# We always want to give GPT posix paths
code_message_path = get_relative_path(self.path, session_context.cwd)
code_message.append(str(code_message_path.as_posix()))

# Get file lines
file_lines = code_file_manager.read_file(self.path)
for i, line in enumerate(file_lines):
if self.interval.contains(i + 1):
if parser.provide_line_numbers():
code_message.append(f"{i + parser.line_number_starting_index()}:{line}")
else:
code_message.append(f"{line}")

if standalone:
code_message.append("")

if self.path in code_context.diff_context.diff_files():
diff = get_diff_for_file(code_context.diff_context.target, self.path)
diff_annotations = parse_diff(diff)
if self.interval.whole_file():
code_message = annotate_file_message(code_message, diff_annotations)
else:
for section in diff_annotations:
# TODO: Place diff_annotations inside interval where they belong
if section.start >= self.interval.start and section.start < self.interval.end:
code_message += section.message
return code_message

def get_checksum(self) -> str:
# TODO: Only update checksum if last modified time of file updates to conserve file system reads
session_context = SESSION_CONTEXT.get()
code_file_manager = session_context.code_file_manager

return code_file_manager.get_file_checksum(self.path, self.interval)

def count_tokens(self, model: str) -> int:
code_message = self.get_code_message()
return count_tokens("\n".join(code_message), model, full_message=False)


async def count_feature_tokens(features: list[CodeFeature], model: str) -> list[int]:
"""Return the number of tokens in each feature."""
sem = asyncio.Semaphore(10)

feature_tokens = list[int]()
for feature in features:
async with sem:
tokens = feature.count_tokens(model)
feature_tokens.append(tokens)
return feature_tokens


def _get_code_message_from_intervals(features: list[CodeFeature]) -> list[str]:
"""
Merge multiple features for the same file into a single code message.
"""
features_sorted = sorted(features, key=lambda f: f.interval)
posix_path = features_sorted[0].get_code_message()[0]
code_message = [posix_path]
next_line = 1
for feature in features_sorted:
starting_line = feature.interval.start
if starting_line < next_line:
logging.info(f"Features overlap: {feature}")
if feature.interval.end <= next_line:
continue
feature = CodeFeature(
feature.path,
interval=Interval(next_line, feature.interval.end),
name=feature.name,
)
elif starting_line > next_line:
code_message += ["..."]
code_message += feature.get_code_message(standalone=False)
next_line = feature.interval.end
return code_message + [""]


def get_code_message_from_features(features: list[CodeFeature]) -> list[str]:
"""
Generate a code message from a list of features.
Will automatically handle overlapping intervals.
"""
code_message = list[str]()
features_by_path: dict[Path, list[CodeFeature]] = OrderedDict()
for feature in features:
if feature.path not in features_by_path:
features_by_path[feature.path] = list[CodeFeature]()
features_by_path[feature.path].append(feature)
for path_features in features_by_path.values():
if len(path_features) == 1:
code_message += path_features[0].get_code_message()
else:
code_message += _get_code_message_from_intervals(path_features)
return code_message

def count_feature_tokens(feature: CodeFeature, model: str) -> int:
cwd = SESSION_CONTEXT.get().cwd
ref = feature.__str__(cwd)
document = get_document(ref, cwd)
return count_tokens(document, model, full_message=False)


def get_consolidated_feature_refs(features: list[CodeFeature]) -> list[str]:
Expand Down
3 changes: 2 additions & 1 deletion mentat/command/commands/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing_extensions import override

from mentat.code_feature import count_feature_tokens
from mentat.command.command import Command, CommandArgument
from mentat.errors import UserError
from mentat.session_context import SESSION_CONTEXT
Expand Down Expand Up @@ -60,7 +61,7 @@ async def apply(self, *args: str) -> None:
file_interval = feature.interval_string()
stream.send(file_interval, color="bright_cyan", end="")

tokens = feature.count_tokens(config.model)
tokens = count_feature_tokens(feature, config.model)
cumulative_tokens += tokens
tokens_str = f" ({tokens} tokens)"
stream.send(tokens_str, color="yellow")
Expand Down
89 changes: 0 additions & 89 deletions mentat/diff_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
from pathlib import Path
from typing import List, Literal, Optional

import attr

from mentat.errors import MentatError
from mentat.git_handler import (
check_head_exists,
get_diff_for_file,
Expand All @@ -13,83 +10,10 @@
get_treeish_metadata,
get_untracked_files,
)
from mentat.interval import Interval
from mentat.session_context import SESSION_CONTEXT
from mentat.session_stream import SessionStream


@attr.define(frozen=True)
class DiffAnnotation(Interval):
start: int | float = attr.field()
message: List[str] = attr.field()
end: int | float = attr.field(
default=attr.Factory(
lambda self: self.start + sum(bool(line.startswith("-")) for line in self.message),
takes_self=True,
)
)


def parse_diff(diff: str) -> list[DiffAnnotation]:
"""Parse diff into a list of annotations."""
annotations: list[DiffAnnotation] = []
active_annotation: Optional[DiffAnnotation] = None
lines = diff.splitlines()
for line in lines:
if line.startswith(("---", "+++", "//", "diff", "index")):
continue
elif line.startswith("@@"):
if active_annotation:
annotations.append(active_annotation)
_new_index = line.split(" ")[2]
if "," in _new_index:
new_start = _new_index[1:].split(",")[0]
else:
new_start = _new_index[1:]
active_annotation = DiffAnnotation(start=int(new_start), message=[])
elif line.startswith(("+", "-")):
if not active_annotation:
raise MentatError("Invalid diff")
active_annotation.message.append(line)
if active_annotation:
annotations.append(active_annotation)
annotations.sort(key=lambda a: a.start)
return annotations


def annotate_file_message(code_message: list[str], annotations: list[DiffAnnotation]) -> list[str]:
"""Return the code_message with annotations inserted."""
active_index = 0
annotated_message: list[str] = []
for annotation in annotations:
# Fill-in lines between annotations
if active_index < annotation.start:
unaffected_lines = code_message[active_index : annotation.start]
annotated_message += unaffected_lines
active_index = annotation.start
if annotation.start == 0:
# Make sure the PATH stays on line 1
annotated_message.append(code_message[0])
active_index += 1
i_minus = None
for line in annotation.message:
sign = line[0]
if sign == "+":
# Add '+' lines in place of code_message lines
annotated_message.append(f"{active_index}:{line}")
active_index += 1
i_minus = None
elif sign == "-":
# Insert '-' lines at the point they were removed
i_minus = 0 if i_minus is None else i_minus
annotated_message.append(f"{annotation.start + i_minus}:{line}")
i_minus += 1
if active_index < len(code_message):
annotated_message += code_message[active_index:]

return annotated_message


class DiffContext:
target: str = ""
name: str = "index (last commit)"
Expand Down Expand Up @@ -184,12 +108,6 @@ def refresh(self):
self._diff_files = [(ctx.cwd / f).resolve() for f in get_files_in_diff(self.target)]
self._untracked_files = [(ctx.cwd / f).resolve() for f in get_untracked_files(ctx.cwd)]

def get_annotations(self, rel_path: Path) -> list[DiffAnnotation]:
if not self.git_root:
return []
diff = get_diff_for_file(self.target, rel_path)
return parse_diff(diff)

def get_display_context(self) -> Optional[str]:
if not self.git_root:
return None
Expand All @@ -204,13 +122,6 @@ def get_display_context(self) -> Optional[str]:
num_lines += len([line for line in diff_lines if line.startswith(("+ ", "- "))])
return f" {self.name} | {num_files} files | {num_lines} lines"

def annotate_file_message(self, rel_path: Path, file_message: list[str]) -> list[str]:
"""Return file_message annotated with active diff."""
if not self.git_root:
return []
annotations = self.get_annotations(rel_path)
return annotate_file_message(file_message, annotations)


TreeishType = Literal["commit", "branch", "relative", "compare"]

Expand Down
38 changes: 0 additions & 38 deletions scripts/sampler/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
from pathlib import Path
from typing import Any

from add_context import add_context
from finetune import generate_finetune
from remove_context import remove_context
from validate import validate_sample

from mentat.llm_api_handler import count_tokens, prompt_tokens
Expand Down Expand Up @@ -50,13 +48,6 @@ async def main():
help="Validate samples conform to spec",
)
parser.add_argument("--finetune", "-f", action="store_true", help="Generate fine-tuning examples")
parser.add_argument("--add-context", "-a", action="store_true", help="Add extra context to samples")
parser.add_argument(
"--remove-context",
"-r",
action="store_true",
help="Remove context from samples",
)
args = parser.parse_args()
sample_files = []
if args.sample_ids:
Expand All @@ -81,11 +72,6 @@ async def main():
except Exception as e:
warn(f"Error loading sample {sample_file}: {e}")
continue
if (args.add_context or args.remove_context) and (
"[ADDED CONTEXT]" in sample.title or "[REMOVED CONTEXT]" in sample.title
):
warn(f"Skipping {sample.id[:8]}: has already been modified.")
continue
if args.validate:
is_valid, reason = await validate_sample(sample)
status = "\033[92mPASSED\033[0m" if is_valid else f"\033[91mFAILED: {reason}\033[0m"
Expand All @@ -104,26 +90,6 @@ async def main():
logs.append(example)
except Exception as e:
warn(f"Error generating finetune example for sample {sample.id[:8]}: {e}")
elif args.add_context:
try:
new_sample = await add_context(sample)
sample_file = SAMPLES_DIR / f"sample_{new_sample.id}.json"
new_sample.save(sample_file)
print(f"Generated new sample with extra context: {sample_file}")
logs.append({"id": new_sample.id, "prototype_id": sample.id})
except Exception as e:
warn(f"Error adding extra context to sample {sample.id[:8]}: {e}")
elif args.remove_context:
if not sample.context or len(sample.context) == 1:
warn(f"Skipping {sample.id[:8]}: no context to remove.")
continue
try:
new_sample = await remove_context(sample)
new_sample.save(SAMPLES_DIR / f"sample_{new_sample.id}.json")
print(f"Generated new sample with context removed: {sample_file}")
logs.append({"id": new_sample.id, "prototype_id": sample.id})
except Exception as e:
warn(f"Error removing context from sample {sample.id[:8]}: {e}")
else:
print(f"Running sample {sample.id[:8]}")
print(f" Prompt: {sample.message_prompt}")
Expand Down Expand Up @@ -161,10 +127,6 @@ async def main():
del log["tokens"]
f.write(json.dumps(log) + "\n")
print(f"{len(logs)} fine-tuning examples ({tokens} tokens) saved to {fname}.")
elif args.add_context:
print(f"{len(logs)} samples with extra context generated.")
elif args.remove_context:
print(f"{len(logs)} samples with context removed generated.")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit eeac156

Please sign in to comment.