-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/issue 25/create chat with repo compoent #38
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
api_key: "" | ||
api_base: "" | ||
db_path: './project_hierachy.json' | ||
log_file: './log.txt' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import gradio as gr | ||
|
||
class GradioInterface: | ||
def __init__(self, respond_function): | ||
self.respond = respond_function | ||
self.setup_gradio_interface() | ||
|
||
def setup_gradio_interface(self): | ||
css=""" | ||
.markdown-container:nth-of-type(2){ | ||
max-height: 200px; /* 设置最大高度 */ | ||
overflow-y: auto; /* 超出部分显示滚动条 */ | ||
} | ||
|
||
.output-container:nth-of-type(5) { | ||
max-height: 150px; | ||
overflow-y: auto; | ||
} | ||
""" | ||
|
||
with gr.Blocks(css=css) as demo: | ||
gr.Markdown(""" | ||
# RepoChat Test | ||
This is a test for retrieval repo | ||
""") | ||
with gr.Row(): | ||
with gr.Column(scale=2): | ||
msg = gr.Textbox(label="Question Input") | ||
gr.Markdown("### question") | ||
question = gr.Markdown(label="qa") | ||
with gr.Accordion(label="Advanced options", open=False): | ||
system = gr.Textbox(label="System message", lines=2, value="A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.") | ||
output1 = gr.Textbox(label="RAG") | ||
|
||
with gr.Column(scale=1): | ||
output2 = gr.Textbox(label="Embedding recall") | ||
with gr.Column(scale=1): | ||
output3 = gr.Textbox(label="key words") | ||
code = gr.Textbox(label="code") | ||
|
||
btn = gr.Button("Submit") | ||
btn.click(self.respond, inputs=[msg, system], outputs=[msg, output1, output2, output3, code, question]) | ||
msg.submit(self.respond, inputs=[msg, system], outputs=[msg, output1, output2, output3, code, question]) # Press enter to submit | ||
|
||
gr.close_all() | ||
demo.queue().launch(share=True) | ||
|
||
# 使用方法 | ||
if __name__ == "__main__": | ||
def respond_function(msg, system): | ||
# 这里实现您的响应逻辑 | ||
return msg, "RAG_output", "Embedding_recall_output", "Key_words_output", "Code_output", "QA_output" | ||
|
||
gradio_interface = GradioInterface(respond_function) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import json | ||
|
||
class JsonFileProcessor: | ||
def __init__(self, file_path): | ||
self.file_path = file_path | ||
|
||
def read_json_file(self): | ||
# 读取 JSON 文件作为数据库 | ||
with open(self.file_path, 'r', encoding='utf-8') as file: | ||
data = json.load(file) | ||
return data | ||
|
||
def extract_md_contents(self): | ||
# 从 JSON 数据中提取 Markdown 内容并返回一个列表 | ||
json_data = self.read_json_file() | ||
md_contents = [] | ||
for file in json_data["files"]: | ||
for obj in file["objects"]: | ||
if "md_content" in obj: | ||
md_contents.append(obj["md_content"]) | ||
return md_contents | ||
|
||
def search_in_json_nested(self, file_path, search_text): | ||
# retrieve code from json | ||
try: | ||
with open(file_path, 'r',encoding='utf-8') as file: | ||
data = json.load(file) | ||
|
||
def recursive_search(data_item): | ||
if isinstance(data_item, dict): | ||
if 'name' in data_item and search_text.lower() in data_item['name'].lower(): | ||
return data_item | ||
|
||
for key, value in data_item.items(): | ||
if isinstance(value, (dict, list)): | ||
result = recursive_search(value) | ||
if result: | ||
return result | ||
elif isinstance(data_item, list): | ||
for item in data_item: | ||
result = recursive_search(item) | ||
if result: | ||
return result | ||
|
||
result = recursive_search(data) | ||
if result: | ||
return result | ||
else: | ||
return "No matching item found." | ||
|
||
except FileNotFoundError: | ||
return "File not found." | ||
except json.JSONDecodeError: | ||
return "Invalid JSON file." | ||
except Exception as e: | ||
return f"An error occurred: {e}" | ||
|
||
if __name__ == "__main__": | ||
processor = JsonFileProcessor("database.json") | ||
md_contents = processor.extract_md_contents() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import logging | ||
|
||
class LoggerManager: | ||
def __init__(self, log_file, log_level=logging.DEBUG): | ||
self.logger = logging.getLogger(__name__) | ||
self.logger.setLevel(log_level) | ||
file_handler = logging.FileHandler(log_file, encoding='utf-8') | ||
file_handler.setLevel(log_level) | ||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
file_handler.setFormatter(formatter) | ||
self.logger.addHandler(file_handler) | ||
|
||
def get_logger(self): | ||
return self.logger |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from repo_agent.chat_with_repo.gradio_ui import GradioInterface | ||
import yaml | ||
from rag import RepoAssistant | ||
|
||
|
||
def load_config(config_file): | ||
with open(config_file, 'r') as file: | ||
return yaml.safe_load(file) | ||
|
||
|
||
def main(): | ||
config = load_config("config.yml") | ||
api_key = config['api_key'] | ||
api_base = config['api_base'] | ||
db_path = config['db_path'] | ||
log_file = config['log_file'] | ||
assistant = RepoAssistant(api_key, api_base, db_path,log_file) | ||
md_contents = assistant.json_data.extract_md_contents() | ||
assistant.chroma_data.create_vector_store(md_contents) | ||
GradioInterface(assistant.respond) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个写法我持保留意见,我觉得当前 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这四部指的是1,传入相关配置,2.获取md内容3,放到向量数据库,4启动查询UI,主要rag已经封装了,上次说的能单例化的我基本都单例开了个类 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我是针对 |
||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
|
||
from llama_index.llms import OpenAI | ||
from logger import LoggerManager | ||
from repo_agent.chat_with_repo.json_handle import JsonFileProcessor | ||
class TextAnalysisTool: | ||
def __init__(self, llm, logger, db_path): | ||
jsonsearch=JsonFileProcessor(db_path) | ||
self.jsonsearch=jsonsearch | ||
self.llm = llm | ||
self.logger = logger.get_logger() | ||
self.db_path = db_path | ||
|
||
def keyword(self, query): | ||
prompt = f"Please provide a list of keywords related to the following query, requests output no more than 3 keywords, Input: {query}, Output:" | ||
response = self.llm.complete(prompt) | ||
return response | ||
|
||
def tree(self, query): | ||
prompt = f"Please analyze the following text and generate a tree structure based on its hierarchy:\n\n{query}" | ||
response = self.llm.complete(prompt) | ||
return response | ||
|
||
def format_chat_prompt(self, message, instruction): | ||
prompt = f"System:{instruction}\nUser: {message}\nAssistant:" | ||
return prompt | ||
|
||
def queryblock(self, message): | ||
search_result = self.jsonsearch.search_in_json_nested(self.db_path, message) | ||
if isinstance(search_result, dict): | ||
search_result = search_result['code_content'] | ||
return str(search_result) | ||
|
||
def nerquery(self, message): | ||
query1 = """ | ||
The output must strictly be a pure function name or class name, without any additional characters. | ||
For example: | ||
Pure function names: calculateSum, processData | ||
Pure class names: MyClass, DataProcessor | ||
The output function name or class name should be only one. | ||
""" | ||
query = f"Extract the most relevant class or function from the following{query1}input:\n{message}\nOutput:" | ||
response = self.llm.complete(query) | ||
self.logger.debug(f"Input: {message}, Output: {response}") | ||
return response | ||
|
||
if __name__ == "__main__": | ||
api_base = "https://api.openai.com/v1" | ||
api_key = "your_api_key" | ||
log_file = "your_logfile_path" | ||
llm = OpenAI(api_key=api_key, api_base=api_base) | ||
logger = LoggerManager(log_file) | ||
db_path = "your_database_path" | ||
test= TextAnalysisTool(llm,logger,db_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在下个版本中使用
from loguru import logger
吧,是一个简化了 Logging 配置的库There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,下次更新就换掉