Skip to content

Commit

Permalink
modify chatpaper_1
Browse files Browse the repository at this point in the history
  • Loading branch information
qingzhong1 committed Sep 6, 2023
1 parent 5fec2f4 commit 2e8daab
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 110 deletions.
2 changes: 1 addition & 1 deletion pipelines/examples/chatpaper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ python chat_paper.py \
--secret_key \
--bos_ak \
--bos_sk \
--json_dir \
--txt_file \
--retriever_api_key \
--retriever_secret_key \
--es_host \
Expand Down
229 changes: 121 additions & 108 deletions pipelines/examples/chatpaper/chat_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
import arxiv
import erniebot as eb
import gradio as gr
from utils import _apply_token, get_shown_context, pdf2image, retrieval
from utils import (
_apply_token,
get_shown_context,
load_all_json_path,
pdf2image,
retrieval,
)

paper_id_list = []
single_paper_id = ""
Expand Down Expand Up @@ -57,14 +63,16 @@
"--retriever_embed_title", type=bool, default=False, help="whether use embedding title in retriever"
)
parser.add_argument("--retriever_threshold", type=float, default=0.95, help="the threshold of retriever")
parser.add_argument("--json_dir", type=str, default="", help="the directory of json files created by papers")
parser.add_argument("--txt_file", type=str, default="", help="the path of a txt file which includes all papers path")
parser.add_argument("--max_token", type=int, default=11200, help=" the max number of tokens of LLM")
args = parser.parse_args()
PROMPT_RETRIVER = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,
请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
<已知信息>{documents}</已知信息>
<问题>{query}</问题>"""

all_json_id = load_all_json_path(args.txt_file)


def clear_input():
"""
Expand All @@ -91,50 +99,16 @@ def retrieval_papers(query, state={}):
query = query.strip().replace("<br>", "\n")
context = state.setdefault("context", [])
global paper_id_list
if not paper_id_list:
context.append({"system": args.system_prompt, "role": "user", "content": query})
prediction = retrieval(
query=query,
es_host=args.es_host,
es_port=args.es_port,
es_username=args.es_username,
es_password=args.es_password,
es_index=args.es_index_abstract,
es_chunk_size=args.es_chunk_size,
es_thread_count=args.es_thread_count,
es_queue_size=args.es_queue_size,
retriever_batch_size=args.retriever_batch_size,
retriever_api_key=args.retriever_api_key,
retriever_secret_key=args.retriever_secret_key,
retriever_embed_title=args.retriever_embed_title,
retriever_topk=30,
rank_topk=5,
)
documents = prediction["documents"]
all_content = ""
for i in range(len(documents)):
if documents[i].meta["id"] not in paper_id_list:
paper_id_list.append(documents[i].meta["id"])
key_words = documents[i].meta.get("key_words", "")
title = documents[i].meta.get("title", "")
abstract = documents[i].meta.get("abstracts", "")
paper_content = (
"**" + str(len(paper_id_list)) + "." + title + "**" + "\n" + key_words + "\n" + abstract
)
all_content += paper_content + "\n\n"
context.append({"role": "assistant", "content": all_content})
shown_context = get_shown_context(context)
else:
content = ""
for id in paper_id_list:
if query:
if not paper_id_list:
context.append({"system": args.system_prompt, "role": "user", "content": query})
prediction = retrieval(
query=query,
file_id=id,
es_host=args.es_host,
es_port=args.es_port,
es_username=args.es_username,
es_password=args.es_password,
es_index=args.es_index_full_text,
es_index=args.es_index_abstract,
es_chunk_size=args.es_chunk_size,
es_thread_count=args.es_thread_count,
es_queue_size=args.es_queue_size,
Expand All @@ -143,20 +117,57 @@ def retrieval_papers(query, state={}):
retriever_secret_key=args.retriever_secret_key,
retriever_embed_title=args.retriever_embed_title,
retriever_topk=30,
rank_topk=2,
rank_topk=5,
)
content += "\n".join([item.content for item in prediction["documents"]])
content = PROMPT_RETRIVER.format(documents=content, query=query)
content = content[: args.max_token]
context.append({"system": args.system_prompt, "role": "user", "content": content})
eb.api_type = args.api_type
access_token = _apply_token(args.api_key, args.secret_key)
eb.access_token = access_token
model = "ernie-bot-3.5" if args.ernie_model is None or args.ernie_model.strip() == "" else args.ernie_model
response = eb.ChatCompletion.create(model=model, messages=context, stream=False)
bot_response = response.result
context.append({"role": "assistant", "content": bot_response})
context[-2]["content"] = query
documents = prediction["documents"]
all_content = ""
for i in range(len(documents)):
if documents[i].meta["id"] not in paper_id_list:
paper_id_list.append(documents[i].meta["id"])
key_words = documents[i].meta.get("key_words", "")
title = documents[i].meta.get("title", "")
abstract = documents[i].meta.get("abstracts", "")
paper_content = (
"**" + str(len(paper_id_list)) + "." + title + "**" + "\n" + key_words + "\n" + abstract
)
all_content += paper_content + "\n\n"
context.append({"role": "assistant", "content": all_content})
shown_context = get_shown_context(context)
else:
content = ""
for id in paper_id_list:
prediction = retrieval(
query=query,
file_id=id,
es_host=args.es_host,
es_port=args.es_port,
es_username=args.es_username,
es_password=args.es_password,
es_index=args.es_index_full_text,
es_chunk_size=args.es_chunk_size,
es_thread_count=args.es_thread_count,
es_queue_size=args.es_queue_size,
retriever_batch_size=args.retriever_batch_size,
retriever_api_key=args.retriever_api_key,
retriever_secret_key=args.retriever_secret_key,
retriever_embed_title=args.retriever_embed_title,
retriever_topk=30,
rank_topk=2,
)
content += "\n".join([item.content for item in prediction["documents"]])
content = PROMPT_RETRIVER.format(documents=content, query=query)
content = content[: args.max_token]
context.append({"system": args.system_prompt, "role": "user", "content": content})
eb.api_type = args.api_type
access_token = _apply_token(args.api_key, args.secret_key)
eb.access_token = access_token
model = "ernie-bot-3.5" if args.ernie_model is None or args.ernie_model.strip() == "" else args.ernie_model
response = eb.ChatCompletion.create(model=model, messages=context, stream=False)
bot_response = response.result
context.append({"role": "assistant", "content": bot_response})
context[-2]["content"] = query
shown_context = get_shown_context(context)
else:
shown_context = get_shown_context(context)
return None, shown_context, state

Expand Down Expand Up @@ -194,40 +205,43 @@ def infer(query, state):
eb.access_token = access_token
query = query.strip().replace("<br>", "\n")
context = state.setdefault("context", [])
if single_paper_id:
prediction = retrieval(
query=query,
file_id=single_paper_id,
es_host=args.es_host,
es_port=args.es_port,
es_username=args.es_username,
es_password=args.es_password,
es_index=args.es_index_full_text,
es_chunk_size=args.es_chunk_size,
es_thread_count=args.es_thread_count,
es_queue_size=args.es_queue_size,
retriever_batch_size=args.retriever_batch_size,
retriever_api_key=args.retriever_api_key,
retriever_secret_key=args.retriever_secret_key,
retriever_embed_title=args.retriever_embed_title,
retriever_topk=30,
rank_topk=2,
)
content = "\n".join([item.content for item in prediction["documents"]])
content = PROMPT_RETRIVER.format(documents=content, query=query)
content = content[: args.max_token]
context.append({"system": args.system_prompt, "role": "user", "content": content})
model = "ernie-bot-3.5" if args.ernie_model is None or args.ernie_model.strip() == "" else args.ernie_model
response = eb.ChatCompletion.create(model=model, messages=context, stream=False)
bot_response = response.result
context.append({"role": "assistant", "content": bot_response})
context[-2]["content"] = query
shown_context = get_shown_context(context)
if query:
if single_paper_id:
prediction = retrieval(
query=query,
file_id=single_paper_id,
es_host=args.es_host,
es_port=args.es_port,
es_username=args.es_username,
es_password=args.es_password,
es_index=args.es_index_full_text,
es_chunk_size=args.es_chunk_size,
es_thread_count=args.es_thread_count,
es_queue_size=args.es_queue_size,
retriever_batch_size=args.retriever_batch_size,
retriever_api_key=args.retriever_api_key,
retriever_secret_key=args.retriever_secret_key,
retriever_embed_title=args.retriever_embed_title,
retriever_topk=30,
rank_topk=2,
)
content = "\n".join([item.content for item in prediction["documents"]])
content = PROMPT_RETRIVER.format(documents=content, query=query)
content = content[: args.max_token]
context.append({"system": args.system_prompt, "role": "user", "content": content})
model = "ernie-bot-3.5" if args.ernie_model is None or args.ernie_model.strip() == "" else args.ernie_model
response = eb.ChatCompletion.create(model=model, messages=context, stream=False)
bot_response = response.result
context.append({"role": "assistant", "content": bot_response})
context[-2]["content"] = query
shown_context = get_shown_context(context)
else:
context.append({"system": args.system_prompt, "role": "user", "content": query})
response = eb.ChatFile.create(messages=context, stream=False)
bot_response = response.result
context.append({"role": "assistant", "content": bot_response})
shown_context = get_shown_context(context)
else:
context.append({"system": args.system_prompt, "role": "user", "content": query})
response = eb.ChatFile.create(messages=context, stream=False)
bot_response = response.result
context.append({"role": "assistant", "content": bot_response})
shown_context = get_shown_context(context)
return None, shown_context, state

Expand All @@ -239,29 +253,28 @@ def upload_file(file_name, file_url, file_upload, state={}):
global single_paper_id
single_paper_id = ""
if file_name:
json_file_path, file_id = retrieval_title(file_name)
json_file_path = json_file_path.replace("/", "_").replace(".pdf", "")
json_file_path = os.path.join(args.json_dir, json_file_path)
single_paper_id = file_id
if os.path.isfile(json_file_path):

try:
json_file_path, file_id = retrieval_title(file_name)
json_file_path = json_file_path.replace("/", "_").replace(".pdf", "")
json_file_path = all_json_id[json_file_path]
single_paper_id = file_id
with open(json_file_path, mode="r") as json_file:
json_content = json.load(json_file)
content = json_content["content"]
return (
gr.Gallery.update(visible=False),
gr.File.update(visible=False),
None,
state,
gr.Chatbot.update(
[["", content]],
visible=True,
scale=30,
height=600,
),
)
else:
return gr.Gallery.update(visible=False), gr.File.update(visible=False), None, state, None
except:
content = "这篇论文目前尚未加入到论文库中"
return (
gr.Gallery.update(visible=False),
gr.File.update(visible=False),
None,
state,
gr.Chatbot.update(
[["", content]],
visible=True,
scale=30,
height=600,
),
)
elif file_url:
root_path = "./"
paper = next(arxiv.Search(id_list=[file_url.split("/")[-1]]).results())
Expand Down Expand Up @@ -385,9 +398,9 @@ def upload_file(file_name, file_url, file_upload, state={}):
clear.click(clear_input, inputs=[], outputs=[file_name, file_url, file_upload])
submit_btn.click(infer, inputs=[message, state], outputs=[message, chatbot, state])
clear_btn.click(
lambda _: (None, None, None, None, {}),
lambda _: (None, {}),
inputs=clear_btn,
outputs=[ori_paper, ori_pdf, chatbot, ori_json, state],
outputs=[chatbot, state],
api_name="clear",
show_progress=False,
)
Expand Down
2 changes: 1 addition & 1 deletion pipelines/examples/chatpaper/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fitz
scipdf
PyMuPDF==1.20.2
arxiv
erniebot
gradio==3.41.2
12 changes: 12 additions & 0 deletions pipelines/examples/chatpaper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ def pdf2image(pdfPath, imgPath, zoom_x=10, zoom_y=10, rotation_angle=0):
return image_path


def load_all_json_path(path):
json_path = {}
with open(path, encoding="utf-8", mode="r") as f:
for line in f:
try:
json_id, json_name = line.strip().split()
json_path[json_id] = json_name
except:
continue
return json_path


def _apply_token(api_key, secret_key):
"""
Gererate an access token.
Expand Down

0 comments on commit 2e8daab

Please sign in to comment.