diff --git a/mentat/code_context.py b/mentat/code_context.py index 3990e89d6..0cf65a6c5 100644 --- a/mentat/code_context.py +++ b/mentat/code_context.py @@ -8,6 +8,7 @@ from mentat.code_feature import CodeFeature, get_consolidated_feature_refs from mentat.diff_context import DiffContext from mentat.errors import PathValidationError +from mentat.git_handler import get_git_root_for_path from mentat.include_files import ( PathType, get_code_features_for_path, @@ -67,6 +68,11 @@ async def refresh_daemon(self): cwd = ctx.cwd llm_api_handler = ctx.llm_api_handler + # Use print because stream is not initialized yet + print("Scanning codebase for updates...") + if not get_git_root_for_path(cwd, raise_error=False): + print("\033[93mWarning: Not a git repository (this might take a while)\033[0m") + annotators: dict[str, dict[str, Any]] = { "hierarchy": {"ignore_patterns": [str(p) for p in self.ignore_patterns]}, "chunker_line": {"lines_per_chunk": 50}, @@ -185,11 +191,15 @@ async def get_code_message( auto_tokens=auto_tokens, ) for ref in context_builder.to_refs(): + new_features = list[CodeFeature]() # Save ragdaemon context back to include_files path, interval_str = split_intervals_from_path(Path(ref)) - intervals = parse_intervals(interval_str) - for interval in intervals: - feature = CodeFeature(cwd / path, interval) - self.include_features([feature]) # Save ragdaemon context back to include_files + if not interval_str: + new_features.append(CodeFeature(cwd / path)) + else: + intervals = parse_intervals(interval_str) + for interval in intervals: + new_features.append(CodeFeature(cwd / path, interval)) + self.include_features(new_features) # The context message is rendered by ragdaemon (ContextBuilder.render()) context_message = context_builder.render() @@ -417,10 +427,14 @@ async def search( continue distance = node["distance"] path, interval = split_intervals_from_path(Path(node["ref"])) - intervals = parse_intervals(interval) - for _interval in intervals: - feature = CodeFeature(cwd / path, _interval) + if not interval: + feature = CodeFeature(cwd / path) all_features_sorted.append((feature, distance)) + else: + intervals = parse_intervals(interval) + for _interval in intervals: + feature = CodeFeature(cwd / path, _interval) + all_features_sorted.append((feature, distance)) if max_results is None: return all_features_sorted else: diff --git a/mentat/config.py b/mentat/config.py index 97093eb7f..4468d18af 100644 --- a/mentat/config.py +++ b/mentat/config.py @@ -42,7 +42,7 @@ class Config: ) provider: Optional[str] = attr.field(default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]}) embedding_model: str = attr.field( - default="text-embedding-ada-002", + default="text-embedding-3-large", metadata={"auto_completions": [model.name for model in models if isinstance(model, EmbeddingModel)]}, ) embedding_provider: Optional[str] = attr.field( diff --git a/mentat/llm_api_handler.py b/mentat/llm_api_handler.py index eb309e4b4..5152e88cd 100644 --- a/mentat/llm_api_handler.py +++ b/mentat/llm_api_handler.py @@ -217,7 +217,7 @@ async def call_llm_api( return response @api_guard - def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-ada-002") -> EmbeddingResponse: + def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-3-large") -> EmbeddingResponse: ctx = SESSION_CONTEXT.get() return self.spice.get_embeddings_sync(input_texts, model, provider=ctx.config.embedding_provider) diff --git a/mentat/session.py b/mentat/session.py index 0ebb05078..088cfc723 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -164,7 +164,6 @@ async def _main(self): await session_context.llm_api_handler.initialize_client() - print("Scanning codebase for updates...") await code_context.refresh_daemon() check_model()