Skip to content

Commit

Permalink
[Paddle-pipelines] Add index data parallel support (#6886)
Browse files Browse the repository at this point in the history
* Add multi thread support

* Update chatpaper index

* Set fixed number_of_shards
  • Loading branch information
w5688414 authored Sep 3, 2023
1 parent 7588a21 commit 2f3eac3
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 132 deletions.
40 changes: 28 additions & 12 deletions pipelines/examples/chatpaper/hierarchical_search_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.")
parser.add_argument("--index_name", default='dureader_nano_query_encoder', type=str, help="The ann index name of ANN.")
parser.add_argument("--root_index_name", default="weipu_abstract", type=str, help="The index name of the ANN search engine")
parser.add_argument("--child_index_name", default="weipu_full_text", type=str, help="The index name of the ANN search engine")
parser.add_argument('--username', type=str, default="", help='Username of ANN search engine')
parser.add_argument('--password', type=str, default="", help='Password of ANN search engine')
parser.add_argument("--search_engine", choices=['elastic', 'bes'], default="elastic", help="The type of ANN search engine.")
Expand Down Expand Up @@ -67,7 +68,7 @@ def get_retrievers(use_gpu):
embedding_dim=args.embedding_dim,
vector_type="dense_vector",
search_fields=["content", "meta"],
index=args.index_name,
index=args.root_index_name,
)
else:
document_store_with_docs = BaiduElasticsearchDocumentStore(
Expand All @@ -79,7 +80,7 @@ def get_retrievers(use_gpu):
similarity="dot_prod",
vector_type="bpack_vector",
search_fields=["content", "meta"],
index=args.index_name,
index=args.root_index_name,
)

# 语义索引模型
Expand Down Expand Up @@ -117,7 +118,7 @@ def hierarchical_search_tutorial():
dpr_retriever, bm_retriever = get_retrievers(use_gpu)

# Ranker
ranker = ErnieRanker(model_name_or_path="rocketqa-nano-cross-encoder", use_gpu=use_gpu)
ranker = ErnieRanker(model_name_or_path="rocketqa-base-cross-encoder", use_gpu=use_gpu)

# Pipeline
pipeline = Pipeline()
Expand All @@ -127,31 +128,46 @@ def hierarchical_search_tutorial():
component=JoinDocuments(join_mode="concatenate"), name="JoinResults", inputs=["BMRetriever", "DenseRetriever"]
)
pipeline.add_node(component=ranker, name="Ranker", inputs=["JoinResults"])

# Abstract search
prediction = pipeline.run(
query="P2P网络借贷的风险有哪些?",
query="商誉私法保护研究",
params={
"BMRetriever": {"top_k": args.bm_topk, "index": args.index_name},
"DenseRetriever": {"top_k": args.dense_topk, "index": args.index_name},
"BMRetriever": {"top_k": args.bm_topk, "index": args.root_index_name},
"DenseRetriever": {
"top_k": args.dense_topk,
"index": args.root_index_name,
},
"Ranker": {"top_k": args.rank_topk},
},
)
print_documents(prediction)

# Main body Search
documents = prediction["documents"]
sub_index_name = documents[0].meta["name"]
file_id = documents[0].meta["id"]

# filters = {
# "$and": {
# "id": {"$eq": "6bc0c021ef4ec96a81fbc5707e1c7016"},
# }
# }
pipe = Pipeline()
pipe.add_node(component=dpr_retriever, name="DenseRetriever", inputs=["Query"])
pipe.add_node(component=ranker, name="Ranker", inputs=["DenseRetriever"])

filters = {
"$and": {
"id": {"$eq": file_id},
}
}
results = pipe.run(
query="P2P网络借贷的研究背景是什么?",
query="商誉私法保护的目的是什么?",
params={
"DenseRetriever": {"top_k": args.dense_topk, "index": sub_index_name.lower()},
"DenseRetriever": {"top_k": args.dense_topk, "index": args.child_index_name, "filters": filters},
},
)

print_documents(results)
print_documents(results, print_meta=True)


if __name__ == "__main__":
Expand Down
182 changes: 63 additions & 119 deletions pipelines/examples/chatpaper/offline_ann_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import argparse
from concurrent.futures import ThreadPoolExecutor

import pandas as pd

Expand All @@ -21,16 +22,13 @@
ElasticsearchDocumentStore,
MilvusDocumentStore,
)
from pipelines.nodes import (
CharacterTextSplitter,
DensePassageRetriever,
EmbeddingRetriever,
)
from pipelines.nodes import DensePassageRetriever, EmbeddingRetriever, SpacyTextSplitter
from pipelines.utils import launch_es

# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--index_name", default="baike_cities", type=str, help="The index name of the ANN search engine")
parser.add_argument("--root_index_name", default="weipu_abstract", type=str, help="The index name of the ANN search engine")
parser.add_argument("--child_index_name", default="weipu_full_text", type=str, help="The index name of the ANN search engine")
parser.add_argument("--file_name", default="data/baike/", type=str, help="The doc path of the corpus")
parser.add_argument('--username', type=str, default="", help='Username of ANN search engine')
parser.add_argument('--password', type=str, default="", help='Password of ANN search engine')
Expand Down Expand Up @@ -104,8 +102,6 @@ def offline_ann(index_name, docs):
index=index_name,
search_fields=args.search_fields, # 当使用了多路召回并且搜索字段设置了除content的其他字段,构建索引时其他字段也需要设置,例如:['content', 'name']。
)
# 文档数据写入数据库
# document_store.write_documents(docs)
# 语义索引模型
if args.model_type == "ernie-embedding-v1":
retriever = EmbeddingRetriever(
Expand All @@ -130,10 +126,16 @@ def offline_ann(index_name, docs):
embed_title=args.embed_title,
)

# Manually indexing
res = retriever.run_indexing(docs)
documents = res[0]["documents"]
document_store.write_documents(documents)
log_file = open("log.txt", "a")
try:
# Manually indexing
res = retriever.run_indexing(docs)
documents = res[0]["documents"]
document_store.write_documents(documents)
log_file.write(index_name + "\t" + "success" + "\n")
except Exception as e:
print("Indexing failed, please try again.")
log_file.write(index_name + "\t" + e + "\n")


def delete_data(index_name):
Expand Down Expand Up @@ -176,134 +178,76 @@ def delete_data(index_name):
print("Delete an existing elasticsearch index {} Done.".format(index_name))


from langdetect import detect


def extract_meta_data(content):
abstracts = []
key_words = ""
for index, sentence in enumerate(content):
if "关键词" in sentence:
key_words = sentence
break
elif "关键字" in sentence:
key_words = sentence
break
else:
try:
if len(sentence.strip()) == 0:
continue
res = detect(sentence)
if res == "en":
# print(sentence)
break
abstracts.append(sentence)
except Exception as e:
print(sentence)
print(e)
# If no end keyword found, return top 5 sentence as abstract
if index + 1 == len(content):
# print(index)
abstracts.append(content[:5])
key_words = ""
return abstracts, key_words


def process_abstract(csv_name):
data = pd.read_csv(csv_name)
list_data = []
for index, row in data.iterrows():
# print(index, row["title"], row["content"])
paragraphs = row["content"].split("\n")
abstracts, key_words = extract_meta_data(paragraphs)
abstracts = "\n".join(abstracts[1:])
key_words = key_words.replace("关键词", "").replace("关键字", "")
if len(abstracts) > 2000:
print(index, len(abstracts))
print(row["title"])
doc = {"abstract": abstracts, "key_words": key_words, "name": row["title"]}
list_data.append(doc)
return list_data


def extract_all_contents(content):
text_body = []
for index, sentence in enumerate(content):
try:
if len(sentence.strip()) == 0:
continue
elif "参考文献" in sentence:
break
res = detect(sentence)
# remove english sentence
if res == "en":
# print(sentence)
continue
text_body.append(sentence)
except Exception as e:
print(sentence)
print(e)
return text_body


def process_content(csv_name):
data = pd.read_csv(csv_name)
def read_data(file_path):
data = pd.read_json(path_or_buf=file_path, lines=True)
list_data = []
for index, row in data.iterrows():
paragraphs = row["content"].split("\n")
processed_content = extract_all_contents(paragraphs)
doc = {"content": "\n".join(processed_content), "name": row["title"]}
doc = row.to_dict()
list_data.append(doc)
return list_data


def indexing_abstract(csv_name):
dataset = process_abstract(csv_name)
list_documents = [
{"content": item["abstract"] + "\n" + item["key_words"], "name": item["name"]} for item in dataset
]

text_splitter = CharacterTextSplitter(separator="\n", chunk_size=300, chunk_overlap=0, filters=["\n"])
dataset = read_data(csv_name)
text_splitter = SpacyTextSplitter(separator="\n", chunk_size=320, chunk_overlap=10, filters=["\n"])
datasets = []
for document in list_documents:
text = document["content"]
for document in dataset:
text = document["abstracts"]
text_splits = text_splitter.split_text(text)
for txt in text_splits:
meta_data = {
"name": document["name"],
}
meta_data = {}
meta_data.update(document)
meta_data.pop("content")
meta_data.pop("abstracts")
datasets.append({"content": txt, "meta": meta_data})
# Add abstract into one index
offline_ann(args.index_name, datasets)
offline_ann(index_name=args.root_index_name, docs=datasets)


def run_thread_index(data):
docs = data["content"]
offline_ann(args.child_index_name, docs)


def run_multi_process_splitter(document):
file_log = open("log_process.txt", "a")
text_splitter = SpacyTextSplitter(separator="\n", chunk_size=320, chunk_overlap=10, filters=["\n"])
text = document["content"]
text_splits = text_splitter.split_text(text)

datasets = []
for txt in text_splits:
meta_data = {
"name": document["title"],
"id": document["id"],
"title": document["title"],
"key_words": document["key_words"],
}
datasets.append({"content": txt, "meta": meta_data})
file_log.write(document["id"] + "\tsuccess" + "\n")
return {"index_name": document["id"], "content": datasets}


from multiprocessing import Pool


def indexing_main_body(csv_name):
dataset = process_content(csv_name)
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=300, chunk_overlap=0, filters=["\n"])
all_data = []
for document in dataset:
text = document["content"]
text_splits = text_splitter.split_text(text)
datasets = []
for txt in text_splits:
meta_data = {
"name": document["name"],
}
datasets.append({"content": txt, "meta": meta_data})
all_data.append({"index_name": document["name"].lower(), "content": datasets})
dataset = read_data(csv_name)
# Multiprocessing for splitting text
pool = Pool(processes=10)
all_data = pool.map(run_multi_process_splitter, dataset)

# Add body into separate index
for data in all_data:
index_name = data["index_name"]
print(index_name)
datasets = data["content"]
offline_ann(index_name, datasets)
thread_count = 10
with ThreadPoolExecutor(max_workers=thread_count) as executor:
executor.map(run_thread_index, all_data)


if __name__ == "__main__":
if args.delete_index:
delete_data(args.index_name)
delete_data(args.root_index_name)
delete_data(args.child_index_name)
# hierarchical index abstract, keywords
# indexing_abstract(args.file_name)
indexing_abstract(args.file_name)
# hierarchical index main body
indexing_main_body(args.file_name)
5 changes: 4 additions & 1 deletion pipelines/pipelines/document_stores/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ def _create_document_index(self, index_name: str, headers: Optional[Dict[str, st
mapping["mappings"]["properties"].update({field: {"type": "text"}})

if self.embedding_field:
mapping["settings"]["number_of_shards"] = 1
mapping["settings"]["number_of_replicas"] = 2
mapping["mappings"]["properties"][self.embedding_field] = {
"type": self.vector_type,
"dims": self.embedding_dim,
Expand Down Expand Up @@ -485,7 +487,8 @@ def _create_label_index(self, index_name: str, headers: Optional[Dict[str, str]]
"updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||epoch_millis"}
# TODO add pipeline_hash and pipeline_name once we migrated the REST API to pipelines
}
}
},
"settings": {"number_of_shards": 1, "number_of_replicas": 2},
}
try:
self.client.indices.create(index=index_name, body=mapping, headers=headers)
Expand Down
1 change: 1 addition & 0 deletions pipelines/pipelines/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
CharacterTextSplitter,
PreProcessor,
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
)
from pipelines.nodes.prompt import PromptModel, PromptNode, Shaper
from pipelines.nodes.question_generator import QuestionGenerator
Expand Down
1 change: 1 addition & 0 deletions pipelines/pipelines/nodes/preprocessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
from pipelines.nodes.preprocessor.text_splitter import (
CharacterTextSplitter,
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
)

0 comments on commit 2f3eac3

Please sign in to comment.