diff --git a/slm/pipelines/examples/structured_index/README.md b/slm/pipelines/examples/structured_index/README.md
new file mode 100644
index 000000000000..48f371fd88bc
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/README.md
@@ -0,0 +1,181 @@
+# 文档层次化索引
+
+## 方法
+
+1. 加载数据(load):把需要处理的 pdf 或者 html 文档加载到流程中。
+2. 文档语篇结构解析(parse):使用大语言模型对文档进行语篇结构解析,根据语义重新切分文章,并解析出文档的语篇结构树。
+3. 层次化摘要生成(summary):根据语篇结构树,自底向上对文档解析结果进行层次化摘要生成,生成不同层次信息的摘要。
+4. 层次化索引构建(index):通过文本编码器,将这些不同层次的文本摘要片段嵌入到稠密检索的向量空间中,从而构建一个层次化文本索引。这种索引不仅包含了局部信息,还包含了较高层次的全局信息,能够支持对多种粒度信息的召回,以适应用户查询中的不同信息需求。
+
+## 安装
+
+### 环境依赖
+
+推荐安装 gpu 版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/conda/linux-conda.html),以 cuda11.7的 paddle 为例,安装命令如下:
+
+```bash
+conda install paddlepaddle-gpu==2.6.2 cudatoolkit=11.7 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge
+```
+安装其他依赖:
+```bash
+pip install -r requirements.txt
+```
+
+### 数据准备
+
+- 源文档:需要构建层次化索引的文档语料,如路径`data/source`下的文档示例。每篇文档为单个文件,目前支持 PDF 或 HTML 格式。
+脚本`data/source/download.sh`可用于下载示例文档:
+```bash
+apt install jq -y # 安装 jq 工具, 需要系统权限,若已安装可跳过
+cd data/source
+bash download.sh
+```
+- 查询文件:用户查询文本,目前支持 json 格式,单条查询为`query_id: query_text`,如查询文件示例`data/query.json`。
+
+
+## 运行
+
+### 索引构建
+
+为单个文档文件构建层次化索引:
+```bash
+python construct_index.py \
+--source data/source/2308.12950.pdf \
+--parse_model_name_or_path Qwen/Qwen2-72B-Instruct \
+--summarize_model_name_or_path Qwen/Qwen2-72B-Instruct \
+--encode_model_name_or_path BAAI/bge-large-en-v1.5 \
+--log_dir .logs
+```
+
+为整个路径下的所有文档文件构建层次化索引:
+```bash
+python construct_index.py \
+--source data/source \
+--parse_model_name_or_path Qwen/Qwen2-72B-Instruct \
+--summarize_model_name_or_path Qwen/Qwen2-72B-Instruct \
+--encode_model_name_or_path BAAI/bge-large-en-v1.5 \
+--log_dir .logs
+```
+
+可调整参数包括:
+- `source`: 需要构建层次化索引的所有源文件的目录路径,或需要构建层次化索引的单个源文件
+
+- `parse_model_name_or_path`: 用于文档语篇结构解析(parse)的模型的名称或路径
+
+- `parse_model_url`: 用于文档语篇结构解析(parse)的模型的 URL。如果不需要则不要写这个参数
+
+- `summarize_model_name_or_path`: 用于文档层次化摘要(summarize)的模型的名称或路径
+
+- `summarize_model_url`: 用于文档层次化摘要(summarize)的模型的 URL。如果不需要则不要写这个参数
+
+- `encode_model_name_or_path`: 用于文本编码的模型的名称或路径
+
+- `log_dir`: 保存日志文件的路径
+
+层次化索引的结果会保存在 `data/index/{encode_model_name_or_path}`, 每个源文档在此路径下有两个对应的缓存文件用于检索:`.pkl`文件包含源文档的层次化摘要文本,`.npy`文件包含对应的摘要文本编码向量。
+例如,对 `data/source/CodeLlama.pdf` 构建的层次化索引缓存文件包括 `index/BAAI/bge-large-en-v1.5/CodeLlama.npy` 和 `index/BAAI/bge-large-en-v1.5/CodeLlama.pkl`。
+
+### 检索输出
+
+在层次化索引中检索查询相关摘要片段,并输出检索结果。
+
+以文件形式查询多条文本:
+```bash
+python query.py \
+--search_result_dir data/search_result \
+--encode_model_name_or_path BAAI/bge-large-en-v1.5 \
+--log_dir .logs \
+--query_filepath data/query.json \
+--top_k 5 \
+--embedding_batch_size 128
+```
+
+以文本形式查询单条文本:
+```bash
+python query.py \
+--search_result_dir data/search_result \
+--encode_model_name_or_path BAAI/bge-large-en-v1.5 \
+--log_dir .logs \
+--query_text "What is the relationship between CodeLlama and Llama?" \
+--top_k 5 \
+--embedding_batch_size 1
+```
+
+可调整参数为:
+- `search_result_dir`: 保存查询的检索结果的路径
+
+- `encode_model_name_or_path`: 用于文本编码的模型的名称或路径
+
+- `query_filepath`: query 的文件路径。如果有,它必须是一个查询字典的 JSON 文件
+
+- `query_text`: 单条 query 的文本。如果有,它必须是一个字符串
+
+- `top_k`: 设置为每条查询返回前 top_k 个结果
+
+- `embedding_batch_size`: 编码 query 时的批处理大小
+
+- `log_dir`: 保存日志文件的路径
+
+检索结果保存在`{search_result_dir}/{encode_model_name_or_path}`路径下。此路径下的每个结果文件对应一次查询调用,包含若干条查询,即每次会在`{search_result_dir}/{encode_model_name_or_path}`路径下产生一个`query_{时间戳}.json`的文件记录查询结果,由查询 ID 唯一标识单次查询中的每条查询。若通过`query_text`传入查询文本,则查询 ID 设置为`"0"`。
+
+例如上述单条查询的检索结果如下:
+```json
+{
+ "0": {
+ "query": "What is the relationship between CodeLlama and Llama?",
+ "hits": [
+ {
+ "corpus_id": 122,
+ "score": 0.7032119035720825,
+ "content": "CoDE LLAMA is a family of large language models for code, based on LLAMA 2, designed for state-of-the-art performance in programming tasks, including infilling, large context handling, and zero-shot instruction-following, with a focus on safety and alignment.",
+ "source": "data/source/2308.12950.pdf",
+ "level": 0
+ },
+ {
+ "corpus_id": 127,
+ "score": 0.6490256786346436,
+ "content": "CoDE LLAMA models are general-purpose code generation tools, with specialized versions like CoDE LLAMA -PyTHON for Python code and CoDE LLAMA -INsTRUCT for understanding and executing instructions.",
+ "source": "data/source/2308.12950.pdf",
+ "level": 3
+ },
+ {
+ "corpus_id": 128,
+ "score": 0.6398724317550659,
+ "content": "CoDE LLAMA -PyTHON is specialized for Python code generation, while CoDE LLAMA -INsTRUCT models are designed to understand and execute instructions.",
+ "source": "data/source/2308.12950.pdf",
+ "level": 4
+ },
+ {
+ "corpus_id": 161,
+ "score": 0.6116989254951477,
+ "content": "CoDE LLAMA models are designed for real-world applications, excelling in infilling and large context handling, and they achieve state-of-the-art performance on code generation benchmarks while ensuring safety and alignment.",
+ "source": "data/source/2308.12950.pdf",
+ "level": 2
+ },
+ {
+ "corpus_id": 129,
+ "score": 0.6056838631629944,
+ "content": "CoDE LLAMA -INsTRUCT are instruction-following models designed to understand and execute instructions.",
+ "source": "data/source/2308.12950.pdf",
+ "level": 5
+ }
+ ]
+ }
+}
+```
+其中,每条 query 检索结果的格式如下:
+```
+查询ID: {
+ "query": 查询文本,
+ "hits": [
+ {
+ "corpus_id": 本条语料在所有语料中的编号,
+ "score": 相似度分数,
+ "content": 语料摘要内容,
+ "source": 语料来源文档的路径,
+ "level": 本条语料在源文档中的信息粒度层级, 0代表最高级, 数字越大,信息粒度越细
+ },
+ ...
+ ]
+ }
+```
\ No newline at end of file
diff --git a/slm/pipelines/examples/structured_index/arguments.py b/slm/pipelines/examples/structured_index/arguments.py
new file mode 100644
index 000000000000..6665da3f3145
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/arguments.py
@@ -0,0 +1,66 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+
+
+@dataclass
+class StructuredIndexerArguments:
+ """
+ Arguments for StructuredIndexer.
+ """
+
+ log_dir: str = field(default=".logs", metadata={"help": "log directory"})
+
+
+@dataclass
+class StructuredIndexerEncodeArguments(StructuredIndexerArguments):
+ """
+ Arguments for encoding corpus in StructuredIndexer.
+ """
+
+ encode_model_name_or_path: str = field(
+ default="BAAI/bge-large-en-v1.5", metadata={"help": "encode model name or path"}
+ )
+
+
+@dataclass
+class StructuredIndexerPipelineArguments(StructuredIndexerEncodeArguments):
+ """
+ Arguments for building StructuredIndex pipeline for a single corpus file.
+ """
+
+ source: str = field(default="data/source", metadata={"help": "source file or directory"})
+ parse_model_name_or_path: str = field(
+ default="Qwen/Qwen2-7B-Instruct", metadata={"help": "parse model name or path"}
+ )
+ parse_model_url: str = field(default=None, metadata={"help": "parse model url if you use api"})
+ summarize_model_name_or_path: str = field(
+ default="Qwen/Qwen2-7B-Instruct",
+ metadata={"help": "summarize model name or path"},
+ )
+ summarize_model_url: str = field(default=None, metadata={"help": "summarize model url if you use api"})
+
+
+@dataclass
+class RetrievalArguments(StructuredIndexerEncodeArguments):
+ """
+ Arguments for StructuredIndex to retrieve.
+ """
+
+ search_result_dir: str = field(default="search_result", metadata={"help": "search result directory"})
+ query_filepath: str = field(default="query.json", metadata={"help": "query file path"})
+ query_text: str = field(default=None, metadata={"help": "query text"})
+ top_k: int = field(default=5, metadata={"help": "top k results for each query"})
+ embedding_batch_size: int = field(default=128, metadata={"help": "embedding batch size for queries"})
diff --git a/slm/pipelines/examples/structured_index/construct_index.py b/slm/pipelines/examples/structured_index/construct_index.py
new file mode 100644
index 000000000000..f90f40ddd74b
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/construct_index.py
@@ -0,0 +1,48 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from arguments import StructuredIndexerPipelineArguments
+from src.structured_index import StructuredIndexer
+
+from paddlenlp.trainer import PdArgumentParser
+
+if __name__ == "__main__":
+ parser = PdArgumentParser(StructuredIndexerPipelineArguments)
+ (args,) = parser.parse_args_into_dataclasses()
+
+ structured_indexer = StructuredIndexer(log_dir=args.log_dir)
+ assert os.path.exists(args.source)
+ if os.path.isfile(args.source):
+ structured_indexer.pipeline(
+ filepath=args.source,
+ parse_model_name_or_path=args.parse_model_name_or_path,
+ parse_model_url=args.parse_model_url,
+ summarize_model_name_or_path=args.summarize_model_name_or_path,
+ summarize_model_url=args.summarize_model_url,
+ encode_model_name_or_path=args.encode_model_name_or_path,
+ )
+ else:
+ for root, _, files in os.walk(args.source):
+ for file in files:
+ filepath = os.path.join(root, file)
+ structured_indexer.pipeline(
+ filepath=filepath,
+ parse_model_name_or_path=args.parse_model_name_or_path,
+ parse_model_url=args.parse_model_url,
+ summarize_model_name_or_path=args.summarize_model_name_or_path,
+ summarize_model_url=args.summarize_model_url,
+ encode_model_name_or_path=args.encode_model_name_or_path,
+ )
diff --git a/slm/pipelines/examples/structured_index/data/query.json b/slm/pipelines/examples/structured_index/data/query.json
new file mode 100644
index 000000000000..62a0ce01f349
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/data/query.json
@@ -0,0 +1,7 @@
+{
+ "0" : "What is big model alignment?",
+ "1" : "What are the benefits of aligning large models?",
+ "2" : "How to improve the decoding speed of large language model inference?",
+ "3" : "What is the difference between CodeLlama and Llama?",
+ "4" : "What is Grouped Multiple-Degradation Restoration with Image Degradation Similarity?"
+}
\ No newline at end of file
diff --git a/slm/pipelines/examples/structured_index/data/source/download.sh b/slm/pipelines/examples/structured_index/data/source/download.sh
new file mode 100644
index 000000000000..eaa2af6a721f
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/data/source/download.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+
+download() {
+ local url=$1
+ local ext=$2
+
+ # 获取文件的basename
+ local filename=$(basename "$url")
+
+ # 检查文件名是否以ext结尾
+ if [[ "$filename" != *".$ext" ]]; then
+ filename="$filename.$ext"
+ fi
+
+ # 下载文件
+ echo "Downloading $url as $filename"
+ curl -o "$filename" "$url"
+}
+
+# 读取JSON文件并解析所有ext和对应的URL
+json_file="source_url.json"
+exts=$(jq -r 'keys[]' "$json_file")
+
+# 遍历每个ext并下载对应的文件
+for ext in $exts; do
+ urls=$(jq -r --arg ext "$ext" '.[$ext][]' "$json_file")
+ for url in $urls; do
+ download "$url" "$ext"
+ done
+done
\ No newline at end of file
diff --git a/slm/pipelines/examples/structured_index/data/source/source_url.json b/slm/pipelines/examples/structured_index/data/source/source_url.json
new file mode 100644
index 000000000000..cc30537f6907
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/data/source/source_url.json
@@ -0,0 +1,10 @@
+{
+ "pdf": [
+ "https://arxiv.org/pdf/2406.15877",
+ "https://arxiv.org/pdf/2407.12273",
+ "https://arxiv.org/pdf/2308.12950",
+ "https://arxiv.org/pdf/1810.04805",
+ "https://arxiv.org/pdf/2402.12374",
+ "https://aclanthology.org/2023.ccl-2.7.pdf"
+ ]
+}
\ No newline at end of file
diff --git a/slm/pipelines/examples/structured_index/query.py b/slm/pipelines/examples/structured_index/query.py
new file mode 100644
index 000000000000..00488e4dffc4
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/query.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from arguments import RetrievalArguments
+from src.structured_index import StructuredIndexer
+
+from paddlenlp.trainer import PdArgumentParser
+
+if __name__ == "__main__":
+ parser = PdArgumentParser(RetrievalArguments)
+ (args,) = parser.parse_args_into_dataclasses()
+
+ structured_indexer = StructuredIndexer(log_dir=args.log_dir)
+
+ from src.utils import load_data
+
+ if args.query_text is None:
+ queries_dict = load_data(args.query_filepath, mode="Searching")
+ else:
+
+ assert isinstance(args.query_text, str)
+ queries_dict = {"0": args.query_text}
+
+ structured_indexer.search(
+ queries_dict=queries_dict,
+ output_dir=args.search_result_dir,
+ model_name_or_path=args.encode_model_name_or_path,
+ top_k=args.top_k,
+ embedding_batch_size=args.embedding_batch_size,
+ )
diff --git a/slm/pipelines/examples/structured_index/requirements.txt b/slm/pipelines/examples/structured_index/requirements.txt
new file mode 100644
index 000000000000..343739772537
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/requirements.txt
@@ -0,0 +1,7 @@
+faiss-gpu==1.7.2
+paddlenlp==3.0.0b2
+tqdm
+numpy
+fitz
+frontend
+paddleocr==2.7.3
\ No newline at end of file
diff --git a/slm/pipelines/examples/structured_index/src/index.py b/slm/pipelines/examples/structured_index/src/index.py
new file mode 100644
index 000000000000..9718089bb9e0
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/src/index.py
@@ -0,0 +1,145 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import math
+import os
+import pickle
+from typing import Dict, List, Sequence
+
+import faiss
+import numpy as np
+from tqdm import tqdm
+
+from .utils import encode, load_data, load_model
+
+
+class Indexer:
+ def __init__(self, model_name_or_path: str = "BAAI/bge-large-en-v1.5", url=None):
+ self.model_name = model_name_or_path
+ self.model, self.tokenizer = load_model(model_name_or_path=model_name_or_path)
+
+ def __call__(self, file_path: str) -> Dict:
+ data = load_data(file_path=file_path, mode="Indexing")
+ return self.index_doc(data=data, file_path=file_path)
+
+ def get_corpus_embedding(self, corpus: Sequence[str], batch_size: int = 128) -> np.ndarray:
+ """
+ Generate embeddings for a corpus of text by processing it in batches.
+
+ Args:
+ corpus (Sequence[str]): A sequence of text strings to be embedded.
+ batch_size (int, optional): The number of text strings to process in each batch. Defaults to 128.
+
+ Returns:
+ np.ndarray: A numpy array containing the embeddings for the entire corpus.
+ """
+ logging.info("Getting embedding")
+ for i in tqdm(range(math.ceil(len(corpus) / batch_size))):
+ corpus_embeddings = encode(
+ sentences=corpus[i * batch_size : (i + 1) * batch_size],
+ model=self.model,
+ tokenizer=self.tokenizer,
+ convert_to_numpy=True,
+ )
+ if i == 0:
+ corpus_embeddings_list = corpus_embeddings
+ else:
+ corpus_embeddings_list = np.concatenate((corpus_embeddings_list, corpus_embeddings), axis=0)
+ return corpus_embeddings_list
+
+ def build_engine(self, corpus_embeddings: np.ndarray) -> faiss.IndexFlatIP:
+ """
+ Build a FAISS index using the provided corpus embeddings.
+
+ Args:
+ corpus_embeddings (np.ndarray): A numpy array containing the embeddings of the corpus.
+
+ Returns:
+ faiss.IndexFlatIP: A FAISS index built using the inner product metric.
+ """
+ embedding_dim = corpus_embeddings.shape[1]
+ index = faiss.IndexFlatIP(embedding_dim)
+ index.add(corpus_embeddings.astype("float32"))
+ return index
+
+ def index_doc(
+ self,
+ file_path: str,
+ data: Dict,
+ save: bool = True,
+ save_dir: str = "data/index",
+ ) -> Dict:
+ """
+ Index a document by extracting its content, generating embeddings, and optionally saving the results.
+
+ Args:
+ file_path (str): The path to the file being indexed.
+ data (Dict): A dictionary containing the document's data, including nodes with content and summary.
+ save (bool, optional): Whether to save the generated embeddings and metadata. Defaults to True.
+ save_dir (str, optional): The directory where the saved files will be stored. Defaults to "data/index".
+
+ Returns:
+ Dict: The original data dictionary, potentially modified during the indexing process.
+ """
+ corpus: List[str] = []
+ info: List[dict] = []
+ for node in data["nodes"]:
+ if "summary" in node and len(node["summary"]) > 0:
+ content = node["summary"]
+ else:
+ content = node["content"]
+ corpus.append(content)
+ level = node["level"] if "level" in node else 0
+ info.append({"content": content, "source": data["source"], "level": level})
+ corpus_embeddings: np.ndarray = self.get_corpus_embedding(corpus)
+
+ filename = os.path.splitext(os.path.basename(file_path))[0]
+ if save:
+ save_dir = os.path.join(save_dir, self.model_name)
+ os.makedirs(save_dir, exist_ok=True)
+ np.save(os.path.join(save_dir, f"{filename}.npy"), corpus_embeddings)
+ with open(os.path.join(save_dir, f"{filename}.pkl"), "wb") as f:
+ pickle.dump(info, f)
+ logging.info(f"Index cache and Corpus saved to {save_dir}/{filename}.*")
+
+ return data
+
+ def _search(
+ self,
+ index: faiss.IndexFlatIP,
+ query_embeddings: np.ndarray,
+ top_k: int,
+ batch_size: int = 4000,
+ ) -> List[Dict]:
+ """
+ retrieves the top_k hits for each query embedding
+ Calling index.search()
+ """
+ hits = []
+ for i in range(math.ceil(len(query_embeddings) / batch_size)):
+ q_emb_matrix = query_embeddings[i * batch_size : (i + 1) * batch_size]
+ res_dist, res_p_id = index.search(q_emb_matrix.astype("float32"), top_k)
+ assert len(res_dist) == len(q_emb_matrix)
+ assert len(res_p_id) == len(q_emb_matrix)
+
+ for i in range(len(q_emb_matrix)):
+ passages = []
+ assert len(res_p_id[i]) == len(res_dist[i])
+ for j in range(min(top_k, len(res_p_id[i]))):
+ pid = res_p_id[i][j]
+ score = res_dist[i][j]
+ passages.append({"corpus_id": int(pid), "score": float(score)})
+ hits.append(passages)
+ return hits
diff --git a/slm/pipelines/examples/structured_index/src/load.py b/slm/pipelines/examples/structured_index/src/load.py
new file mode 100644
index 000000000000..9d0f6e14e188
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/src/load.py
@@ -0,0 +1,151 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import os
+from typing import Dict, List
+
+
+class Loader:
+ @classmethod
+ def load_file(cls, filepath: str) -> Dict:
+ """
+ Load a file based on its extension using the appropriate loader method.
+
+ Args:
+ filepath (str): The path to the file to be loaded.
+
+ Returns:
+ Dict: A dictionary containing the loaded document data.
+ None: If the file type is unsupported.
+ """
+ loader = {".pdf": cls.load_pdf, ".json": cls.load_json, ".html": cls.load_html}
+ _, ext = os.path.splitext(os.path.basename(filepath))
+ if ext in loader:
+ doc = loader[ext](filepath=filepath)
+ return doc
+ else:
+ logging.warning(f"Unsupported file type: {filepath}")
+ return None
+
+ @classmethod
+ def load_dir(cls, input_dir: str, output_dir: str, save: bool = True):
+ """
+ Load all files in a directory and optionally save their content as JSON files.
+
+ Args:
+ input_dir (str): The directory containing the files to be loaded.
+ output_dir (str): The directory where the loaded content will be saved as JSON files.
+ save (bool, optional): Whether to save the loaded content. Defaults to True.
+ """
+ for root, _, files in os.walk(input_dir):
+ for file in files:
+ filepath = os.path.join(root, file)
+ filename, _ = os.path.splitext(os.path.basename(file))
+ file_output_dir = os.path.join(output_dir, f"{filename}.json")
+ if os.path.exists(file_output_dir):
+ if os.path.getsize(file_output_dir) > 0:
+ logging.info(f"File already exists: {file_output_dir}")
+ continue
+ doc = cls.load_file(filepath=filepath)
+ if save and doc is not None:
+ with open(file_output_dir, "w") as f:
+ json.dump(doc, f, ensure_ascii=False, indent=4)
+ logging.info(f"Saved to: {file_output_dir}")
+
+ @classmethod
+ def load_pdf(cls, filepath: str) -> Dict:
+ """
+ Load a PDF file, extract images from each page, perform OCR on the images, and return the extracted text content.
+
+ Args:
+ filepath (str): The path to the PDF file to be loaded.
+
+ Returns:
+ Dict: A dictionary containing the extracted text content from the PDF.
+ """
+
+ logging.info(f"Loading PDF: {filepath}")
+
+ from paddleocr import PaddleOCR
+
+ # Paddleocr目前支持的多语言语种可以通过修改lang参数进行切换
+ # 例如`ch`, `en`, `fr`, `german`, `korean`, `japan`
+ PAGE_NUM = 10 # 将识别页码前置作为全局,防止后续打开pdf的参数和前文识别参数不一致 / Set the recognition page number
+ # ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=PAGE_NUM) # need to run only once to download and load model into memory
+ ocr = PaddleOCR(
+ use_angle_cls=True, lang="ch", page_num=PAGE_NUM, use_gpu=1
+ ) # 如果需要使用GPU,请取消此行的注释 并注释上一行 / To Use GPU,uncomment this line and comment the above one.
+ result = ocr.ocr(filepath, cls=True)
+
+ doc = {"source": filepath, "nodes": []}
+ for idx in range(len(result)):
+ res = result[idx]
+ if res is None:
+ # 识别到空页就跳过,防止程序报错 / Skip when empty result detected to avoid TypeError:NoneType
+ logging.info(f"Empty page {idx+1} detected in {filepath}, skip it.")
+ continue
+ for line in res:
+ assert isinstance(line, list) and len(line) == 2
+ assert isinstance(line[1], tuple) and len(line[1]) == 2
+ assert isinstance(line[1][0], str) and isinstance(line[1][1], float)
+ doc["nodes"].append(
+ {
+ "id": len(doc["nodes"]),
+ "content": line[1][0],
+ "confidence": line[1][1],
+ }
+ )
+ return doc
+
+ @classmethod
+ def load_json(cls, filepath: str) -> dict:
+ with open(filepath, "r") as f:
+ data = json.load(f)
+ return data
+
+ @classmethod
+ def load_html(cls, filepath: str) -> dict:
+ """
+ Load an HTML file, extract text content, and return it as a structured dictionary.
+
+ Args:
+ filepath (str): The path to the HTML file to be loaded.
+
+ Returns:
+ dict: A dictionary containing the extracted text content from the HTML file.
+ """
+ from bs4 import BeautifulSoup
+
+ doc = {"source": filepath, "nodes": []}
+ with open(filepath, "r", encoding="utf-8") as f:
+ html = f.read()
+ soup = BeautifulSoup(html, "html.parser")
+ html_text = soup.get_text()
+
+ node_list: List[Dict] = []
+ for t in html_text.splitlines():
+ content = t.strip()
+ if len(content) <= 0:
+ continue
+ node_list.append(
+ {
+ "id": len(node_list),
+ "content": content,
+ }
+ )
+ doc["nodes"] = node_list
+
+ return doc
diff --git a/slm/pipelines/examples/structured_index/src/parse.py b/slm/pipelines/examples/structured_index/src/parse.py
new file mode 100644
index 000000000000..d677e1634d99
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/src/parse.py
@@ -0,0 +1,128 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from .utils import get_messages, get_response, load_data, load_model
+
+
+class DocumentStructureParser:
+ def __init__(
+ self,
+ model_name_or_path: str = "Qwen/Qwen2-7B-Instruct",
+ url: str = None,
+ ):
+ self.model_name = model_name_or_path
+ if url is None:
+ self.model, self.tokenizer = load_model(model_name_or_path=model_name_or_path)
+ self.url = url
+ self.prompt_template = """**Structured document text**
+Insert # into the text, using the number of # to represent the hierarchical level of the content, while keeping all the original information.
+The structured results are output in the form of markdown.
+```markdown
+```
+
+**Input text:**
+[input]
+**Structured Text:**
+"""
+
+ def __call__(self, file_path: str) -> Dict:
+ data = load_data(file_path=file_path, mode="Parsing")
+ return self.parse_doc(data=data)
+
+ def extract_text(self, response: str) -> str:
+ """
+ Extract text content from a markdown code block within a given response string.
+
+ Args:
+ response (str): The response string containing markdown code blocks.
+
+ Returns:
+ str: The extracted text content from the markdown code block.
+ """
+ start_index = response.rfind("```markdown") + len("```markdown")
+ end_index = response.rfind("```")
+ content = response[start_index:end_index]
+ return content.strip()
+
+ def parse_doc(self, data: Dict, max_text_length=512, repeat_length=2048) -> Dict:
+ """
+ Parse a document by extracting its content, formatting it according to a prompt template, and structuring the content into a hierarchical format.
+
+ Args:
+ data (Dict): A dictionary containing the document's data, including nodes with content.
+ max_text_length (int, optional): The maximum length of text to consider from each node. Defaults to 512.
+ repeat_length (int, optional): The length at which the text should be repeated and inserted into the prompt template. Defaults to 2048.
+
+ Returns:
+ Dict: A dictionary containing the parsed and structured content of the document.
+ """
+ contents = [item["content"] for item in data["nodes"]]
+ content = "".join(c[:max_text_length].strip() for c in contents)
+ content = self.prompt_template.replace("[input]", content)
+ if len(content) <= repeat_length:
+ content = content.replace("Insert # into the text,", "Repeat the text and insert # into the text,")
+
+ messages = [
+ {
+ "role": "user",
+ "content": content,
+ }
+ ]
+ logging.debug(messages[0]["content"])
+ if self.url is None:
+ predicted_str = get_response(messages=messages, tokenizer=self.tokenizer, model=self.model)
+ else:
+ predicted_str = get_messages(
+ messages=messages,
+ model_name=self.model_name,
+ url=self.url,
+ temperature=0.0,
+ )[0]
+ logging.debug(f"parse result =\n{predicted_str}")
+ predicted_str = self.extract_text(response=predicted_str)
+
+ predicted_str_list: List[str] = predicted_str.splitlines()
+ cur_dictionary = []
+ node_id = -1
+ data["nodes"] = []
+ for line in predicted_str_list:
+ line = line.strip()
+ if len(line) == 0:
+ continue
+ node_id += 1
+ level = -1
+ while line.startswith("#"):
+ level += 1
+ line = line[1:]
+ content = line.strip()
+ if level == -1:
+ level = len(cur_dictionary)
+ data["nodes"].append({"id": node_id, "content": content, "level": level})
+
+ while len(cur_dictionary) > level:
+ cur_dictionary.pop()
+ parent_id = cur_dictionary[-1] if len(cur_dictionary) > 0 else -1
+ cur_dictionary.append(node_id)
+ assert len(cur_dictionary) <= level + 1, f"len={len(cur_dictionary)}, level={level}"
+
+ if parent_id > -1:
+ data["nodes"][node_id]["parent_id"] = parent_id
+ if "child_ids" not in data["nodes"][parent_id]:
+ data["nodes"][parent_id]["child_ids"] = []
+ data["nodes"][parent_id]["child_ids"].append(node_id)
+
+ return data
diff --git a/slm/pipelines/examples/structured_index/src/structured_index.py b/slm/pipelines/examples/structured_index/src/structured_index.py
new file mode 100644
index 000000000000..2409f3170590
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/src/structured_index.py
@@ -0,0 +1,228 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import os
+import pickle
+from datetime import datetime
+from typing import Dict, List, Optional
+
+import faiss
+import numpy as np
+
+from .index import Indexer
+from .load import Loader
+from .parse import DocumentStructureParser
+from .summarize import Summarizer
+
+
+class StructuredIndexer(object):
+ def __init__(self, log_dir: Optional[str] = None):
+ self.set_logging(log_dir=log_dir)
+
+ def set_logging(self, log_dir: Optional[str] = None):
+ """
+ Configures the logging settings for the application.
+
+ Args:
+ log_dir (Optional[str]): The directory where log files will be stored. If None, the default directory '.logs' is used.
+
+ Returns:
+ None
+ """
+ if log_dir is None:
+ log_dir = ".logs"
+ os.makedirs(log_dir, exist_ok=True)
+ log_filename = os.path.join(log_dir, f"{str(datetime.now())}.log")
+
+ level = logging.DEBUG
+ logging.basicConfig(
+ level=level,
+ format="%(asctime)s [%(levelname)s] at %(filename)s,%(lineno)d: %(message)s",
+ datefmt="%Y-%m-%d(%a)%H:%M:%S",
+ filename=log_filename,
+ filemode="w",
+ )
+
+ console = logging.StreamHandler()
+ console.setLevel(level)
+ formatter = logging.Formatter("[%(levelname)-8s] %(message)s")
+ console.setFormatter(formatter)
+ logging.getLogger().addHandler(console)
+
+ def prepare_dir(self, input_dir: str, mode: str, output_dir: str = None, save: bool = True):
+ """
+ Prepares the input directory for processing and optionally creates an output directory.
+
+ Args:
+ input_dir (str): The path to the input directory.
+ mode (str): The mode in which the files are being processed (e.g., "read", "write").
+ output_dir (str, optional): The path to the output directory. Defaults to None.
+ save (bool, optional): Whether to create the output directory if it doesn't exist. Defaults to True.
+
+ Raises:
+ NotADirectoryError: If the `input_dir` is not a directory.
+ FileNotFoundError: If the `input_dir` does not exist.
+
+ Returns:
+ None
+
+ Logs:
+ - An info message indicating the mode of operation and the input directory.
+ - An error message if the `input_dir` is not a directory.
+ - An error message if the `input_dir` does not exist.
+ """
+ if not os.path.isdir(input_dir):
+ logging.error(f"Path {input_dir} is not a directory.")
+ raise NotADirectoryError(f"Path {input_dir} is not a directory.")
+ if not os.path.exists(input_dir):
+ logging.error(f"Path {input_dir} does not exist.")
+ raise FileNotFoundError(f"Path {input_dir} does not exist.")
+ logging.info(f"{mode} files in {input_dir}")
+ if save and output_dir is not None:
+ os.makedirs(output_dir, exist_ok=True)
+
+ def search(
+ self,
+ queries_dict: Dict,
+ input_dir: str = "data/index",
+ output_dir: str = "data/search_result",
+ model_name_or_path: str = "BAAI/bge-large-en-v1.5",
+ top_k: int = 10,
+ embedding_batch_size: int = 128,
+ save: bool = True,
+ ) -> List[Dict]:
+ """
+ Perform a semantic search on a pre-indexed corpus using a specified embedding model.
+
+ Args:
+ queries_dict (Dict): A dictionary where keys are query identifiers and values are the actual queries.
+ input_dir (str, optional): The directory containing the pre-indexed corpus and embeddings. Defaults to "data/index".
+ output_dir (str, optional): The directory where the search results will be saved. Defaults to "data/search_result".
+ model_name_or_path (str, optional): The name or path of the embedding model to use. Defaults to "BAAI/bge-large-en-v1.5".
+ top_k (int, optional): The number of top results to retrieve for each query. Defaults to 10.
+ embedding_batch_size (int, optional): The batch size for embedding the queries. Defaults to 128.
+ save (bool, optional): Whether to save the search results to a file. Defaults to True.
+
+ Returns:
+ List[Dict]: A list of dictionaries containing the search results for each query.
+
+ Raises:
+ AssertionError: If the number of index cache files does not match the number of corpus files, or if the corpus embeddings do not match in shape.
+
+ Notes:
+ - The input directory should contain the pre-computed embeddings and corresponding corpus files.
+ - The output directory will be created if it does not exist.
+ - The search results are saved in a JSON file with a timestamp in the filename.
+ """
+ input_dir = os.path.join(input_dir, model_name_or_path) # The embedding model must be the same
+ output_dir = os.path.join(output_dir, model_name_or_path)
+ self.prepare_dir(input_dir=input_dir, output_dir=output_dir, save=True, mode="Searching")
+
+ indexcachefiles = []
+ corpusfiles = []
+ for root, _, files in os.walk(input_dir):
+ for file in files:
+ if file.endswith(".npy"):
+ indexcachefiles.append(os.path.join(root, file))
+ document_name = os.path.splitext(os.path.basename(file))[0]
+ corpusfiles.append(os.path.join(root, document_name + ".pkl"))
+ logging.info(f"Loading Index cache from {indexcachefiles[-1]}")
+ logging.info(f"Loading Corpus from {corpusfiles[-1]}")
+ assert len(indexcachefiles) == len(corpusfiles) and len(indexcachefiles) > 0
+
+ indexer = Indexer(model_name_or_path=model_name_or_path)
+
+ corpus_infos: List[Dict] = []
+ for file in corpusfiles:
+ corpus_infos.extend(pickle.load(open(file, "rb")))
+ logging.info(f"corpus/index quantity: {len(corpus_infos)}")
+ corpus_embeddings = np.load(indexcachefiles[0])
+ for file in indexcachefiles[1:]:
+ temp_corpus_embeddings = np.load(file)
+ assert corpus_embeddings.shape[1] == temp_corpus_embeddings.shape[1]
+ corpus_embeddings = np.concatenate((corpus_embeddings, temp_corpus_embeddings), axis=0)
+ index: faiss.IndexFlatIP = indexer.build_engine(corpus_embeddings)
+ assert len(corpus_infos) == index.ntotal
+
+ queries = list(queries_dict.values())
+ keys = list(queries_dict.keys())
+ query_embeddings = indexer.get_corpus_embedding(queries, batch_size=embedding_batch_size)
+ hits = indexer._search(index, query_embeddings, top_k)
+ assert len(hits) == len(queries)
+
+ for passages in hits:
+ for p_dict in passages:
+ id = p_dict["corpus_id"]
+ if id < len(corpus_infos) and id >= 0:
+ p_dict["content"] = corpus_infos[id]["content"]
+ p_dict["source"] = corpus_infos[id]["source"]
+ p_dict["level"] = corpus_infos[id]["level"]
+ if save:
+ result_dict = {}
+ for i in range(len(hits)):
+ hit_dict = {"query": queries[i], "hits": hits[i]}
+ assert queries_dict[keys[i]] == queries[i]
+ result_dict[keys[i]] = hit_dict
+ queryfilename = f'query_{datetime.now().strftime("%Y%m%d%H%M%S")}.json'
+ with open(os.path.join(output_dir, queryfilename), "w") as f:
+ json.dump(result_dict, f, ensure_ascii=False, indent=4)
+ logging.info(f"Saved to: {queryfilename}")
+ return result_dict
+
+ def pipeline(
+ self,
+ filepath: str,
+ parse_model_name_or_path: str = "Qwen/Qwen2-7B-Instruct",
+ parse_model_url: str = None,
+ summarize_model_name_or_path: str = "Qwen/Qwen2-7B-Instruct",
+ summarize_model_url: str = None,
+ encode_model_name_or_path: str = "BAAI/bge-large-en-v1.5",
+ ):
+ """
+ Process a document through a series of steps including loading, parsing, summarizing and indexing to get Structured Index.
+
+ Args:
+ filepath (str): The path to the document file to be processed.
+ parse_model_name_or_path (str, optional): The name or path of the model used for document parsing. Defaults to "Qwen/Qwen2-7B-Instruct".
+ parse_model_url (str, optional): The URL of the model used for document parsing. Defaults to None.
+ summarize_model_name_or_path (str, optional): The name or path of the model used for document summarization. Defaults to "Qwen/Qwen2-7B-Instruct".
+ summarize_model_url (str, optional): The URL of the model used for document summarization. Defaults to None.
+ encode_model_name_or_path (str, optional): The name or path of the model used for document encoding (indexing). Defaults to "BAAI/bge-large-en-v1.5".
+
+ Returns:
+ dict: A dictionary containing the processed document data, or None if the document could not be loaded.
+
+ Steps:
+ 1. Load the document from the specified file path.
+ 2. Parse the document structure using the specified parsing model.
+ 3. Summarize the document content using the specified summarization model.
+ 4. Index the document using the specified encoding model.
+
+ Notes:
+ - The document is processed in a sequence of steps: loading, parsing, summarizing, and indexing.
+ - Each step uses a different model specified by the respective model name or path and URL.
+ - If the document cannot be loaded, the function returns None.
+ """
+ doc = Loader.load_file(filepath=filepath)
+ if doc is None:
+ return None
+ processor = DocumentStructureParser(model_name_or_path=parse_model_name_or_path, url=parse_model_url)
+ doc = processor.parse_doc(data=doc)
+ processor = Summarizer(model_name_or_path=summarize_model_name_or_path, url=summarize_model_url)
+ doc = processor.summarize_doc(data=doc)
+ processor = Indexer(model_name_or_path=encode_model_name_or_path)
+ doc = processor.index_doc(file_path=filepath, data=doc)
+ return doc
diff --git a/slm/pipelines/examples/structured_index/src/summarize.py b/slm/pipelines/examples/structured_index/src/summarize.py
new file mode 100644
index 000000000000..b1ab0b710c0b
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/src/summarize.py
@@ -0,0 +1,151 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from tqdm import tqdm
+
+from .utils import get_messages, get_response, load_data, load_model
+
+
+class Summarizer:
+ def __init__(
+ self,
+ model_name_or_path: str = "Qwen/Qwen2-7B-Instruct",
+ url: str = None,
+ ):
+ self.model_name = model_name_or_path
+ if url is None:
+ self.model, self.tokenizer = load_model(model_name_or_path=model_name_or_path)
+ self.url = url
+ self.prompt_template = """# Generate text summaries based on text content and sub node summaries
+**Task Description:**
+You will receive a text and a summary of multiple child nodes that the text may depend on. Your task is to generate a concise and accurate summary for this text, no more than three sentences.
+**Tips:**
+1. Understand the text and child nodes: Carefully read the text itself and the summary of the child nodes (if any), and understand the logical relationship between them. Generate a summary with a focus on the text itself, noting that there may be missing spaces between words in the English text.
+2. Integrate information: Integrate key information from the child node summary (if any) into the text summary to ensure the completeness and accuracy of the information.
+3. Concise and clear: Use concise language to avoid redundant information and highlight key points. No more than two sentences.
+4. Language consistency: The language used in the abstract should be consistent with the language of the text itself.
+5. Compliant format: Output format is ..., If unable to generate a summary, simply output an empty summary .
+**Input:**
+
+Artificial intelligence (AI) is changing our lives. From autonomous vehicle to smart homes, AI has been used more and more widely.
+
+
+Autonomous vehicle use AI technology to achieve driverless driving and improve road safety and traffic efficiency.
+The smart home system achieves automation control through AI technology, improving the convenience and comfort of life.
+
+**Output:**
+
+AI is changing our lives. Its applications include autonomous vehicle that improve road safety and traffic efficiency, and smart home systems that improve the convenience and comfort of life.
+
+**Input:**
+
+[text]
+
+
+[child_summary]
+
+**Output:**"""
+
+ def __call__(self, file_path: str) -> Dict:
+ data = load_data(file_path=file_path, mode="Summarizing")
+ return self.summarize_doc(data=data)
+
+ def extract_summary(self, response: str) -> str:
+ """
+ Extract the summary content from a given response string enclosed within tags.
+
+ Args:
+ response (str): The response string containing the summary within tags.
+
+ Returns:
+ str: The extracted summary content.
+ """
+ start_index = response.rfind("") + len("")
+ end_index = response.rfind("")
+ content = response[start_index:end_index]
+ return content.strip()
+
+ def calc_topologic(self, nodes: List[Dict]) -> List[int]:
+ """
+ Calculate the topological order of nodes based on their parent-child relationships.
+
+ Args:
+ nodes (List[Dict]): A list of dictionaries representing nodes, each containing an 'id' and potentially a 'parent_id'.
+
+ Returns:
+ List[int]: A list of node IDs in topological order.
+ """
+ topologic = []
+ pred = [0 for _ in range(len(nodes))]
+ for node in nodes:
+ if "parent_id" in node and node["parent_id"] >= 0:
+ pred[node["parent_id"]] += 1
+ while len(topologic) < len(nodes):
+ for i in range(len(pred)):
+ if pred[i] == 0:
+ topologic.append(i)
+ pred[i] = -1
+ if "parent_id" in nodes[i]:
+ pred[nodes[i]["parent_id"]] -= 1
+ return topologic
+
+ def summarize_doc(self, data: Dict) -> Dict:
+ """
+ Summarize a document by generating summaries for each node based on their content and child summaries.
+
+ Args:
+ data (Dict): A dictionary containing the document's data, including nodes with content and child relationships.
+
+ Returns:
+ Dict: A dictionary containing the summarized content for each node.
+ """
+ topologic = self.calc_topologic(nodes=data["nodes"]) # Calculate topological order
+ logging.debug(f"Topologic: {topologic}")
+ assert len(topologic) == len(data["nodes"])
+ for cur_id in tqdm(topologic):
+ cur_node = data["nodes"][cur_id]
+ assert "summary" not in cur_node
+ assert cur_id == cur_node["id"]
+
+ message_content = self.prompt_template
+ message_content = message_content.replace("[text]", cur_node["content"])
+ child_summary = ""
+ if "child_ids" in cur_node:
+ for child_id in cur_node["child_ids"]:
+ assert "summary" in data["nodes"][child_id]
+ child_summary += data["nodes"][child_id]["summary"] + "\n"
+ message_content = message_content.replace("[child_summary]", child_summary.strip())
+ # logging.debug(f"Message content: {message_content}")
+ messages = [
+ {
+ "role": "user",
+ "content": message_content,
+ }
+ ]
+ if self.url is None:
+ response = get_response(messages=messages, tokenizer=self.tokenizer, model=self.model)
+ else:
+ response = get_messages(
+ messages=messages,
+ model_name=self.model_name,
+ url=self.url,
+ temperature=0.0,
+ )[0]
+ data["nodes"][cur_id]["summary"] = self.extract_summary(response=response)
+ logging.debug(f"Summary for node {cur_id}: {data['nodes'][cur_id]['summary']}\n")
+
+ return data
diff --git a/slm/pipelines/examples/structured_index/src/utils.py b/slm/pipelines/examples/structured_index/src/utils.py
new file mode 100644
index 000000000000..5a45a004efb2
--- /dev/null
+++ b/slm/pipelines/examples/structured_index/src/utils.py
@@ -0,0 +1,189 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import os
+import time
+from typing import Dict, List, Sequence, Union
+
+import numpy as np
+import paddle
+import requests
+
+from paddlenlp.transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
+
+
+def load_data(file_path: str, mode: str) -> Dict:
+ """
+ Load data from a JSON file in StructuredIndexer and return it as a dictionary.
+
+ Args:
+ file_path (str): The path to the JSON file to be loaded.
+ mode (str): A string indicating the mode (e.g., "read", "load") for logging purposes.
+
+ Returns:
+ Dict: A dictionary containing the data loaded from the JSON file.
+
+ Raises:
+ ValueError: If the provided path is not a file or if the file is not a JSON file.
+ FileNotFoundError: If the file does not exist.
+ """
+ if not os.path.isfile(file_path):
+ raise ValueError(f"{file_path} is not a file")
+ if not file_path.endswith(".json"):
+ raise ValueError(f"File {file_path} is not a json file")
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"File {file_path} does not exist")
+ logging.info(f"{mode} file {file_path}")
+ with open(file_path, "r") as f:
+ data = json.load(f)
+ return data
+
+
+def load_model(model_name_or_path: str):
+ """
+ Load a model and its tokenizer from a specified path or model name.
+
+ Args:
+ model_name_or_path (str): The path to the model or the name of the model to be loaded.
+
+ Returns:
+ Tuple[AutoModelForCausalLM, AutoTokenizer]: A tuple containing the loaded model and its tokenizer.
+
+ Raises:
+ RuntimeError: If the model fails to load from the specified path or model name.
+ """
+ device = "gpu" if paddle.device.cuda.device_count() >= 1 else "cpu"
+ paddle.device.set_device(device)
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
+ try:
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, dtype="float16")
+ except Exception:
+ try:
+ model = AutoModel.from_pretrained(model_name_or_path, dtype="float16")
+ except Exception as e:
+ raise RuntimeError(f"Failed to load model from {model_name_or_path}: {e}")
+ model.eval()
+ logging.info(f"Model {model_name_or_path} Loaded to {device}")
+ logging.debug(f"{model.config}")
+ return model, tokenizer
+
+
+def encode(
+ sentences: Sequence[str], tokenizer, model, convert_to_numpy: bool = True
+) -> Union[np.ndarray, paddle.Tensor]:
+ """
+ Encode a sequence of sentences into embeddings using a specified model and tokenizer.
+
+ Args:
+ sentences (Sequence[str]): A sequence of sentences to be encoded.
+ tokenizer: The tokenizer used to preprocess the sentences.
+ model: The model used to generate embeddings.
+ convert_to_numpy (bool, optional): Whether to convert the embeddings to a numpy array. Defaults to True.
+
+ Returns:
+ Union[np.ndarray, paddle.Tensor]: The embeddings of the sentences, either as a numpy array or a PaddlePaddle tensor.
+ """
+ model.eval()
+ encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors="pd")
+ with paddle.no_grad():
+ model_output = model(**encoded_input)
+ sentence_embeddings = model_output[0][:, (0)]
+ sentence_embeddings = paddle.nn.functional.normalize(x=sentence_embeddings, p=2, axis=1)
+ if convert_to_numpy and isinstance(sentence_embeddings, paddle.Tensor):
+ sentence_embeddings_np: np.ndarray = sentence_embeddings.cpu().numpy()
+ return sentence_embeddings_np
+ return sentence_embeddings
+
+
+def get_response(messages: List[Dict], tokenizer, model, max_new_tokens=1024) -> str:
+ """
+ Generate a response using a specified model and tokenizer based on the input messages.
+
+ Args:
+ messages (List[Dict]): A list of dictionaries containing the input messages.
+ tokenizer: The tokenizer used to preprocess the input messages.
+ model: The model used to generate the response.
+ max_new_tokens (int, optional): The maximum number of tokens to generate in the response. Defaults to 1024.
+
+ Returns:
+ str: The generated response.
+ """
+ # logging.debug(messages[0]['content'])
+ inputs = tokenizer(messages[0]["content"], return_tensors="pd")
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
+ response: List[str] = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
+ assert isinstance(response, list)
+ response: str = response[0]
+ # logging.debug(f"response =\n{response}")
+ return response.strip()
+
+
+def get_messages(
+ messages: List[Dict],
+ model_name: str = "Qwen2.5-72B-Instruct-GPTQ-Int4",
+ url: str = None,
+ api_key: str = None,
+ temperature: float = None,
+ n: int = 1,
+ max_tokens: int = None,
+ max_new_tokens: int = None,
+ stream: bool = False,
+ debug: bool = False,
+) -> List[str]:
+ """
+ Send a request to a remote model API to generate responses based on the input messages.
+ """
+ headers = {"Content-Type": "application/json"}
+ if model_name.startswith("gpt"):
+ if api_key is None:
+ raise ValueError("api_key is None")
+ headers["Authorization"] = f"Bearer {api_key}"
+
+ data = {
+ "model": model_name,
+ "messages": messages,
+ "n": n,
+ "stream": stream,
+ }
+ if temperature is not None:
+ data["temperature"] = temperature
+ if max_tokens is not None:
+ data["max_tokens"] = max_tokens
+ if max_new_tokens is not None:
+ data["max_new_tokens"] = max_new_tokens
+
+ if debug:
+ print(f"data={data}")
+
+ message = ""
+ if not isinstance(message, Dict) or "choices" not in message:
+ repeat_index = 0
+ while not isinstance(message, Dict) or "choices" not in message:
+ if repeat_index > 5:
+ raise ConnectionError(f"{url} Error\nmessage=\n{message}")
+ if debug:
+ print(f"message=\n{message}")
+ time.sleep(5)
+ response = requests.post(url, json=data, headers=headers)
+ if debug:
+ print(f"response=\n{response.text}")
+ message = json.loads(response.text)
+ repeat_index += 1
+ if not isinstance(message, Dict) or "choices" not in message:
+ raise ConnectionError(f"{url} Error\nmessage=\n{message}")
+ if len(message["choices"]) != n:
+ raise ValueError(f"{model_name} response num error")
+ return [message["choices"][i]["message"]["content"] for i in range(n)]