Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asymmetric context #1114

Merged
merged 11 commits into from
Aug 11, 2024
3 changes: 2 additions & 1 deletion docs/docs/usage-guide/additional_configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ By default, around any change in your PR, git patch provides three lines of cont
For the `review`, `describe`, `ask` and `add_docs` tools, if the token budget allows, PR-Agent tries to increase the number of lines of context, via the parameter:
```
[config]
patch_extra_lines=3
patch_extra_lines_before=4
patch_extra_lines_after=2
```

Increasing this number provides more context to the model, but will also increase the token budget.
Expand Down
1 change: 1 addition & 0 deletions pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
'bedrock/anthropic.claude-3-sonnet-20240229-v1:0': 100000,
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0': 100000,
'claude-3-5-sonnet': 100000,
'groq/llama3-8b-8192': 8192,
'groq/llama3-70b-8192': 8192,
'groq/mixtral-8x7b-32768': 32768,
Expand Down
65 changes: 38 additions & 27 deletions pr_agent/algo/git_patch_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,8 @@
from pr_agent.log import get_logger


def extend_patch(original_file_str, patch_str, num_lines) -> str:
"""
Extends the given patch to include a specified number of surrounding lines.

Args:
original_file_str (str): The original file to which the patch will be applied.
patch_str (str): The patch to be applied to the original file.
num_lines (int): The number of surrounding lines to include in the extended patch.

Returns:
str: The extended patch string.
"""
if not patch_str or num_lines == 0:
def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, patch_extra_lines_after=0) -> str:
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0):
return patch_str

if type(original_file_str) == bytes:
Expand All @@ -29,6 +18,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
return ""

original_lines = original_file_str.splitlines()
len_original_lines = len(original_lines)
patch_lines = patch_str.splitlines()
extended_patch_lines = []

Expand All @@ -40,10 +30,11 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
if match:
# finish previous hunk
if start1 != -1:
extended_patch_lines.extend(
original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines])
# finish last hunk
if start1 != -1 and patch_extra_lines_after > 0:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
delta_lines = [f' {line}' for line in delta_lines]
extended_patch_lines.extend(delta_lines)

res = list(match.groups())
for i in range(len(res)):
Expand All @@ -55,26 +46,46 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header = res[4]
extended_start1 = max(1, start1 - num_lines)
extended_size1 = size1 + (start1 - extended_start1) + num_lines
extended_start2 = max(1, start2 - num_lines)
extended_size2 = size2 + (start2 - extended_start2) + num_lines

if patch_extra_lines_before > 0 or patch_extra_lines_after > 0:
extended_start1 = max(1, start1 - patch_extra_lines_before)
extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after
if extended_start1 - 1 + extended_size1 > len(original_lines):
extended_size1 = len_original_lines - extended_start1 + 1
extended_start2 = max(1, start2 - patch_extra_lines_before)
extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after
if extended_start2 - 1 + extended_size2 > len_original_lines:
extended_size2 = len_original_lines - extended_start2 + 1
delta_lines = original_lines[extended_start1 - 1:start1 - 1]
delta_lines = [f' {line}' for line in delta_lines]
if section_header:
for line in delta_lines:
if section_header in line:
section_header = '' # remove section header if it is in the extra delta lines
break
else:
extended_start1 = start1
extended_size1 = size1
extended_start2 = start2
extended_size2 = size2
delta_lines = []
extended_patch_lines.append(
f'@@ -{extended_start1},{extended_size1} '
f'+{extended_start2},{extended_size2} @@ {section_header}')
extended_patch_lines.extend(
original_lines[extended_start1 - 1:start1 - 1]) # one to zero based
extended_patch_lines.extend(delta_lines) # one to zero based
continue
extended_patch_lines.append(line)
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to extend patch: {e}")
return patch_str

# finish previous hunk
if start1 != -1:
extended_patch_lines.extend(
original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines])
# finish last hunk
if start1 != -1 and patch_extra_lines_after > 0:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
# add space at the beginning of each extra line
delta_lines = [f' {line}' for line in delta_lines]
extended_patch_lines.extend(delta_lines)

extended_patch_str = '\n'.join(extended_patch_lines)
return extended_patch_str
Expand Down
34 changes: 16 additions & 18 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
large_pr_handling=False,
return_remaining_files=False):
if disable_extra_lines:
PATCH_EXTRA_LINES = 0
PATCH_EXTRA_LINES_BEFORE = 0
PATCH_EXTRA_LINES_AFTER = 0
else:
PATCH_EXTRA_LINES = get_settings().config.patch_extra_lines
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after

try:
diff_files_original = git_provider.get_diff_files()
Expand Down Expand Up @@ -64,15 +66,16 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,

# generate a standard diff string, with patch extension
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks, patch_extra_lines=PATCH_EXTRA_LINES)
pr_languages, token_handler, add_line_numbers_to_hunks,
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)

# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
get_logger().info(f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, "
f"returning full diff.")
return "\n".join(patches_extended)

# if we are over the limit, start pruning
# if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines)
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
f"pruning diff.")
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
Expand Down Expand Up @@ -174,17 +177,8 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH
def pr_generate_extended_diff(pr_languages: list,
token_handler: TokenHandler,
add_line_numbers_to_hunks: bool,
patch_extra_lines: int = 0) -> Tuple[list, int, list]:
"""
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
minimization techniques if needed.

Args:
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding
files.
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
"""
patch_extra_lines_before: int = 0,
patch_extra_lines_after: int = 0) -> Tuple[list, int, list]:
total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = []
patches_extended_tokens = []
Expand All @@ -196,7 +190,8 @@ def pr_generate_extended_diff(pr_languages: list,
continue

# extend each patch with extra lines of context
extended_patch = extend_patch(original_file_content_str, patch, num_lines=patch_extra_lines)
extended_patch = extend_patch(original_file_content_str, patch,
patch_extra_lines_before, patch_extra_lines_after)
if not extended_patch:
get_logger().warning(f"Failed to extend patch for file: {file.filename}")
continue
Expand Down Expand Up @@ -405,10 +400,13 @@ def get_pr_multi_diffs(git_provider: GitProvider,
for lang in pr_languages:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))


# try first a single run with standard diff string, with patch extension, and no deletions
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks=True)
pr_languages, token_handler, add_line_numbers_to_hunks=True,
patch_extra_lines_before=get_settings().config.patch_extra_lines_before,
patch_extra_lines_after=get_settings().config.patch_extra_lines_after)

# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
return ["\n".join(patches_extended)] if patches_extended else []

Expand Down
14 changes: 14 additions & 0 deletions pr_agent/git_providers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,17 @@ def apply_repo_settings(pr_url):
os.remove(repo_settings_file)
except Exception as e:
get_logger().error(f"Failed to remove temporary settings file {repo_settings_file}", e)

# enable switching models with a short definition
if get_settings().config.model.lower()=='claude-3-5-sonnet':
set_claude_model()


def set_claude_model():
"""
set the claude-sonnet-3.5 model easily (even by users), just by stating: --config.model='claude-3-5-sonnet'
"""
model_claude = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
get_settings().set('config.model', model_claude)
get_settings().set('config.model_turbo', model_claude)
get_settings().set('config.fallback_models', [model_claude])
5 changes: 3 additions & 2 deletions pr_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ max_commits_tokens = 500
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
custom_model_max_tokens=-1 # for models not in the default list
#
patch_extra_lines = 1
patch_extra_lines_before = 3 # Number of extra lines (+3 default ones) to include before each hunk in the patch
patch_extra_lines_after = 1 # Number of extra lines (+3 default ones) to include after each hunk in the patch
secret_provider=""
cli_mode=false
ai_disclaimer_title="" # Pro feature, title for a collapsible disclaimer to AI outputs
Expand Down Expand Up @@ -96,7 +97,7 @@ enable_help_text=false


[pr_code_suggestions] # /improve #
max_context_tokens=10000
max_context_tokens=14000
num_code_suggestions=4
commitable_code_suggestions = false
extra_instructions = ""
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/tools/pr_code_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ async def _prepare_prediction(self, model: str) -> dict:
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=True)
disable_extra_lines=False)

if self.patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff)
Expand Down
Loading
Loading