Skip to content

Commit

Permalink
expaned chunking to more different blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
sauravpanda committed Sep 16, 2024
1 parent 71a5a3b commit 3c893a3
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 77 deletions.
100 changes: 84 additions & 16 deletions kaizen/retriever/code_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,100 @@
def chunk_code(code: str, language: str) -> ParsedBody:
parser = ParserFactory.get_parser(language)
tree = parser.parse(code.encode("utf8"))

code_bytes = code.encode("utf8")
body: ParsedBody = {
"imports": [],
"global_variables": [],
"type_definitions": [],
"functions": {},
"async_functions": {},
"classes": {},
"hooks": {},
"components": {},
"jsx_elements": [],
"other_blocks": [],
}
# code_bytes = code.encode("utf8")

def process_node(node):
result = parse_code(code, language)
result = parse_code(node, code_bytes)
if result:
# Assuming parse_code is modified to return line numbers
start_line = result.get("start_line", 0)
end_line = result.get("end_line", 0)

if result["type"] == "function":
if result["type"] == "import_statement":
body["imports"].append(
{
"code": result["code"],
"start_line": start_line,
"end_line": end_line,
}
)
elif (
result["type"] == "variable_declaration"
and node.parent.type == "program"
):
body["global_variables"].append(
{
"code": result["code"],
"start_line": start_line,
"end_line": end_line,
}
)
elif result["type"] in ["type_alias", "interface_declaration"]:
body["type_definitions"].append(
{
"name": result["name"],
"code": result["code"],
"start_line": start_line,
"end_line": end_line,
}
)
elif result["type"] == "function":
if is_react_hook(result["name"]):
body["hooks"][result["name"]] = {
"code": result["code"],
"start_line": start_line,
"end_line": end_line,
}
elif is_react_component(result["code"]):
body["components"][result["name"]] = result["code"]
body["components"][result["name"]] = {
"code": result["code"],
"start_line": start_line,
"end_line": end_line,
}
elif "async" in result["code"].split()[0]:
body["async_functions"][result["name"]] = {
"code": result["code"],
"start_line": start_line,
"end_line": end_line,
}
else:
body["functions"][result["name"]] = result["code"]
body["functions"][result["name"]] = {
"code": result["code"],
"start_line": start_line,
"end_line": end_line,
}
elif result["type"] == "class":
if is_react_component(result["code"]):
body["components"][result["name"]] = result["code"]
body["components"][result["name"]] = {
"code": result["code"],
"start_line": start_line,
"end_line": end_line,
}
else:
body["classes"][result["name"]] = result["code"]
elif result["type"] == "component":
body["components"][result["name"]] = result["code"]
elif result["type"] == "impl":
body["classes"][result["name"]] = result["code"]
body["classes"][result["name"]] = {
"code": result["code"],
"start_line": start_line,
"end_line": end_line,
}
elif result["type"] == "jsx_element":
body["jsx_elements"].append(
{
"code": result["code"],
"start_line": start_line,
"end_line": end_line,
}
)
else:
for child in node.children:
process_node(child)
Expand All @@ -55,8 +112,14 @@ def process_node(node):
for section in body.values():
if isinstance(section, dict):
for code_block in section.values():
start = code.index(code_block)
collected_ranges.append((start, start + len(code_block)))
collected_ranges.append(
(code_block["start_line"], code_block["end_line"])
)
elif isinstance(section, list):
for code_block in section:
collected_ranges.append(
(code_block["start_line"], code_block["end_line"])
)

collected_ranges.sort()
last_end = 0
Expand All @@ -76,5 +139,10 @@ def is_react_hook(name: str) -> bool:

def is_react_component(code: str) -> bool:
return (
"React" in code or "jsx" in code.lower() or "tsx" in code.lower() or "<" in code
"React" in code
or "jsx" in code.lower()
or "tsx" in code.lower()
or "<" in code
or "props" in code
or "render" in code
)
129 changes: 75 additions & 54 deletions kaizen/retriever/llama_index_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from llama_index.embeddings.litellm import LiteLLMEmbedding
from sqlalchemy import create_engine, text
from kaizen.retriever.qdrant_vector_store import QdrantVectorStore

import json

# Set up logging
logging.basicConfig(
Expand Down Expand Up @@ -43,10 +43,18 @@ def __init__(self, repo_id=1):
)
logger.info("RepositoryAnalyzer initialized successfully")

def setup_repository(self, repo_path: str, node_query: str = None):
def setup_repository(
self,
repo_path: str,
node_query: str = None,
file_query: str = None,
function_query: str = None,
):
self.total_usage = self.llm_provider.DEFAULT_USAGE
self.total_files_processed = 0
self.node_query = node_query
self.file_query = file_query
self.function_query = function_query
self.embedding_usage = {"prompt_tokens": 10, "total_tokens": 10}
logger.info(f"Starting repository setup for: {repo_path}")
self.parse_repository(repo_path)
Expand Down Expand Up @@ -130,7 +138,7 @@ def process_code_block(
return # Skip this code block

language = self.get_language_from_extension(file_path)
abstraction, usage = self.generate_abstraction(code, language)
abstraction, usage = self.generate_abstraction(code, language, section)
self.total_usage = self.llm_provider.update_usage(
total_usage=self.total_usage, current_usage=usage
)
Expand Down Expand Up @@ -185,70 +193,85 @@ def store_abstraction_and_embedding(self, function_id: int, abstraction: str):
logger.debug(f"Abstraction and embedding stored for function_id: {function_id}")

def generate_abstraction(
self, code_block: str, language: str, max_tokens: int = 300
self, code_block: str, language: str, section: str, max_tokens: int = 300
) -> str:
prompt = f"""Analyze the following {language} code block and generate a structured abstraction.
Your response should be in YAML format and include the following sections:
Your response should be in JSON format and include the following sections:
{{
"summary": "A concise one-sentence summary of the function's primary purpose.",
summary: A concise one-sentence summary of the function's primary purpose.
"functionality": "A detailed explanation of what the function does, including its main steps and logic. Use multiple lines if needed for clarity.",
functionality: |
A detailed explanation of what the function does, including its main steps and logic.
Use multiple lines if needed for clarity.
"inputs": [
{{
"name": "The parameter name",
"type": "The parameter type",
"description": "A brief description of the parameter's purpose",
"default_value": "The default value, if any (or null if not applicable)"
}}
],
inputs:
- name: The parameter name
type: The parameter type
description: A brief description of the parameter's purpose
default_value: The default value, if any (or null if not applicable)
"output": {{
"type": "The return type of the function",
"description": "A description of what is returned and under what conditions. Use multiple lines if needed."
}},
output:
type: The return type of the function
description: |
A description of what is returned and under what conditions.
Use multiple lines if needed.
"dependencies": [
{{
"name": "Name of the external library or module",
"purpose": "Brief explanation of its use in this function"
}}
],
dependencies:
- name: Name of the external library or module
purpose: Brief explanation of its use in this function
"algorithms": [
{{
"name": "Name of the algorithm or data structure",
"description": "Brief explanation of its use and importance"
}}
],
algorithms:
- name: Name of the algorithm or data structure
description: Brief explanation of its use and importance
"edge_cases": [
"A list of potential edge cases or special conditions the function handles or should handle"
],
edge_cases:
- A list of potential edge cases or special conditions the function handles or should handle
"error_handling": "A description of how errors are handled or propagated. Include specific error types if applicable.",
error_handling: |
A description of how errors are handled or propagated.
Include specific error types if applicable.
"usage_context": "A brief explanation of how this function might be used by parent functions or in a larger system. Include typical scenarios and any important considerations for its use.",
usage_context: |
A brief explanation of how this function might be used by parent functions or in a larger system.
Include typical scenarios and any important considerations for its use.
"complexity": {{
"time": "Estimated time complexity (e.g., O(n))",
"space": "Estimated space complexity (e.g., O(1))",
"explanation": "Brief explanation of the complexity analysis"
}},
complexity:
time: Estimated time complexity (e.g., O(n))
space: Estimated space complexity (e.g., O(1))
"tags": ["List", "of", "relevant", "tags"],
code_snippet: |
```{language}
{code_block}
```
"testing_considerations": "Suggestions for unit tests or test cases to cover key functionality and edge cases",
Provide your analysis in this clear, structured YAML format. If any section is not applicable, use an empty list [] or null value as appropriate. Ensure that multi-line descriptions are properly indented under their respective keys.
"version_compatibility": "Information about language versions or dependency versions this code is compatible with",
"performance_considerations": "Any notes on performance optimizations or potential bottlenecks",
"security_considerations": "Any security-related notes or best practices relevant to this code",
"maintainability_score": "A subjective score from 1-10 on how easy the code is to maintain, with a brief explanation"
}}
Provide your analysis in this clear, structured JSON format. If any section is not applicable, use an empty list [] or null value as appropriate. Ensure that multi-line descriptions are properly formatted as strings.
Code to analyze:
```{language}
{code_block}
```
Language: {language}
Block Type: {section}
Code Block:
```{code_block}```
"""

estimated_prompt_tokens = len(tokenizer.encode(prompt))
adjusted_max_tokens = min(max(150, estimated_prompt_tokens), 1000)

try:
abstraction, usage = self.llm_provider.chat_completion(
abstraction, usage = self.llm_provider.chat_completion_with_json(
prompt="",
messages=[
{
Expand All @@ -259,7 +282,7 @@ def generate_abstraction(
],
custom_model={"max_tokens": adjusted_max_tokens, "model": "small"},
)
return abstraction, usage
return json.dumps(abstraction), usage

except Exception as e:
raise e
Expand All @@ -272,21 +295,19 @@ def store_code_in_db(
section: str,
name: str,
start_line: int,
file_query: str = None,
function_query: str = None,
) -> int:
logger.debug(f"Storing code in DB: {file_path} - {section} - {name}")
with self.engine.begin() as connection:
# Insert into files table (assuming this part is already correct)
if not file_query:
file_query = """
if not self.file_query:
self.file_query = """
INSERT INTO files (repo_id, file_path, file_name, file_ext, programming_language)
VALUES (:repo_id, :file_path, :file_name, :file_ext, :programming_language)
ON CONFLICT (repo_id, file_path) DO UPDATE SET file_path = EXCLUDED.file_path
RETURNING file_id
"""
file_id = connection.execute(
text(file_query),
text(self.file_query),
{
"repo_id": self.repo_id,
"file_path": file_path,
Expand All @@ -297,15 +318,15 @@ def store_code_in_db(
).scalar_one()

# Insert into function_abstractions table
if not function_query:
function_query = """
if not self.function_query:
self.function_query = """
INSERT INTO function_abstractions
(file_id, function_name, function_signature, abstract_functionality, start_line, end_line)
VALUES (:file_id, :function_name, :function_signature, :abstract_functionality, :start_line, :end_line)
RETURNING function_id
"""
function_id = connection.execute(
text(function_query),
text(self.function_query),
{
"file_id": file_id,
"function_name": name,
Expand Down
5 changes: 4 additions & 1 deletion kaizen/retriever/qdrant_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class QdrantVectorStore:
def __init__(self, collection_name, vector_size, max_retries=3, retry_delay=2):
self.HOST = os.getenv("QDRANT_HOST", "localhost")
self.PORT = os.getenv("QDRANT_PORT", "6333")
self.QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
self.collection_name = collection_name
self.max_retries = max_retries
self.retry_delay = retry_delay
Expand Down Expand Up @@ -49,7 +50,9 @@ def _create_collection(self, vector_size):

def add(self, nodes):
points = [
PointStruct(id=node.id_, vector=node.embedding, payload=node.metadata)
PointStruct(
id=node["id"], vector=node["embedding"], payload=node["metadata"]
)
for node in nodes
]
self.client.upsert(collection_name=self.collection_name, points=points)
Expand Down
8 changes: 3 additions & 5 deletions kaizen/retriever/tree_sitter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,11 @@ def traverse_tree(node, code_bytes: bytes) -> Dict[str, Any]:
return None


def parse_code(code: str, language: str) -> Dict[str, Any]:
def parse_code(node: Any, code_bytes: bytes) -> Dict[str, Any]:
try:
parser = ParserFactory.get_parser(language)
tree = parser.parse(bytes(code, "utf8"))
return traverse_tree(tree.root_node, code.encode("utf8"))
return traverse_tree(node, code_bytes)
except Exception as e:
logger.error(f"Failed to parse {language} code: {str(e)}")
logger.error(f"Failed to parse code: {str(e)}")
raise


Expand Down
Loading

0 comments on commit 3c893a3

Please sign in to comment.