Skip to content

Commit

Permalink
Merge pull request #51 from innovation64/chat_with_repo
Browse files Browse the repository at this point in the history
Fix: Fix dropout is null problem
  • Loading branch information
LOGIC-10 authored Feb 18, 2024
2 parents 2a15f13 + b1921b7 commit 3d4a3cb
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 88 deletions.
3 changes: 2 additions & 1 deletion repo_agent/chat_with_repo/gradio_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import markdown
from repo_agent.log import logger


class GradioInterface:
def __init__(self, respond_function):
self.respond = respond_function
Expand Down Expand Up @@ -171,4 +172,4 @@ def respond_function(msg, system):
"""
return msg, RAG, "Embedding_recall_output", "Key_words_output", "Code_output"

gradio_interface = GradioInterface(respond_function)
gradio_interface = GradioInterface(respond_function)
88 changes: 38 additions & 50 deletions repo_agent/chat_with_repo/json_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@
import sys
from repo_agent.log import logger


class JsonFileProcessor:
def __init__(self, file_path):
self.file_path = file_path

def read_json_file(self):
# 读取 JSON 文件作为数据库
with open(self.file_path, 'r', encoding = 'utf-8') as file:
data = json.load(file)
return data
def extract_md_contents(self):
"""
Extracts the contents of 'md_content' from a JSON file.
try:
with open(self.file_path, "r", encoding="utf-8") as file:
data = json.load(file)
return data
except FileNotFoundError:
logger.exception(f"File not found: {self.file_path}")
sys.exit(1)

Returns:
A list of strings representing the contents of 'md_content'.
"""
def extract_data(self):
# Load JSON data from a file
json_data = self.read_json_file()
md_contents = []
extracted_contents = []
# Iterate through each file in the JSON data
for file, items in json_data.items():
# Check if the value is a list (new format)
Expand All @@ -31,64 +31,52 @@ def extract_md_contents(self):
if "md_content" in item and item["md_content"]:
# Append the first element of 'md_content' to the result list
md_contents.append(item["md_content"][0])
return md_contents

def extract_metadata(self):
"""
Extracts metadata from JSON data.
Returns:
A list of dictionaries containing the extracted metadata.
"""
# Load JSON data from a file
json_data = self.read_json_file()
extracted_contents = []
# Iterate through each file in the JSON data
for file_name, items in json_data.items():
# Check if the value is a list (new format)
if isinstance(items, list):
# Iterate through each item in the list
for item in items:
# Build a dictionary containing the required information
item_dict = {
"type": item.get("type", "UnknownType"),
"name": item.get("name", "Unnamed"),
"code_start_line": item.get("code_start_line", -1),
"code_end_line": item.get("code_end_line", -1),
"have_return": item.get("have_return", False),
"code_content": item.get("code_content", "NoContent"),
"name_column": item.get("name_column", 0),
"item_status": item.get("item_status", "UnknownStatus"),
# Adapt or remove fields based on new structure requirements
}
extracted_contents.append(item_dict)
return extracted_contents
# Build a dictionary containing the required information
item_dict = {
"type": item.get("type", "UnknownType"),
"name": item.get("name", "Unnamed"),
"code_start_line": item.get("code_start_line", -1),
"code_end_line": item.get("code_end_line", -1),
"have_return": item.get("have_return", False),
"code_content": item.get("code_content", "NoContent"),
"name_column": item.get("name_column", 0),
"item_status": item.get("item_status", "UnknownStatus"),
# Adapt or remove fields based on new structure requirements
}
extracted_contents.append(item_dict)
return md_contents,extracted_contents

def recursive_search(self, data_item, search_text, results):
def recursive_search(self, data_item, search_text, code_results, md_results):
if isinstance(data_item, dict):
# Direct comparison is removed as there's no direct key==search_text in the new format
for key, value in data_item.items():
# Recursively search through dictionary values and lists
if isinstance(value, (dict, list)):
self.recursive_search(value, search_text, results)
self.recursive_search(value, search_text,code_results, md_results)
elif isinstance(data_item, list):
for item in data_item:
# Now we check for the 'name' key in each item of the list
if isinstance(item, dict) and item.get('name') == search_text:
# If 'code_content' exists, append it to results
if 'code_content' in item:
results.append(item['code_content'])
code_results.append(item['code_content'])
md_results.append(item['md_content'])
# Recursive call in case of nested lists or dicts
self.recursive_search(item, search_text, results)
self.recursive_search(item, search_text, code_results, md_results)

def search_code_contents_by_name(self, file_path, search_text):
# Attempt to retrieve code from the JSON file
try:
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
results = [] # List to store matching items' code_content
self.recursive_search(data, search_text, results)
return results if results else "No matching item found."
code_results = []
md_results = [] # List to store matching items' code_content and md_content
self.recursive_search(data, search_text, code_results, md_results)
# 确保无论结果如何都返回两个值
if code_results or md_results:
return code_results, md_results
else:
return ["No matching item found."], ["No matching item found."]
except FileNotFoundError:
return "File not found."
except json.JSONDecodeError:
Expand All @@ -99,4 +87,4 @@ def search_code_contents_by_name(self, file_path, search_text):

if __name__ == "__main__":
processor = JsonFileProcessor("database.json")
md_contents = processor.extract_md_contents()
md_contents = processor.extract_md_contents()
5 changes: 2 additions & 3 deletions repo_agent/chat_with_repo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ def main():
api_key = CONFIG["api_keys"][_model][0]["api_key"]
api_base = CONFIG["api_keys"][_model][0]["base_url"]
db_path = os.path.join(
CONFIG["repo_path"], CONFIG["project_hierarchy"], ".project_hierarchy.json"
CONFIG["repo_path"], CONFIG["project_hierarchy"], "project_hierarchy.json"
)

assistant = RepoAssistant(api_key, api_base, db_path)
md_contents = assistant.json_data.extract_md_contents()
meta_data = assistant.json_data.extract_metadata()
md_contents,meta_data = assistant.json_data.extract_data()
assistant.chroma_data.create_vector_store(md_contents, meta_data)
GradioInterface(assistant.respond)

Expand Down
11 changes: 5 additions & 6 deletions repo_agent/chat_with_repo/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from repo_agent.log import logger
from repo_agent.chat_with_repo.json_handler import JsonFileProcessor

# logger.add("./log.txt", level="DEBUG", format="{time} - {name} - {level} - {message}")

class TextAnalysisTool:
def __init__(self, llm, db_path):
Expand All @@ -25,8 +24,8 @@ def format_chat_prompt(self, message, instruction):
return prompt

def queryblock(self, message):
search_result = self.jsonsearch.search_code_contents_by_name(self.db_path, message)
return search_result
search_result,md = self.jsonsearch.search_code_contents_by_name(self.db_path, message)
return search_result,md

def list_to_markdown(self, search_result):
markdown_str = ""
Expand All @@ -38,14 +37,14 @@ def list_to_markdown(self, search_result):
return markdown_str

def nerquery(self, message):
instruction = """
query1 = """
The output must strictly be a pure function name or class name, without any additional characters.
For example:
Pure function names: calculateSum, processData
Pure class names: MyClass, DataProcessor
The output function name or class name should be only one.
"""
query = f"{instruction}\nExtract the most relevant class or function from the following input:\n{message}\nOutput:"
query = f"Extract the most relevant class or function base following instrcution {query1},here is input:\n{message}\nOutput:"
response = self.llm.complete(query)
# logger.debug(f"Input: {message}, Output: {response}")
return response
Expand All @@ -56,4 +55,4 @@ def nerquery(self, message):
log_file = "your_logfile_path"
llm = OpenAI(api_key=api_key, api_base=api_base)
db_path = "your_database_path"
test = TextAnalysisTool(llm, db_path)
test = TextAnalysisTool(llm, db_path)
70 changes: 54 additions & 16 deletions repo_agent/chat_with_repo/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from repo_agent.log import logger
from llama_index import PromptTemplate
from llama_index.llms import OpenAI
import json
from openai import OpenAI as AI

# logger.add("./log.txt", level="DEBUG", format="{time} - {name} - {level} - {message}")

class RepoAssistant:
def __init__(self, api_key, api_base, db_path):
Expand All @@ -16,6 +17,7 @@ def __init__(self, api_key, api_base, db_path):
self.md_contents = []
self.llm = OpenAI(api_key=api_key, api_base=api_base,model="gpt-3.5-turbo-1106")
self.client = OpenAI(api_key=api_key, api_base=api_base,model="gpt-4-1106-preview")
self.lm = AI(api_key = api_key, base_url = api_base)
self.textanslys = TextAnalysisTool(self.llm,db_path)
self.json_data = JsonFileProcessor(db_path)
self.chroma_data = ChromaManager(api_key, api_base)
Expand All @@ -37,10 +39,26 @@ def generate_queries(self, query_str: str, num_queries: int = 4):
queries = response.text.split("\n")
return queries

def rerank(self, query ,docs):
response = self.lm.chat.completions.create(
model='gpt-4-1106-preview',
response_format={"type": "json_object"},
temperature=0,
messages=[
{"role": "system", "content": "You are an expert relevance ranker. Given a list of documents and a query, your job is to determine how relevant each document is for answering the query. Your output is JSON, which is a list of documents. Each document has two fields, content and score. relevance_score is from 0.0 to 100.0. Higher relevance means higher score."},
{"role": "user", "content": f"Query: {query} Docs: {docs}"}
]
)
scores = json.loads(response.choices[0].message.content)["documents"]
logger.debug(f"scores: {scores}")
sorted_data = sorted(scores, key=lambda x: x['relevance_score'], reverse=True)
top_5_contents = [doc['content'] for doc in sorted_data[:5]]
return top_5_contents

def rag(self, query, retrieved_documents):
# rag
information = "\n\n".join(retrieved_documents)
messages = f"You are a helpful expert repo research assistant. Your users are asking questions about information contained in a repository. You will be shown the user's question, and the relevant information from the repository. Answer the user's question using only the information given.\nQuestion: {query}. \nInformation: {information}"
messages = f"You are a helpful expert repo research assistant. Your users are asking questions about information contained in repo . You will be shown the user's question, and the relevant information from the repo. Answer the user's question using only this information.\nQuestion: {query}. \nInformation: {information}"
response = self.llm.complete(messages)
content = response
return content
Expand All @@ -49,7 +67,7 @@ def list_to_markdown(self,list_items):

# 对于列表中的每个项目,添加一个带数字的列表项
for index, item in enumerate(list_items, start=1):
markdown_content += f"[{index}] {item}\n"
markdown_content += f"{index}. {item}\n"

return markdown_content
def rag_ar(self, query, related_code, embedding_recall, project_name):
Expand Down Expand Up @@ -83,7 +101,7 @@ def respond(self, message, instruction):
# return answer
prompt = self.textanslys.format_chat_prompt(message, instruction)
questions = self.textanslys.keyword(prompt)
logger.debug(f"Questions: {questions}")
# logger.debug(f"Questions: {questions}")
promptq = self.generate_queries(prompt,3)
all_results = []
all_ids = []
Expand All @@ -93,7 +111,8 @@ def respond(self, message, instruction):
all_ids.extend(query_result['ids'][0])

logger.debug(f"all_ids: {all_ids},{all_results}")
unique_ids = [id for id in all_ids if all_ids.count(id) == 1]
unique_ids = list(dict.fromkeys(all_ids))
# unique_ids = [id for id in all_ids if all_ids.count(id) == 1]
logger.debug(f"uniqueid: {unique_ids}")
unique_documents = []
unique_code = []
Expand All @@ -102,30 +121,49 @@ def respond(self, message, instruction):
if id in unique_ids:
unique_documents.append(doc)
unique_code.append(code.get("code_content"))
unique_code=self.textanslys.list_to_markdown(unique_code)
retrieved_documents = unique_documents

retrieved_documents = self.rerank(message,unique_documents)
# logger.debug(f"retrieveddocuments: {retrieved_documents}")
response = self.rag(prompt,retrieved_documents)
chunkrecall = self.list_to_markdown(retrieved_documents)
bot_message = str(response)
keyword = str(self.textanslys.nerquery(bot_message))
keywords = str(self.textanslys.nerquery(str(prompt)+str(questions)))
codez=self.textanslys.queryblock(keyword)
codey=self.textanslys.queryblock(keywords)
codez,mdz=self.textanslys.queryblock(keyword)
codey,mdy=self.textanslys.queryblock(keywords)
if not isinstance(codez, list):
codex = [codez]
codez = [codez]
if not isinstance(mdz, list):
mdz = [mdz]
# 确保 codey 是列表,如果不是,则将其转换为列表
if not isinstance(codey, list):
codey = [codey]
if not isinstance(mdy, list):
mdy = [mdy]

codex = codez+codey
codex = self.textanslys.list_to_markdown(codex)
bot_message = self.rag_ar(prompt,unique_code,retrieved_documents,"test")
bot_message = str(bot_message) +'\n'+ str(self.textanslys.tree(bot_message))
return message, bot_message,chunkrecall,questions,unique_code,codex

md = mdz + mdy
unique_mdx = list(set([item for sublist in md for item in sublist]))
uni_codex = []
uni_md = []
uni_codex = list(dict.fromkeys(codex))
uni_md = list(dict.fromkeys(unique_mdx))
codex = self.textanslys.list_to_markdown(uni_codex)
retrieved_documents = retrieved_documents+uni_md
retrieved_documents = list(dict.fromkeys(retrieved_documents))
retrieved_documents = self.rerank(message,retrieved_documents[:6])
uni_code = uni_codex+unique_code
uni_code = list(dict.fromkeys(uni_code))
uni_code = self.rerank(message,uni_code[:6])
unique_code=self.textanslys.list_to_markdown(unique_code)
bot_message = self.rag_ar(prompt,uni_code,retrieved_documents,"test")
bot_message = str(bot_message)
return message, bot_message, chunkrecall, questions, unique_code, codex


if __name__ == "__main__":
api_key = ""
api_base = ""
db_path = ""
log_file = ""
assistant = RepoAssistant(api_key, api_base, db_path, log_file)
assistant = RepoAssistant(api_key, api_base, db_path, log_file)
Loading

0 comments on commit 3d4a3cb

Please sign in to comment.