Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Fix dropout is null problem #51

Merged
merged 4 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion repo_agent/chat_with_repo/gradio_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import markdown
from repo_agent.log import logger


class GradioInterface:
def __init__(self, respond_function):
self.respond = respond_function
Expand Down Expand Up @@ -171,4 +172,4 @@ def respond_function(msg, system):
"""
return msg, RAG, "Embedding_recall_output", "Key_words_output", "Code_output"

gradio_interface = GradioInterface(respond_function)
gradio_interface = GradioInterface(respond_function)
88 changes: 38 additions & 50 deletions repo_agent/chat_with_repo/json_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@
import sys
from repo_agent.log import logger


class JsonFileProcessor:
def __init__(self, file_path):
self.file_path = file_path

def read_json_file(self):
# 读取 JSON 文件作为数据库
with open(self.file_path, 'r', encoding = 'utf-8') as file:
data = json.load(file)
return data
def extract_md_contents(self):
"""
Extracts the contents of 'md_content' from a JSON file.
try:
with open(self.file_path, "r", encoding="utf-8") as file:
data = json.load(file)
return data
except FileNotFoundError:
logger.exception(f"File not found: {self.file_path}")
sys.exit(1)

Returns:
A list of strings representing the contents of 'md_content'.
"""
def extract_data(self):
# Load JSON data from a file
json_data = self.read_json_file()
md_contents = []
extracted_contents = []
# Iterate through each file in the JSON data
for file, items in json_data.items():
# Check if the value is a list (new format)
Expand All @@ -31,64 +31,52 @@ def extract_md_contents(self):
if "md_content" in item and item["md_content"]:
# Append the first element of 'md_content' to the result list
md_contents.append(item["md_content"][0])
return md_contents

def extract_metadata(self):
"""
Extracts metadata from JSON data.

Returns:
A list of dictionaries containing the extracted metadata.
"""
# Load JSON data from a file
json_data = self.read_json_file()
extracted_contents = []
# Iterate through each file in the JSON data
for file_name, items in json_data.items():
# Check if the value is a list (new format)
if isinstance(items, list):
# Iterate through each item in the list
for item in items:
# Build a dictionary containing the required information
item_dict = {
"type": item.get("type", "UnknownType"),
"name": item.get("name", "Unnamed"),
"code_start_line": item.get("code_start_line", -1),
"code_end_line": item.get("code_end_line", -1),
"have_return": item.get("have_return", False),
"code_content": item.get("code_content", "NoContent"),
"name_column": item.get("name_column", 0),
"item_status": item.get("item_status", "UnknownStatus"),
# Adapt or remove fields based on new structure requirements
}
extracted_contents.append(item_dict)
return extracted_contents
# Build a dictionary containing the required information
item_dict = {
"type": item.get("type", "UnknownType"),
"name": item.get("name", "Unnamed"),
"code_start_line": item.get("code_start_line", -1),
"code_end_line": item.get("code_end_line", -1),
"have_return": item.get("have_return", False),
"code_content": item.get("code_content", "NoContent"),
"name_column": item.get("name_column", 0),
"item_status": item.get("item_status", "UnknownStatus"),
# Adapt or remove fields based on new structure requirements
}
extracted_contents.append(item_dict)
return md_contents,extracted_contents

def recursive_search(self, data_item, search_text, results):
def recursive_search(self, data_item, search_text, code_results, md_results):
if isinstance(data_item, dict):
# Direct comparison is removed as there's no direct key==search_text in the new format
for key, value in data_item.items():
# Recursively search through dictionary values and lists
if isinstance(value, (dict, list)):
self.recursive_search(value, search_text, results)
self.recursive_search(value, search_text,code_results, md_results)
elif isinstance(data_item, list):
for item in data_item:
# Now we check for the 'name' key in each item of the list
if isinstance(item, dict) and item.get('name') == search_text:
# If 'code_content' exists, append it to results
if 'code_content' in item:
results.append(item['code_content'])
code_results.append(item['code_content'])
md_results.append(item['md_content'])
# Recursive call in case of nested lists or dicts
self.recursive_search(item, search_text, results)
self.recursive_search(item, search_text, code_results, md_results)

def search_code_contents_by_name(self, file_path, search_text):
# Attempt to retrieve code from the JSON file
try:
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
results = [] # List to store matching items' code_content
self.recursive_search(data, search_text, results)
return results if results else "No matching item found."
code_results = []
md_results = [] # List to store matching items' code_content and md_content
self.recursive_search(data, search_text, code_results, md_results)
# 确保无论结果如何都返回两个值
if code_results or md_results:
return code_results, md_results
else:
return ["No matching item found."], ["No matching item found."]
except FileNotFoundError:
return "File not found."
except json.JSONDecodeError:
Expand All @@ -99,4 +87,4 @@ def search_code_contents_by_name(self, file_path, search_text):

if __name__ == "__main__":
processor = JsonFileProcessor("database.json")
md_contents = processor.extract_md_contents()
md_contents = processor.extract_md_contents()
5 changes: 2 additions & 3 deletions repo_agent/chat_with_repo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ def main():
api_key = CONFIG["api_keys"][_model][0]["api_key"]
api_base = CONFIG["api_keys"][_model][0]["base_url"]
db_path = os.path.join(
CONFIG["repo_path"], CONFIG["project_hierarchy"], ".project_hierarchy.json"
CONFIG["repo_path"], CONFIG["project_hierarchy"], "project_hierarchy.json"
)

assistant = RepoAssistant(api_key, api_base, db_path)
md_contents = assistant.json_data.extract_md_contents()
meta_data = assistant.json_data.extract_metadata()
md_contents,meta_data = assistant.json_data.extract_data()
assistant.chroma_data.create_vector_store(md_contents, meta_data)
GradioInterface(assistant.respond)

Expand Down
11 changes: 5 additions & 6 deletions repo_agent/chat_with_repo/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from repo_agent.log import logger
from repo_agent.chat_with_repo.json_handler import JsonFileProcessor

# logger.add("./log.txt", level="DEBUG", format="{time} - {name} - {level} - {message}")

class TextAnalysisTool:
def __init__(self, llm, db_path):
Expand All @@ -25,8 +24,8 @@ def format_chat_prompt(self, message, instruction):
return prompt

def queryblock(self, message):
search_result = self.jsonsearch.search_code_contents_by_name(self.db_path, message)
return search_result
search_result,md = self.jsonsearch.search_code_contents_by_name(self.db_path, message)
return search_result,md

def list_to_markdown(self, search_result):
markdown_str = ""
Expand All @@ -38,14 +37,14 @@ def list_to_markdown(self, search_result):
return markdown_str

def nerquery(self, message):
instruction = """
query1 = """
The output must strictly be a pure function name or class name, without any additional characters.
For example:
Pure function names: calculateSum, processData
Pure class names: MyClass, DataProcessor
The output function name or class name should be only one.
"""
query = f"{instruction}\nExtract the most relevant class or function from the following input:\n{message}\nOutput:"
query = f"Extract the most relevant class or function base following instrcution {query1},here is input:\n{message}\nOutput:"
response = self.llm.complete(query)
# logger.debug(f"Input: {message}, Output: {response}")
return response
Expand All @@ -56,4 +55,4 @@ def nerquery(self, message):
log_file = "your_logfile_path"
llm = OpenAI(api_key=api_key, api_base=api_base)
db_path = "your_database_path"
test = TextAnalysisTool(llm, db_path)
test = TextAnalysisTool(llm, db_path)
70 changes: 54 additions & 16 deletions repo_agent/chat_with_repo/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from repo_agent.log import logger
from llama_index import PromptTemplate
from llama_index.llms import OpenAI
import json
from openai import OpenAI as AI

# logger.add("./log.txt", level="DEBUG", format="{time} - {name} - {level} - {message}")

class RepoAssistant:
def __init__(self, api_key, api_base, db_path):
Expand All @@ -16,6 +17,7 @@ def __init__(self, api_key, api_base, db_path):
self.md_contents = []
self.llm = OpenAI(api_key=api_key, api_base=api_base,model="gpt-3.5-turbo-1106")
self.client = OpenAI(api_key=api_key, api_base=api_base,model="gpt-4-1106-preview")
self.lm = AI(api_key = api_key, base_url = api_base)
self.textanslys = TextAnalysisTool(self.llm,db_path)
self.json_data = JsonFileProcessor(db_path)
self.chroma_data = ChromaManager(api_key, api_base)
Expand All @@ -37,10 +39,26 @@ def generate_queries(self, query_str: str, num_queries: int = 4):
queries = response.text.split("\n")
return queries

def rerank(self, query ,docs):
response = self.lm.chat.completions.create(
model='gpt-4-1106-preview',
response_format={"type": "json_object"},
temperature=0,
messages=[
{"role": "system", "content": "You are an expert relevance ranker. Given a list of documents and a query, your job is to determine how relevant each document is for answering the query. Your output is JSON, which is a list of documents. Each document has two fields, content and score. relevance_score is from 0.0 to 100.0. Higher relevance means higher score."},
{"role": "user", "content": f"Query: {query} Docs: {docs}"}
]
)
scores = json.loads(response.choices[0].message.content)["documents"]
logger.debug(f"scores: {scores}")
sorted_data = sorted(scores, key=lambda x: x['relevance_score'], reverse=True)
top_5_contents = [doc['content'] for doc in sorted_data[:5]]
return top_5_contents

def rag(self, query, retrieved_documents):
# rag
information = "\n\n".join(retrieved_documents)
messages = f"You are a helpful expert repo research assistant. Your users are asking questions about information contained in a repository. You will be shown the user's question, and the relevant information from the repository. Answer the user's question using only the information given.\nQuestion: {query}. \nInformation: {information}"
messages = f"You are a helpful expert repo research assistant. Your users are asking questions about information contained in repo . You will be shown the user's question, and the relevant information from the repo. Answer the user's question using only this information.\nQuestion: {query}. \nInformation: {information}"
response = self.llm.complete(messages)
content = response
return content
Expand All @@ -49,7 +67,7 @@ def list_to_markdown(self,list_items):

# 对于列表中的每个项目,添加一个带数字的列表项
for index, item in enumerate(list_items, start=1):
markdown_content += f"[{index}] {item}\n"
markdown_content += f"{index}. {item}\n"

return markdown_content
def rag_ar(self, query, related_code, embedding_recall, project_name):
Expand Down Expand Up @@ -83,7 +101,7 @@ def respond(self, message, instruction):
# return answer
prompt = self.textanslys.format_chat_prompt(message, instruction)
questions = self.textanslys.keyword(prompt)
logger.debug(f"Questions: {questions}")
# logger.debug(f"Questions: {questions}")
promptq = self.generate_queries(prompt,3)
all_results = []
all_ids = []
Expand All @@ -93,7 +111,8 @@ def respond(self, message, instruction):
all_ids.extend(query_result['ids'][0])

logger.debug(f"all_ids: {all_ids},{all_results}")
unique_ids = [id for id in all_ids if all_ids.count(id) == 1]
unique_ids = list(dict.fromkeys(all_ids))
# unique_ids = [id for id in all_ids if all_ids.count(id) == 1]
logger.debug(f"uniqueid: {unique_ids}")
unique_documents = []
unique_code = []
Expand All @@ -102,30 +121,49 @@ def respond(self, message, instruction):
if id in unique_ids:
unique_documents.append(doc)
unique_code.append(code.get("code_content"))
unique_code=self.textanslys.list_to_markdown(unique_code)
retrieved_documents = unique_documents

retrieved_documents = self.rerank(message,unique_documents)
# logger.debug(f"retrieveddocuments: {retrieved_documents}")
response = self.rag(prompt,retrieved_documents)
chunkrecall = self.list_to_markdown(retrieved_documents)
bot_message = str(response)
keyword = str(self.textanslys.nerquery(bot_message))
keywords = str(self.textanslys.nerquery(str(prompt)+str(questions)))
codez=self.textanslys.queryblock(keyword)
codey=self.textanslys.queryblock(keywords)
codez,mdz=self.textanslys.queryblock(keyword)
codey,mdy=self.textanslys.queryblock(keywords)
if not isinstance(codez, list):
codex = [codez]
codez = [codez]
if not isinstance(mdz, list):
mdz = [mdz]
# 确保 codey 是列表,如果不是,则将其转换为列表
if not isinstance(codey, list):
codey = [codey]
if not isinstance(mdy, list):
mdy = [mdy]

codex = codez+codey
codex = self.textanslys.list_to_markdown(codex)
bot_message = self.rag_ar(prompt,unique_code,retrieved_documents,"test")
bot_message = str(bot_message) +'\n'+ str(self.textanslys.tree(bot_message))
return message, bot_message,chunkrecall,questions,unique_code,codex

md = mdz + mdy
unique_mdx = list(set([item for sublist in md for item in sublist]))
uni_codex = []
uni_md = []
uni_codex = list(dict.fromkeys(codex))
uni_md = list(dict.fromkeys(unique_mdx))
codex = self.textanslys.list_to_markdown(uni_codex)
retrieved_documents = retrieved_documents+uni_md
retrieved_documents = list(dict.fromkeys(retrieved_documents))
retrieved_documents = self.rerank(message,retrieved_documents[:6])
uni_code = uni_codex+unique_code
uni_code = list(dict.fromkeys(uni_code))
uni_code = self.rerank(message,uni_code[:6])
unique_code=self.textanslys.list_to_markdown(unique_code)
bot_message = self.rag_ar(prompt,uni_code,retrieved_documents,"test")
bot_message = str(bot_message)
return message, bot_message, chunkrecall, questions, unique_code, codex


if __name__ == "__main__":
api_key = ""
api_base = ""
db_path = ""
log_file = ""
assistant = RepoAssistant(api_key, api_base, db_path, log_file)
assistant = RepoAssistant(api_key, api_base, db_path, log_file)
Loading