diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index 63a628a57..0f647bf40 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -9,6 +9,7 @@ 'gpt-4-0613': 8000, 'gpt-4-32k': 32000, 'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens + 'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens 'claude-instant-1': 100000, 'claude-2': 100000, 'command-nightly': 4096, diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 1e482dbfb..61061bbd5 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -11,7 +11,7 @@ from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.file_filter import filter_ignored from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import get_max_tokens +from pr_agent.algo.utils import get_max_tokens, ModelType from pr_agent.config_loader import get_settings from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE from pr_agent.log import get_logger @@ -220,8 +220,8 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo return patches, modified_files_list, deleted_files_list, added_files_list -async def retry_with_fallback_models(f: Callable): - all_models = _get_all_models() +async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR): + all_models = _get_all_models(model_type) all_deployments = _get_all_deployments(all_models) # try each (model, deployment_id) pair until one is successful, otherwise raise exception for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)): @@ -243,8 +243,11 @@ async def retry_with_fallback_models(f: Callable): raise # Re-raise the last exception -def _get_all_models() -> List[str]: - model = get_settings().config.model +def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]: + if model_type == ModelType.TURBO: + model = get_settings().config.model_turbo + else: + model = get_settings().config.model fallback_models = get_settings().config.fallback_models if not isinstance(fallback_models, list): fallback_models = [m.strip() for m in fallback_models.split(",")] diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 9a150864a..e92c32197 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -5,6 +5,7 @@ import re import textwrap from datetime import datetime +from enum import Enum from typing import Any, List import yaml @@ -15,6 +16,9 @@ from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger +class ModelType(str, Enum): + REGULAR = "regular" + TURBO = "turbo" def get_setting(key: str) -> Any: try: diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 3a9b8d39c..5c7a35eb7 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -1,5 +1,6 @@ [config] model="gpt-4" # "gpt-4-0125-preview" +model_turbo="gpt-4-0125-preview" fallback_models=["gpt-3.5-turbo-16k"] git_provider="github" publish_output=true @@ -68,17 +69,18 @@ enable_help_text=true [pr_code_suggestions] # /improve # +max_context_tokens=8000 num_code_suggestions=4 summarize = true extra_instructions = "" rank_suggestions = false enable_help_text=true # params for '/improve --extended' mode -auto_extended_mode=false -num_code_suggestions_per_chunk=8 -rank_extended_suggestions = true -max_number_of_calls = 5 -final_clip_factor = 0.9 +auto_extended_mode=true +num_code_suggestions_per_chunk=5 +rank_extended_suggestions = false +max_number_of_calls = 3 +final_clip_factor = 0.8 [pr_add_docs] # /add_docs # extra_instructions = "" diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 381c02a65..fb7258488 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -8,7 +8,7 @@ from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml, replace_code_tags +from pr_agent.algo.utils import load_yaml, replace_code_tags, ModelType from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -26,6 +26,14 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None, self.git_provider.get_languages(), self.git_provider.get_files() ) + # limit context specifically for the improve command, which has hard input to parse: + if get_settings().pr_code_suggestions.max_context_tokens: + MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens + if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE: + get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve") + get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE + + # extended mode try: self.is_extended = self._get_is_extended(args or []) @@ -64,10 +72,10 @@ async def run(self): get_logger().info('Preparing PR code suggestions...') if not self.is_extended: - await retry_with_fallback_models(self._prepare_prediction) + await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO) data = self._prepare_pr_code_suggestions() else: - data = await retry_with_fallback_models(self._prepare_prediction_extended) + data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO) if (not data) or (not 'code_suggestions' in data):