Skip to content

Commit

Permalink
Merge pull request #636 from Codium-ai/tr/model_turbo
Browse files Browse the repository at this point in the history
moving the 'improve' command to turbo mode, with auto_extended=true
  • Loading branch information
mrT23 authored Feb 1, 2024
2 parents 2816cd2 + d04d8b6 commit cb8ff2b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 13 deletions.
1 change: 1 addition & 0 deletions pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,
Expand Down
13 changes: 8 additions & 5 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens
from pr_agent.algo.utils import get_max_tokens, ModelType
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
from pr_agent.log import get_logger
Expand Down Expand Up @@ -220,8 +220,8 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
return patches, modified_files_list, deleted_files_list, added_files_list


async def retry_with_fallback_models(f: Callable):
all_models = _get_all_models()
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
all_models = _get_all_models(model_type)
all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
Expand All @@ -243,8 +243,11 @@ async def retry_with_fallback_models(f: Callable):
raise # Re-raise the last exception


def _get_all_models() -> List[str]:
model = get_settings().config.model
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
if model_type == ModelType.TURBO:
model = get_settings().config.model_turbo
else:
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [m.strip() for m in fallback_models.split(",")]
Expand Down
4 changes: 4 additions & 0 deletions pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import textwrap
from datetime import datetime
from enum import Enum
from typing import Any, List

import yaml
Expand All @@ -15,6 +16,9 @@
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger

class ModelType(str, Enum):
REGULAR = "regular"
TURBO = "turbo"

def get_setting(key: str) -> Any:
try:
Expand Down
12 changes: 7 additions & 5 deletions pr_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[config]
model="gpt-4" # "gpt-4-0125-preview"
model_turbo="gpt-4-0125-preview"
fallback_models=["gpt-3.5-turbo-16k"]
git_provider="github"
publish_output=true
Expand Down Expand Up @@ -68,17 +69,18 @@ enable_help_text=true


[pr_code_suggestions] # /improve #
max_context_tokens=8000
num_code_suggestions=4
summarize = true
extra_instructions = ""
rank_suggestions = false
enable_help_text=true
# params for '/improve --extended' mode
auto_extended_mode=false
num_code_suggestions_per_chunk=8
rank_extended_suggestions = true
max_number_of_calls = 5
final_clip_factor = 0.9
auto_extended_mode=true
num_code_suggestions_per_chunk=5
rank_extended_suggestions = false
max_number_of_calls = 3
final_clip_factor = 0.8

[pr_add_docs] # /add_docs #
extra_instructions = ""
Expand Down
14 changes: 11 additions & 3 deletions pr_agent/tools/pr_code_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, replace_code_tags
from pr_agent.algo.utils import load_yaml, replace_code_tags, ModelType
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
Expand All @@ -26,6 +26,14 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None,
self.git_provider.get_languages(), self.git_provider.get_files()
)

# limit context specifically for the improve command, which has hard input to parse:
if get_settings().pr_code_suggestions.max_context_tokens:
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE


# extended mode
try:
self.is_extended = self._get_is_extended(args or [])
Expand Down Expand Up @@ -64,10 +72,10 @@ async def run(self):

get_logger().info('Preparing PR code suggestions...')
if not self.is_extended:
await retry_with_fallback_models(self._prepare_prediction)
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
data = self._prepare_pr_code_suggestions()
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended)
data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO)


if (not data) or (not 'code_suggestions' in data):
Expand Down

0 comments on commit cb8ff2b

Please sign in to comment.