From 2e8daab060d820990402f99af126b1e3655d95c6 Mon Sep 17 00:00:00 2001 From: qingzhong1 Date: Wed, 6 Sep 2023 17:03:27 +0800 Subject: [PATCH] modify chatpaper_1 --- pipelines/examples/chatpaper/README.md | 2 +- pipelines/examples/chatpaper/chat_paper.py | 229 +++++++++--------- pipelines/examples/chatpaper/requirements.txt | 2 +- pipelines/examples/chatpaper/utils.py | 12 + 4 files changed, 135 insertions(+), 110 deletions(-) diff --git a/pipelines/examples/chatpaper/README.md b/pipelines/examples/chatpaper/README.md index dc5a829cfc01..0afebed0c382 100644 --- a/pipelines/examples/chatpaper/README.md +++ b/pipelines/examples/chatpaper/README.md @@ -16,7 +16,7 @@ python chat_paper.py \ --secret_key \ --bos_ak \ --bos_sk \ ---json_dir \ +--txt_file \ --retriever_api_key \ --retriever_secret_key \ --es_host \ diff --git a/pipelines/examples/chatpaper/chat_paper.py b/pipelines/examples/chatpaper/chat_paper.py index 4e5c02a06e90..01f4a21b04f9 100644 --- a/pipelines/examples/chatpaper/chat_paper.py +++ b/pipelines/examples/chatpaper/chat_paper.py @@ -19,7 +19,13 @@ import arxiv import erniebot as eb import gradio as gr -from utils import _apply_token, get_shown_context, pdf2image, retrieval +from utils import ( + _apply_token, + get_shown_context, + load_all_json_path, + pdf2image, + retrieval, +) paper_id_list = [] single_paper_id = "" @@ -57,7 +63,7 @@ "--retriever_embed_title", type=bool, default=False, help="whether use embedding title in retriever" ) parser.add_argument("--retriever_threshold", type=float, default=0.95, help="the threshold of retriever") -parser.add_argument("--json_dir", type=str, default="", help="the directory of json files created by papers") +parser.add_argument("--txt_file", type=str, default="", help="the path of a txt file which includes all papers path") parser.add_argument("--max_token", type=int, default=11200, help=" the max number of tokens of LLM") args = parser.parse_args() PROMPT_RETRIVER = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案, @@ -65,6 +71,8 @@ <已知信息>{documents} <问题>{query}""" +all_json_id = load_all_json_path(args.txt_file) + def clear_input(): """ @@ -91,50 +99,16 @@ def retrieval_papers(query, state={}): query = query.strip().replace("
", "\n") context = state.setdefault("context", []) global paper_id_list - if not paper_id_list: - context.append({"system": args.system_prompt, "role": "user", "content": query}) - prediction = retrieval( - query=query, - es_host=args.es_host, - es_port=args.es_port, - es_username=args.es_username, - es_password=args.es_password, - es_index=args.es_index_abstract, - es_chunk_size=args.es_chunk_size, - es_thread_count=args.es_thread_count, - es_queue_size=args.es_queue_size, - retriever_batch_size=args.retriever_batch_size, - retriever_api_key=args.retriever_api_key, - retriever_secret_key=args.retriever_secret_key, - retriever_embed_title=args.retriever_embed_title, - retriever_topk=30, - rank_topk=5, - ) - documents = prediction["documents"] - all_content = "" - for i in range(len(documents)): - if documents[i].meta["id"] not in paper_id_list: - paper_id_list.append(documents[i].meta["id"]) - key_words = documents[i].meta.get("key_words", "") - title = documents[i].meta.get("title", "") - abstract = documents[i].meta.get("abstracts", "") - paper_content = ( - "**" + str(len(paper_id_list)) + "." + title + "**" + "\n" + key_words + "\n" + abstract - ) - all_content += paper_content + "\n\n" - context.append({"role": "assistant", "content": all_content}) - shown_context = get_shown_context(context) - else: - content = "" - for id in paper_id_list: + if query: + if not paper_id_list: + context.append({"system": args.system_prompt, "role": "user", "content": query}) prediction = retrieval( query=query, - file_id=id, es_host=args.es_host, es_port=args.es_port, es_username=args.es_username, es_password=args.es_password, - es_index=args.es_index_full_text, + es_index=args.es_index_abstract, es_chunk_size=args.es_chunk_size, es_thread_count=args.es_thread_count, es_queue_size=args.es_queue_size, @@ -143,20 +117,57 @@ def retrieval_papers(query, state={}): retriever_secret_key=args.retriever_secret_key, retriever_embed_title=args.retriever_embed_title, retriever_topk=30, - rank_topk=2, + rank_topk=5, ) - content += "\n".join([item.content for item in prediction["documents"]]) - content = PROMPT_RETRIVER.format(documents=content, query=query) - content = content[: args.max_token] - context.append({"system": args.system_prompt, "role": "user", "content": content}) - eb.api_type = args.api_type - access_token = _apply_token(args.api_key, args.secret_key) - eb.access_token = access_token - model = "ernie-bot-3.5" if args.ernie_model is None or args.ernie_model.strip() == "" else args.ernie_model - response = eb.ChatCompletion.create(model=model, messages=context, stream=False) - bot_response = response.result - context.append({"role": "assistant", "content": bot_response}) - context[-2]["content"] = query + documents = prediction["documents"] + all_content = "" + for i in range(len(documents)): + if documents[i].meta["id"] not in paper_id_list: + paper_id_list.append(documents[i].meta["id"]) + key_words = documents[i].meta.get("key_words", "") + title = documents[i].meta.get("title", "") + abstract = documents[i].meta.get("abstracts", "") + paper_content = ( + "**" + str(len(paper_id_list)) + "." + title + "**" + "\n" + key_words + "\n" + abstract + ) + all_content += paper_content + "\n\n" + context.append({"role": "assistant", "content": all_content}) + shown_context = get_shown_context(context) + else: + content = "" + for id in paper_id_list: + prediction = retrieval( + query=query, + file_id=id, + es_host=args.es_host, + es_port=args.es_port, + es_username=args.es_username, + es_password=args.es_password, + es_index=args.es_index_full_text, + es_chunk_size=args.es_chunk_size, + es_thread_count=args.es_thread_count, + es_queue_size=args.es_queue_size, + retriever_batch_size=args.retriever_batch_size, + retriever_api_key=args.retriever_api_key, + retriever_secret_key=args.retriever_secret_key, + retriever_embed_title=args.retriever_embed_title, + retriever_topk=30, + rank_topk=2, + ) + content += "\n".join([item.content for item in prediction["documents"]]) + content = PROMPT_RETRIVER.format(documents=content, query=query) + content = content[: args.max_token] + context.append({"system": args.system_prompt, "role": "user", "content": content}) + eb.api_type = args.api_type + access_token = _apply_token(args.api_key, args.secret_key) + eb.access_token = access_token + model = "ernie-bot-3.5" if args.ernie_model is None or args.ernie_model.strip() == "" else args.ernie_model + response = eb.ChatCompletion.create(model=model, messages=context, stream=False) + bot_response = response.result + context.append({"role": "assistant", "content": bot_response}) + context[-2]["content"] = query + shown_context = get_shown_context(context) + else: shown_context = get_shown_context(context) return None, shown_context, state @@ -194,40 +205,43 @@ def infer(query, state): eb.access_token = access_token query = query.strip().replace("
", "\n") context = state.setdefault("context", []) - if single_paper_id: - prediction = retrieval( - query=query, - file_id=single_paper_id, - es_host=args.es_host, - es_port=args.es_port, - es_username=args.es_username, - es_password=args.es_password, - es_index=args.es_index_full_text, - es_chunk_size=args.es_chunk_size, - es_thread_count=args.es_thread_count, - es_queue_size=args.es_queue_size, - retriever_batch_size=args.retriever_batch_size, - retriever_api_key=args.retriever_api_key, - retriever_secret_key=args.retriever_secret_key, - retriever_embed_title=args.retriever_embed_title, - retriever_topk=30, - rank_topk=2, - ) - content = "\n".join([item.content for item in prediction["documents"]]) - content = PROMPT_RETRIVER.format(documents=content, query=query) - content = content[: args.max_token] - context.append({"system": args.system_prompt, "role": "user", "content": content}) - model = "ernie-bot-3.5" if args.ernie_model is None or args.ernie_model.strip() == "" else args.ernie_model - response = eb.ChatCompletion.create(model=model, messages=context, stream=False) - bot_response = response.result - context.append({"role": "assistant", "content": bot_response}) - context[-2]["content"] = query - shown_context = get_shown_context(context) + if query: + if single_paper_id: + prediction = retrieval( + query=query, + file_id=single_paper_id, + es_host=args.es_host, + es_port=args.es_port, + es_username=args.es_username, + es_password=args.es_password, + es_index=args.es_index_full_text, + es_chunk_size=args.es_chunk_size, + es_thread_count=args.es_thread_count, + es_queue_size=args.es_queue_size, + retriever_batch_size=args.retriever_batch_size, + retriever_api_key=args.retriever_api_key, + retriever_secret_key=args.retriever_secret_key, + retriever_embed_title=args.retriever_embed_title, + retriever_topk=30, + rank_topk=2, + ) + content = "\n".join([item.content for item in prediction["documents"]]) + content = PROMPT_RETRIVER.format(documents=content, query=query) + content = content[: args.max_token] + context.append({"system": args.system_prompt, "role": "user", "content": content}) + model = "ernie-bot-3.5" if args.ernie_model is None or args.ernie_model.strip() == "" else args.ernie_model + response = eb.ChatCompletion.create(model=model, messages=context, stream=False) + bot_response = response.result + context.append({"role": "assistant", "content": bot_response}) + context[-2]["content"] = query + shown_context = get_shown_context(context) + else: + context.append({"system": args.system_prompt, "role": "user", "content": query}) + response = eb.ChatFile.create(messages=context, stream=False) + bot_response = response.result + context.append({"role": "assistant", "content": bot_response}) + shown_context = get_shown_context(context) else: - context.append({"system": args.system_prompt, "role": "user", "content": query}) - response = eb.ChatFile.create(messages=context, stream=False) - bot_response = response.result - context.append({"role": "assistant", "content": bot_response}) shown_context = get_shown_context(context) return None, shown_context, state @@ -239,29 +253,28 @@ def upload_file(file_name, file_url, file_upload, state={}): global single_paper_id single_paper_id = "" if file_name: - json_file_path, file_id = retrieval_title(file_name) - json_file_path = json_file_path.replace("/", "_").replace(".pdf", "") - json_file_path = os.path.join(args.json_dir, json_file_path) - single_paper_id = file_id - if os.path.isfile(json_file_path): - + try: + json_file_path, file_id = retrieval_title(file_name) + json_file_path = json_file_path.replace("/", "_").replace(".pdf", "") + json_file_path = all_json_id[json_file_path] + single_paper_id = file_id with open(json_file_path, mode="r") as json_file: json_content = json.load(json_file) content = json_content["content"] - return ( - gr.Gallery.update(visible=False), - gr.File.update(visible=False), - None, - state, - gr.Chatbot.update( - [["", content]], - visible=True, - scale=30, - height=600, - ), - ) - else: - return gr.Gallery.update(visible=False), gr.File.update(visible=False), None, state, None + except: + content = "这篇论文目前尚未加入到论文库中" + return ( + gr.Gallery.update(visible=False), + gr.File.update(visible=False), + None, + state, + gr.Chatbot.update( + [["", content]], + visible=True, + scale=30, + height=600, + ), + ) elif file_url: root_path = "./" paper = next(arxiv.Search(id_list=[file_url.split("/")[-1]]).results()) @@ -385,9 +398,9 @@ def upload_file(file_name, file_url, file_upload, state={}): clear.click(clear_input, inputs=[], outputs=[file_name, file_url, file_upload]) submit_btn.click(infer, inputs=[message, state], outputs=[message, chatbot, state]) clear_btn.click( - lambda _: (None, None, None, None, {}), + lambda _: (None, {}), inputs=clear_btn, - outputs=[ori_paper, ori_pdf, chatbot, ori_json, state], + outputs=[chatbot, state], api_name="clear", show_progress=False, ) diff --git a/pipelines/examples/chatpaper/requirements.txt b/pipelines/examples/chatpaper/requirements.txt index 42e8c1d6876e..472f7bba2373 100644 --- a/pipelines/examples/chatpaper/requirements.txt +++ b/pipelines/examples/chatpaper/requirements.txt @@ -1,5 +1,5 @@ -fitz scipdf PyMuPDF==1.20.2 arxiv erniebot +gradio==3.41.2 \ No newline at end of file diff --git a/pipelines/examples/chatpaper/utils.py b/pipelines/examples/chatpaper/utils.py index 15905b050cc1..9be228c9f588 100644 --- a/pipelines/examples/chatpaper/utils.py +++ b/pipelines/examples/chatpaper/utils.py @@ -37,6 +37,18 @@ def pdf2image(pdfPath, imgPath, zoom_x=10, zoom_y=10, rotation_angle=0): return image_path +def load_all_json_path(path): + json_path = {} + with open(path, encoding="utf-8", mode="r") as f: + for line in f: + try: + json_id, json_name = line.strip().split() + json_path[json_id] = json_name + except: + continue + return json_path + + def _apply_token(api_key, secret_key): """ Gererate an access token.