Skip to content

Commit

Permalink
Frontier (#1958)
Browse files Browse the repository at this point in the history
* update welcome svg

* fix loading chatglm3 (#1937)

* update welcome svg

* update welcome message

* fix loading chatglm3

---------

Co-authored-by: binary-husky <qingxu.fu@outlook.com>
Co-authored-by: binary-husky <96192199+binary-husky@users.noreply.github.com>

* begin rag project with llama index

* rag version one

* rag beta release

* add social worker (proto)

* fix llamaindex version

---------

Co-authored-by: moetayuko <loli@yuko.moe>
  • Loading branch information
binary-husky and moetayuko committed Sep 8, 2024
1 parent 16f4fd6 commit dd66ca2
Show file tree
Hide file tree
Showing 19 changed files with 1,103 additions and 12 deletions.
1 change: 1 addition & 0 deletions TODO
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
RAG忘了触发保存了!
5 changes: 4 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
"gemini-1.5-pro", "chatglm3"
]

EMBEDDING_MODEL = "text-embedding-3-small"

# --- --- --- ---
# P.S. 其他可用的模型还包括
# AVAIL_LLM_MODELS = [
Expand Down Expand Up @@ -296,7 +299,7 @@

# 除了连接OpenAI之外,还有哪些场合允许使用代理,请尽量不要修改
WHEN_TO_USE_PROXY = ["Download_LLM", "Download_Gradio_Theme", "Connect_Grobid",
"Warmup_Modules", "Nougat_Download", "AutoGen"]
"Warmup_Modules", "Nougat_Download", "AutoGen", "Connect_OpenAI_Embedding"]


# 启用插件热加载
Expand Down
8 changes: 8 additions & 0 deletions crazy_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
def get_crazy_functions():
from crazy_functions.读文章写摘要 import 读文章写摘要
from crazy_functions.生成函数注释 import 批量生成函数注释
from crazy_functions.Rag_Interface import Rag问答
from crazy_functions.SourceCode_Analyse import 解析项目本身
from crazy_functions.SourceCode_Analyse import 解析一个Python项目
from crazy_functions.SourceCode_Analyse import 解析一个Matlab项目
Expand Down Expand Up @@ -50,6 +51,13 @@ def get_crazy_functions():
from crazy_functions.SourceCode_Comment import 注释Python项目

function_plugins = {
"Rag智能召回": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"Info": "将问答数据记录到向量库中,作为长期参考。",
"Function": HotReload(Rag问答),
},
"虚空终端": {
"Group": "对话|编程|学术|智能体",
"Color": "stop",
Expand Down
75 changes: 75 additions & 0 deletions crazy_functions/Rag_Interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker

RAG_WORKER_REGISTER = {}

MAX_HISTORY_ROUND = 5
MAX_CONTEXT_TOKEN_LIMIT = 4096
REMEMBER_PREVIEW = 1000

@CatchException
def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):

# 1. we retrieve rag worker from global context
user_name = chatbot.get_user()
if user_name in RAG_WORKER_REGISTER:
rag_worker = RAG_WORKER_REGISTER[user_name]
else:
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker(
user_name,
llm_kwargs,
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag'),
auto_load_checkpoint=True)

chatbot.append([txt, '正在召回知识 ...'])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面

# 2. clip history to reduce token consumption
# 2-1. reduce chat round
txt_origin = txt

if len(history) > MAX_HISTORY_ROUND * 2:
history = history[-(MAX_HISTORY_ROUND * 2):]
txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True)
input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"])

# 2-2. if input is clipped, add input to vector store before retrieve
if input_is_clipped_flag:
yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面
# save input to vector store
rag_worker.add_text_to_vector_store(txt_origin)
yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面
if len(txt_origin) > REMEMBER_PREVIEW:
HALF = REMEMBER_PREVIEW//2
i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:]
if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF:
txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:]
else:
pass
i_say = txt_clip
else:
i_say_to_remember = i_say = txt_clip
else:
i_say_to_remember = i_say = txt_clip

# 3. we search vector store and build prompts
nodes = rag_worker.retrieve_from_store_with_query(i_say)
prompt = rag_worker.build_prompt(query=i_say, nodes=nodes)

# 4. it is time to query llms
if len(chatbot) != 0: chatbot.pop(-1) # pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive`
model_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=prompt, inputs_show_user=i_say,
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
sys_prompt=system_prompt,
retry_times_at_unknown_error=0
)

# 5. remember what has been asked / answered
yield from update_ui_lastest_msg(model_say + '</br></br>' + '对话记忆中, 请稍等 ...', chatbot, history, delay=0.5) # 刷新界面
rag_worker.remember_qa(i_say_to_remember, model_say)
history.extend([i_say, model_say])

yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0) # 刷新界面
65 changes: 65 additions & 0 deletions crazy_functions/Social_Helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
import pickle, os

SOCIAL_NETWOK_WORKER_REGISTER = {}

class SocialNetwork():
def __init__(self):
self.people = []

class SocialNetworkWorker():
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
self.user_name = user_name
self.checkpoint_dir = checkpoint_dir
if auto_load_checkpoint:
self.social_network = self.load_from_checkpoint(checkpoint_dir)
else:
self.social_network = SocialNetwork()

def does_checkpoint_exist(self, checkpoint_dir=None):
import os, glob
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if not os.path.exists(checkpoint_dir): return False
if len(glob.glob(os.path.join(checkpoint_dir, "social_network.pkl"))) == 0: return False
return True

def save_to_checkpoint(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "wb+") as f:
pickle.dump(self.social_network, f)
return

def load_from_checkpoint(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "rb") as f:
social_network = pickle.load(f)
return social_network
else:
return SocialNetwork()


@CatchException
def I人助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, num_day=5):

# 1. we retrieve worker from global context
user_name = chatbot.get_user()
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag')
if user_name in SOCIAL_NETWOK_WORKER_REGISTER:
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name]
else:
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name] = SocialNetworkWorker(
user_name,
llm_kwargs,
checkpoint_dir=checkpoint_dir,
auto_load_checkpoint=True
)

# 2. save
social_network_worker.social_network.people.append("张三")
social_network_worker.save_to_checkpoint(checkpoint_dir)
chatbot.append(["good", "work"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面

25 changes: 21 additions & 4 deletions crazy_functions/crazy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import logging

def input_clipping(inputs, history, max_token_limit):
def input_clipping(inputs, history, max_token_limit, return_clip_flags=False):
"""
当输入文本 + 历史文本超出最大限制时,采取措施丢弃一部分文本。
输入:
Expand All @@ -20,17 +20,20 @@ def input_clipping(inputs, history, max_token_limit):
enc = model_info["gpt-3.5-turbo"]['tokenizer']
def get_token_num(txt): return len(enc.encode(txt, disallowed_special=()))


mode = 'input-and-history'
# 当 输入部分的token占比 小于 全文的一半时,只裁剪历史
input_token_num = get_token_num(inputs)
original_input_len = len(inputs)
if input_token_num < max_token_limit//2:
mode = 'only-history'
max_token_limit = max_token_limit - input_token_num

everything = [inputs] if mode == 'input-and-history' else ['']
everything.extend(history)
n_token = get_token_num('\n'.join(everything))
full_token_num = n_token = get_token_num('\n'.join(everything))
everything_token = [get_token_num(e) for e in everything]
everything_token_num = sum(everything_token)
delta = max(everything_token) // 16 # 截断时的颗粒度

while n_token > max_token_limit:
Expand All @@ -43,10 +46,24 @@ def get_token_num(txt): return len(enc.encode(txt, disallowed_special=()))

if mode == 'input-and-history':
inputs = everything[0]
full_token_num = everything_token_num
else:
pass
full_token_num = everything_token_num + input_token_num

history = everything[1:]
return inputs, history

flags = {
"mode": mode,
"original_input_token_num": input_token_num,
"original_full_token_num": full_token_num,
"original_input_len": original_input_len,
"clipped_input_len": len(inputs),
}

if not return_clip_flags:
return inputs, history
else:
return inputs, history, flags

def request_gpt_model_in_new_thread_with_ui_alive(
inputs, inputs_show_user, llm_kwargs,
Expand Down
122 changes: 122 additions & 0 deletions crazy_functions/rag_fns/llama_index_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import llama_index
from llama_index.core import Document
from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
from shared_utils.connect_void_terminal import get_chat_default_kwargs
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize

DEFAULT_QUERY_GENERATION_PROMPT = """\
Now, you have context information as below:
---------------------
{context_str}
---------------------
Answer the user request below (use the context information if necessary, otherwise you can ignore them):
---------------------
{query_str}
"""

QUESTION_ANSWER_RECORD = """\
{{
"type": "This is a previous conversation with the user",
"question": "{question}",
"answer": "{answer}",
}}
"""


class SaveLoad():

def does_checkpoint_exist(self, checkpoint_dir=None):
import os, glob
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if not os.path.exists(checkpoint_dir): return False
if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False
return True

def save_to_checkpoint(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)

def load_from_checkpoint(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
print('loading checkpoint from disk')
from llama_index.core import StorageContext, load_index_from_storage
storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir)
self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model)
return self.vs_index
else:
return self.create_new_vs()

def create_new_vs(self):
return GptacVectorStoreIndex.default_vector_store(embed_model=self.embed_model)


class LlamaIndexRagWorker(SaveLoad):
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
self.debug_mode = True
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
self.user_name = user_name
self.checkpoint_dir = checkpoint_dir
if auto_load_checkpoint:
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
else:
self.vs_index = self.create_new_vs()

def assign_embedding_model(self):
pass

def inspect_vector_store(self):
# This function is for debugging
self.vs_index.storage_context.index_store.to_dict()
docstore = self.vs_index.storage_context.docstore.docs
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
print('\n++ --------inspect_vector_store begin--------')
print(vector_store_preview)
print('oo --------inspect_vector_store end--------')
return vector_store_preview

def add_documents_to_vector_store(self, document_list):
documents = [Document(text=t) for t in document_list]
documents_nodes = run_transformations(
documents, # type: ignore
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(documents_nodes)
if self.debug_mode: self.inspect_vector_store()

def add_text_to_vector_store(self, text):
node = TextNode(text=text)
documents_nodes = run_transformations(
[node],
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(documents_nodes)
if self.debug_mode: self.inspect_vector_store()

def remember_qa(self, question, answer):
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
self.add_text_to_vector_store(formatted_str)

def retrieve_from_store_with_query(self, query):
if self.debug_mode: self.inspect_vector_store()
retriever = self.vs_index.as_retriever()
return retriever.retrieve(query)

def build_prompt(self, query, nodes):
context_str = self.generate_node_array_preview(nodes)
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)

def generate_node_array_preview(self, nodes):
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
if self.debug_mode: print(buf)
return buf



Loading

0 comments on commit dd66ca2

Please sign in to comment.