diff --git a/slm/pipelines/examples/structured_index/README.md b/slm/pipelines/examples/structured_index/README.md new file mode 100644 index 000000000000..48f371fd88bc --- /dev/null +++ b/slm/pipelines/examples/structured_index/README.md @@ -0,0 +1,181 @@ +# 文档层次化索引 + +## 方法 + +1. 加载数据(load):把需要处理的 pdf 或者 html 文档加载到流程中。 +2. 文档语篇结构解析(parse):使用大语言模型对文档进行语篇结构解析,根据语义重新切分文章,并解析出文档的语篇结构树。 +3. 层次化摘要生成(summary):根据语篇结构树,自底向上对文档解析结果进行层次化摘要生成,生成不同层次信息的摘要。 +4. 层次化索引构建(index):通过文本编码器,将这些不同层次的文本摘要片段嵌入到稠密检索的向量空间中,从而构建一个层次化文本索引。这种索引不仅包含了局部信息,还包含了较高层次的全局信息,能够支持对多种粒度信息的召回,以适应用户查询中的不同信息需求。 + +## 安装 + +### 环境依赖 + +推荐安装 gpu 版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/conda/linux-conda.html),以 cuda11.7的 paddle 为例,安装命令如下: + +```bash +conda install paddlepaddle-gpu==2.6.2 cudatoolkit=11.7 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge +``` +安装其他依赖: +```bash +pip install -r requirements.txt +``` + +### 数据准备 + +- 源文档:需要构建层次化索引的文档语料,如路径`data/source`下的文档示例。每篇文档为单个文件,目前支持 PDF 或 HTML 格式。 +脚本`data/source/download.sh`可用于下载示例文档: +```bash +apt install jq -y # 安装 jq 工具, 需要系统权限,若已安装可跳过 +cd data/source +bash download.sh +``` +- 查询文件:用户查询文本,目前支持 json 格式,单条查询为`query_id: query_text`,如查询文件示例`data/query.json`。 + + +## 运行 + +### 索引构建 + +为单个文档文件构建层次化索引: +```bash +python construct_index.py \ +--source data/source/2308.12950.pdf \ +--parse_model_name_or_path Qwen/Qwen2-72B-Instruct \ +--summarize_model_name_or_path Qwen/Qwen2-72B-Instruct \ +--encode_model_name_or_path BAAI/bge-large-en-v1.5 \ +--log_dir .logs +``` + +为整个路径下的所有文档文件构建层次化索引: +```bash +python construct_index.py \ +--source data/source \ +--parse_model_name_or_path Qwen/Qwen2-72B-Instruct \ +--summarize_model_name_or_path Qwen/Qwen2-72B-Instruct \ +--encode_model_name_or_path BAAI/bge-large-en-v1.5 \ +--log_dir .logs +``` + +可调整参数包括: +- `source`: 需要构建层次化索引的所有源文件的目录路径,或需要构建层次化索引的单个源文件 + +- `parse_model_name_or_path`: 用于文档语篇结构解析(parse)的模型的名称或路径 + +- `parse_model_url`: 用于文档语篇结构解析(parse)的模型的 URL。如果不需要则不要写这个参数 + +- `summarize_model_name_or_path`: 用于文档层次化摘要(summarize)的模型的名称或路径 + +- `summarize_model_url`: 用于文档层次化摘要(summarize)的模型的 URL。如果不需要则不要写这个参数 + +- `encode_model_name_or_path`: 用于文本编码的模型的名称或路径 + +- `log_dir`: 保存日志文件的路径 + +层次化索引的结果会保存在 `data/index/{encode_model_name_or_path}`, 每个源文档在此路径下有两个对应的缓存文件用于检索:`.pkl`文件包含源文档的层次化摘要文本,`.npy`文件包含对应的摘要文本编码向量。 +例如,对 `data/source/CodeLlama.pdf` 构建的层次化索引缓存文件包括 `index/BAAI/bge-large-en-v1.5/CodeLlama.npy` 和 `index/BAAI/bge-large-en-v1.5/CodeLlama.pkl`。 + +### 检索输出 + +在层次化索引中检索查询相关摘要片段,并输出检索结果。 + +以文件形式查询多条文本: +```bash +python query.py \ +--search_result_dir data/search_result \ +--encode_model_name_or_path BAAI/bge-large-en-v1.5 \ +--log_dir .logs \ +--query_filepath data/query.json \ +--top_k 5 \ +--embedding_batch_size 128 +``` + +以文本形式查询单条文本: +```bash +python query.py \ +--search_result_dir data/search_result \ +--encode_model_name_or_path BAAI/bge-large-en-v1.5 \ +--log_dir .logs \ +--query_text "What is the relationship between CodeLlama and Llama?" \ +--top_k 5 \ +--embedding_batch_size 1 +``` + +可调整参数为: +- `search_result_dir`: 保存查询的检索结果的路径 + +- `encode_model_name_or_path`: 用于文本编码的模型的名称或路径 + +- `query_filepath`: query 的文件路径。如果有,它必须是一个查询字典的 JSON 文件 + +- `query_text`: 单条 query 的文本。如果有,它必须是一个字符串 + +- `top_k`: 设置为每条查询返回前 top_k 个结果 + +- `embedding_batch_size`: 编码 query 时的批处理大小 + +- `log_dir`: 保存日志文件的路径 + +检索结果保存在`{search_result_dir}/{encode_model_name_or_path}`路径下。此路径下的每个结果文件对应一次查询调用,包含若干条查询,即每次会在`{search_result_dir}/{encode_model_name_or_path}`路径下产生一个`query_{时间戳}.json`的文件记录查询结果,由查询 ID 唯一标识单次查询中的每条查询。若通过`query_text`传入查询文本,则查询 ID 设置为`"0"`。 + +例如上述单条查询的检索结果如下: +```json +{ + "0": { + "query": "What is the relationship between CodeLlama and Llama?", + "hits": [ + { + "corpus_id": 122, + "score": 0.7032119035720825, + "content": "CoDE LLAMA is a family of large language models for code, based on LLAMA 2, designed for state-of-the-art performance in programming tasks, including infilling, large context handling, and zero-shot instruction-following, with a focus on safety and alignment.", + "source": "data/source/2308.12950.pdf", + "level": 0 + }, + { + "corpus_id": 127, + "score": 0.6490256786346436, + "content": "CoDE LLAMA models are general-purpose code generation tools, with specialized versions like CoDE LLAMA -PyTHON for Python code and CoDE LLAMA -INsTRUCT for understanding and executing instructions.", + "source": "data/source/2308.12950.pdf", + "level": 3 + }, + { + "corpus_id": 128, + "score": 0.6398724317550659, + "content": "CoDE LLAMA -PyTHON is specialized for Python code generation, while CoDE LLAMA -INsTRUCT models are designed to understand and execute instructions.", + "source": "data/source/2308.12950.pdf", + "level": 4 + }, + { + "corpus_id": 161, + "score": 0.6116989254951477, + "content": "CoDE LLAMA models are designed for real-world applications, excelling in infilling and large context handling, and they achieve state-of-the-art performance on code generation benchmarks while ensuring safety and alignment.", + "source": "data/source/2308.12950.pdf", + "level": 2 + }, + { + "corpus_id": 129, + "score": 0.6056838631629944, + "content": "CoDE LLAMA -INsTRUCT are instruction-following models designed to understand and execute instructions.", + "source": "data/source/2308.12950.pdf", + "level": 5 + } + ] + } +} +``` +其中,每条 query 检索结果的格式如下: +``` +查询ID: { + "query": 查询文本, + "hits": [ + { + "corpus_id": 本条语料在所有语料中的编号, + "score": 相似度分数, + "content": 语料摘要内容, + "source": 语料来源文档的路径, + "level": 本条语料在源文档中的信息粒度层级, 0代表最高级, 数字越大,信息粒度越细 + }, + ... + ] + } +``` \ No newline at end of file diff --git a/slm/pipelines/examples/structured_index/arguments.py b/slm/pipelines/examples/structured_index/arguments.py new file mode 100644 index 000000000000..6665da3f3145 --- /dev/null +++ b/slm/pipelines/examples/structured_index/arguments.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + + +@dataclass +class StructuredIndexerArguments: + """ + Arguments for StructuredIndexer. + """ + + log_dir: str = field(default=".logs", metadata={"help": "log directory"}) + + +@dataclass +class StructuredIndexerEncodeArguments(StructuredIndexerArguments): + """ + Arguments for encoding corpus in StructuredIndexer. + """ + + encode_model_name_or_path: str = field( + default="BAAI/bge-large-en-v1.5", metadata={"help": "encode model name or path"} + ) + + +@dataclass +class StructuredIndexerPipelineArguments(StructuredIndexerEncodeArguments): + """ + Arguments for building StructuredIndex pipeline for a single corpus file. + """ + + source: str = field(default="data/source", metadata={"help": "source file or directory"}) + parse_model_name_or_path: str = field( + default="Qwen/Qwen2-7B-Instruct", metadata={"help": "parse model name or path"} + ) + parse_model_url: str = field(default=None, metadata={"help": "parse model url if you use api"}) + summarize_model_name_or_path: str = field( + default="Qwen/Qwen2-7B-Instruct", + metadata={"help": "summarize model name or path"}, + ) + summarize_model_url: str = field(default=None, metadata={"help": "summarize model url if you use api"}) + + +@dataclass +class RetrievalArguments(StructuredIndexerEncodeArguments): + """ + Arguments for StructuredIndex to retrieve. + """ + + search_result_dir: str = field(default="search_result", metadata={"help": "search result directory"}) + query_filepath: str = field(default="query.json", metadata={"help": "query file path"}) + query_text: str = field(default=None, metadata={"help": "query text"}) + top_k: int = field(default=5, metadata={"help": "top k results for each query"}) + embedding_batch_size: int = field(default=128, metadata={"help": "embedding batch size for queries"}) diff --git a/slm/pipelines/examples/structured_index/construct_index.py b/slm/pipelines/examples/structured_index/construct_index.py new file mode 100644 index 000000000000..f90f40ddd74b --- /dev/null +++ b/slm/pipelines/examples/structured_index/construct_index.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from arguments import StructuredIndexerPipelineArguments +from src.structured_index import StructuredIndexer + +from paddlenlp.trainer import PdArgumentParser + +if __name__ == "__main__": + parser = PdArgumentParser(StructuredIndexerPipelineArguments) + (args,) = parser.parse_args_into_dataclasses() + + structured_indexer = StructuredIndexer(log_dir=args.log_dir) + assert os.path.exists(args.source) + if os.path.isfile(args.source): + structured_indexer.pipeline( + filepath=args.source, + parse_model_name_or_path=args.parse_model_name_or_path, + parse_model_url=args.parse_model_url, + summarize_model_name_or_path=args.summarize_model_name_or_path, + summarize_model_url=args.summarize_model_url, + encode_model_name_or_path=args.encode_model_name_or_path, + ) + else: + for root, _, files in os.walk(args.source): + for file in files: + filepath = os.path.join(root, file) + structured_indexer.pipeline( + filepath=filepath, + parse_model_name_or_path=args.parse_model_name_or_path, + parse_model_url=args.parse_model_url, + summarize_model_name_or_path=args.summarize_model_name_or_path, + summarize_model_url=args.summarize_model_url, + encode_model_name_or_path=args.encode_model_name_or_path, + ) diff --git a/slm/pipelines/examples/structured_index/data/query.json b/slm/pipelines/examples/structured_index/data/query.json new file mode 100644 index 000000000000..62a0ce01f349 --- /dev/null +++ b/slm/pipelines/examples/structured_index/data/query.json @@ -0,0 +1,7 @@ +{ + "0" : "What is big model alignment?", + "1" : "What are the benefits of aligning large models?", + "2" : "How to improve the decoding speed of large language model inference?", + "3" : "What is the difference between CodeLlama and Llama?", + "4" : "What is Grouped Multiple-Degradation Restoration with Image Degradation Similarity?" +} \ No newline at end of file diff --git a/slm/pipelines/examples/structured_index/data/source/download.sh b/slm/pipelines/examples/structured_index/data/source/download.sh new file mode 100644 index 000000000000..eaa2af6a721f --- /dev/null +++ b/slm/pipelines/examples/structured_index/data/source/download.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +download() { + local url=$1 + local ext=$2 + + # 获取文件的basename + local filename=$(basename "$url") + + # 检查文件名是否以ext结尾 + if [[ "$filename" != *".$ext" ]]; then + filename="$filename.$ext" + fi + + # 下载文件 + echo "Downloading $url as $filename" + curl -o "$filename" "$url" +} + +# 读取JSON文件并解析所有ext和对应的URL +json_file="source_url.json" +exts=$(jq -r 'keys[]' "$json_file") + +# 遍历每个ext并下载对应的文件 +for ext in $exts; do + urls=$(jq -r --arg ext "$ext" '.[$ext][]' "$json_file") + for url in $urls; do + download "$url" "$ext" + done +done \ No newline at end of file diff --git a/slm/pipelines/examples/structured_index/data/source/source_url.json b/slm/pipelines/examples/structured_index/data/source/source_url.json new file mode 100644 index 000000000000..cc30537f6907 --- /dev/null +++ b/slm/pipelines/examples/structured_index/data/source/source_url.json @@ -0,0 +1,10 @@ +{ + "pdf": [ + "https://arxiv.org/pdf/2406.15877", + "https://arxiv.org/pdf/2407.12273", + "https://arxiv.org/pdf/2308.12950", + "https://arxiv.org/pdf/1810.04805", + "https://arxiv.org/pdf/2402.12374", + "https://aclanthology.org/2023.ccl-2.7.pdf" + ] +} \ No newline at end of file diff --git a/slm/pipelines/examples/structured_index/query.py b/slm/pipelines/examples/structured_index/query.py new file mode 100644 index 000000000000..00488e4dffc4 --- /dev/null +++ b/slm/pipelines/examples/structured_index/query.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from arguments import RetrievalArguments +from src.structured_index import StructuredIndexer + +from paddlenlp.trainer import PdArgumentParser + +if __name__ == "__main__": + parser = PdArgumentParser(RetrievalArguments) + (args,) = parser.parse_args_into_dataclasses() + + structured_indexer = StructuredIndexer(log_dir=args.log_dir) + + from src.utils import load_data + + if args.query_text is None: + queries_dict = load_data(args.query_filepath, mode="Searching") + else: + + assert isinstance(args.query_text, str) + queries_dict = {"0": args.query_text} + + structured_indexer.search( + queries_dict=queries_dict, + output_dir=args.search_result_dir, + model_name_or_path=args.encode_model_name_or_path, + top_k=args.top_k, + embedding_batch_size=args.embedding_batch_size, + ) diff --git a/slm/pipelines/examples/structured_index/requirements.txt b/slm/pipelines/examples/structured_index/requirements.txt new file mode 100644 index 000000000000..343739772537 --- /dev/null +++ b/slm/pipelines/examples/structured_index/requirements.txt @@ -0,0 +1,7 @@ +faiss-gpu==1.7.2 +paddlenlp==3.0.0b2 +tqdm +numpy +fitz +frontend +paddleocr==2.7.3 \ No newline at end of file diff --git a/slm/pipelines/examples/structured_index/src/index.py b/slm/pipelines/examples/structured_index/src/index.py new file mode 100644 index 000000000000..9718089bb9e0 --- /dev/null +++ b/slm/pipelines/examples/structured_index/src/index.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +import os +import pickle +from typing import Dict, List, Sequence + +import faiss +import numpy as np +from tqdm import tqdm + +from .utils import encode, load_data, load_model + + +class Indexer: + def __init__(self, model_name_or_path: str = "BAAI/bge-large-en-v1.5", url=None): + self.model_name = model_name_or_path + self.model, self.tokenizer = load_model(model_name_or_path=model_name_or_path) + + def __call__(self, file_path: str) -> Dict: + data = load_data(file_path=file_path, mode="Indexing") + return self.index_doc(data=data, file_path=file_path) + + def get_corpus_embedding(self, corpus: Sequence[str], batch_size: int = 128) -> np.ndarray: + """ + Generate embeddings for a corpus of text by processing it in batches. + + Args: + corpus (Sequence[str]): A sequence of text strings to be embedded. + batch_size (int, optional): The number of text strings to process in each batch. Defaults to 128. + + Returns: + np.ndarray: A numpy array containing the embeddings for the entire corpus. + """ + logging.info("Getting embedding") + for i in tqdm(range(math.ceil(len(corpus) / batch_size))): + corpus_embeddings = encode( + sentences=corpus[i * batch_size : (i + 1) * batch_size], + model=self.model, + tokenizer=self.tokenizer, + convert_to_numpy=True, + ) + if i == 0: + corpus_embeddings_list = corpus_embeddings + else: + corpus_embeddings_list = np.concatenate((corpus_embeddings_list, corpus_embeddings), axis=0) + return corpus_embeddings_list + + def build_engine(self, corpus_embeddings: np.ndarray) -> faiss.IndexFlatIP: + """ + Build a FAISS index using the provided corpus embeddings. + + Args: + corpus_embeddings (np.ndarray): A numpy array containing the embeddings of the corpus. + + Returns: + faiss.IndexFlatIP: A FAISS index built using the inner product metric. + """ + embedding_dim = corpus_embeddings.shape[1] + index = faiss.IndexFlatIP(embedding_dim) + index.add(corpus_embeddings.astype("float32")) + return index + + def index_doc( + self, + file_path: str, + data: Dict, + save: bool = True, + save_dir: str = "data/index", + ) -> Dict: + """ + Index a document by extracting its content, generating embeddings, and optionally saving the results. + + Args: + file_path (str): The path to the file being indexed. + data (Dict): A dictionary containing the document's data, including nodes with content and summary. + save (bool, optional): Whether to save the generated embeddings and metadata. Defaults to True. + save_dir (str, optional): The directory where the saved files will be stored. Defaults to "data/index". + + Returns: + Dict: The original data dictionary, potentially modified during the indexing process. + """ + corpus: List[str] = [] + info: List[dict] = [] + for node in data["nodes"]: + if "summary" in node and len(node["summary"]) > 0: + content = node["summary"] + else: + content = node["content"] + corpus.append(content) + level = node["level"] if "level" in node else 0 + info.append({"content": content, "source": data["source"], "level": level}) + corpus_embeddings: np.ndarray = self.get_corpus_embedding(corpus) + + filename = os.path.splitext(os.path.basename(file_path))[0] + if save: + save_dir = os.path.join(save_dir, self.model_name) + os.makedirs(save_dir, exist_ok=True) + np.save(os.path.join(save_dir, f"{filename}.npy"), corpus_embeddings) + with open(os.path.join(save_dir, f"{filename}.pkl"), "wb") as f: + pickle.dump(info, f) + logging.info(f"Index cache and Corpus saved to {save_dir}/{filename}.*") + + return data + + def _search( + self, + index: faiss.IndexFlatIP, + query_embeddings: np.ndarray, + top_k: int, + batch_size: int = 4000, + ) -> List[Dict]: + """ + retrieves the top_k hits for each query embedding + Calling index.search() + """ + hits = [] + for i in range(math.ceil(len(query_embeddings) / batch_size)): + q_emb_matrix = query_embeddings[i * batch_size : (i + 1) * batch_size] + res_dist, res_p_id = index.search(q_emb_matrix.astype("float32"), top_k) + assert len(res_dist) == len(q_emb_matrix) + assert len(res_p_id) == len(q_emb_matrix) + + for i in range(len(q_emb_matrix)): + passages = [] + assert len(res_p_id[i]) == len(res_dist[i]) + for j in range(min(top_k, len(res_p_id[i]))): + pid = res_p_id[i][j] + score = res_dist[i][j] + passages.append({"corpus_id": int(pid), "score": float(score)}) + hits.append(passages) + return hits diff --git a/slm/pipelines/examples/structured_index/src/load.py b/slm/pipelines/examples/structured_index/src/load.py new file mode 100644 index 000000000000..9d0f6e14e188 --- /dev/null +++ b/slm/pipelines/examples/structured_index/src/load.py @@ -0,0 +1,151 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import Dict, List + + +class Loader: + @classmethod + def load_file(cls, filepath: str) -> Dict: + """ + Load a file based on its extension using the appropriate loader method. + + Args: + filepath (str): The path to the file to be loaded. + + Returns: + Dict: A dictionary containing the loaded document data. + None: If the file type is unsupported. + """ + loader = {".pdf": cls.load_pdf, ".json": cls.load_json, ".html": cls.load_html} + _, ext = os.path.splitext(os.path.basename(filepath)) + if ext in loader: + doc = loader[ext](filepath=filepath) + return doc + else: + logging.warning(f"Unsupported file type: {filepath}") + return None + + @classmethod + def load_dir(cls, input_dir: str, output_dir: str, save: bool = True): + """ + Load all files in a directory and optionally save their content as JSON files. + + Args: + input_dir (str): The directory containing the files to be loaded. + output_dir (str): The directory where the loaded content will be saved as JSON files. + save (bool, optional): Whether to save the loaded content. Defaults to True. + """ + for root, _, files in os.walk(input_dir): + for file in files: + filepath = os.path.join(root, file) + filename, _ = os.path.splitext(os.path.basename(file)) + file_output_dir = os.path.join(output_dir, f"{filename}.json") + if os.path.exists(file_output_dir): + if os.path.getsize(file_output_dir) > 0: + logging.info(f"File already exists: {file_output_dir}") + continue + doc = cls.load_file(filepath=filepath) + if save and doc is not None: + with open(file_output_dir, "w") as f: + json.dump(doc, f, ensure_ascii=False, indent=4) + logging.info(f"Saved to: {file_output_dir}") + + @classmethod + def load_pdf(cls, filepath: str) -> Dict: + """ + Load a PDF file, extract images from each page, perform OCR on the images, and return the extracted text content. + + Args: + filepath (str): The path to the PDF file to be loaded. + + Returns: + Dict: A dictionary containing the extracted text content from the PDF. + """ + + logging.info(f"Loading PDF: {filepath}") + + from paddleocr import PaddleOCR + + # Paddleocr目前支持的多语言语种可以通过修改lang参数进行切换 + # 例如`ch`, `en`, `fr`, `german`, `korean`, `japan` + PAGE_NUM = 10 # 将识别页码前置作为全局,防止后续打开pdf的参数和前文识别参数不一致 / Set the recognition page number + # ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=PAGE_NUM) # need to run only once to download and load model into memory + ocr = PaddleOCR( + use_angle_cls=True, lang="ch", page_num=PAGE_NUM, use_gpu=1 + ) # 如果需要使用GPU,请取消此行的注释 并注释上一行 / To Use GPU,uncomment this line and comment the above one. + result = ocr.ocr(filepath, cls=True) + + doc = {"source": filepath, "nodes": []} + for idx in range(len(result)): + res = result[idx] + if res is None: + # 识别到空页就跳过,防止程序报错 / Skip when empty result detected to avoid TypeError:NoneType + logging.info(f"Empty page {idx+1} detected in {filepath}, skip it.") + continue + for line in res: + assert isinstance(line, list) and len(line) == 2 + assert isinstance(line[1], tuple) and len(line[1]) == 2 + assert isinstance(line[1][0], str) and isinstance(line[1][1], float) + doc["nodes"].append( + { + "id": len(doc["nodes"]), + "content": line[1][0], + "confidence": line[1][1], + } + ) + return doc + + @classmethod + def load_json(cls, filepath: str) -> dict: + with open(filepath, "r") as f: + data = json.load(f) + return data + + @classmethod + def load_html(cls, filepath: str) -> dict: + """ + Load an HTML file, extract text content, and return it as a structured dictionary. + + Args: + filepath (str): The path to the HTML file to be loaded. + + Returns: + dict: A dictionary containing the extracted text content from the HTML file. + """ + from bs4 import BeautifulSoup + + doc = {"source": filepath, "nodes": []} + with open(filepath, "r", encoding="utf-8") as f: + html = f.read() + soup = BeautifulSoup(html, "html.parser") + html_text = soup.get_text() + + node_list: List[Dict] = [] + for t in html_text.splitlines(): + content = t.strip() + if len(content) <= 0: + continue + node_list.append( + { + "id": len(node_list), + "content": content, + } + ) + doc["nodes"] = node_list + + return doc diff --git a/slm/pipelines/examples/structured_index/src/parse.py b/slm/pipelines/examples/structured_index/src/parse.py new file mode 100644 index 000000000000..d677e1634d99 --- /dev/null +++ b/slm/pipelines/examples/structured_index/src/parse.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, List + +from .utils import get_messages, get_response, load_data, load_model + + +class DocumentStructureParser: + def __init__( + self, + model_name_or_path: str = "Qwen/Qwen2-7B-Instruct", + url: str = None, + ): + self.model_name = model_name_or_path + if url is None: + self.model, self.tokenizer = load_model(model_name_or_path=model_name_or_path) + self.url = url + self.prompt_template = """**Structured document text** +Insert # into the text, using the number of # to represent the hierarchical level of the content, while keeping all the original information. +The structured results are output in the form of markdown. +```markdown +``` + +**Input text:** +[input] +**Structured Text:** +""" + + def __call__(self, file_path: str) -> Dict: + data = load_data(file_path=file_path, mode="Parsing") + return self.parse_doc(data=data) + + def extract_text(self, response: str) -> str: + """ + Extract text content from a markdown code block within a given response string. + + Args: + response (str): The response string containing markdown code blocks. + + Returns: + str: The extracted text content from the markdown code block. + """ + start_index = response.rfind("```markdown") + len("```markdown") + end_index = response.rfind("```") + content = response[start_index:end_index] + return content.strip() + + def parse_doc(self, data: Dict, max_text_length=512, repeat_length=2048) -> Dict: + """ + Parse a document by extracting its content, formatting it according to a prompt template, and structuring the content into a hierarchical format. + + Args: + data (Dict): A dictionary containing the document's data, including nodes with content. + max_text_length (int, optional): The maximum length of text to consider from each node. Defaults to 512. + repeat_length (int, optional): The length at which the text should be repeated and inserted into the prompt template. Defaults to 2048. + + Returns: + Dict: A dictionary containing the parsed and structured content of the document. + """ + contents = [item["content"] for item in data["nodes"]] + content = "".join(c[:max_text_length].strip() for c in contents) + content = self.prompt_template.replace("[input]", content) + if len(content) <= repeat_length: + content = content.replace("Insert # into the text,", "Repeat the text and insert # into the text,") + + messages = [ + { + "role": "user", + "content": content, + } + ] + logging.debug(messages[0]["content"]) + if self.url is None: + predicted_str = get_response(messages=messages, tokenizer=self.tokenizer, model=self.model) + else: + predicted_str = get_messages( + messages=messages, + model_name=self.model_name, + url=self.url, + temperature=0.0, + )[0] + logging.debug(f"parse result =\n{predicted_str}") + predicted_str = self.extract_text(response=predicted_str) + + predicted_str_list: List[str] = predicted_str.splitlines() + cur_dictionary = [] + node_id = -1 + data["nodes"] = [] + for line in predicted_str_list: + line = line.strip() + if len(line) == 0: + continue + node_id += 1 + level = -1 + while line.startswith("#"): + level += 1 + line = line[1:] + content = line.strip() + if level == -1: + level = len(cur_dictionary) + data["nodes"].append({"id": node_id, "content": content, "level": level}) + + while len(cur_dictionary) > level: + cur_dictionary.pop() + parent_id = cur_dictionary[-1] if len(cur_dictionary) > 0 else -1 + cur_dictionary.append(node_id) + assert len(cur_dictionary) <= level + 1, f"len={len(cur_dictionary)}, level={level}" + + if parent_id > -1: + data["nodes"][node_id]["parent_id"] = parent_id + if "child_ids" not in data["nodes"][parent_id]: + data["nodes"][parent_id]["child_ids"] = [] + data["nodes"][parent_id]["child_ids"].append(node_id) + + return data diff --git a/slm/pipelines/examples/structured_index/src/structured_index.py b/slm/pipelines/examples/structured_index/src/structured_index.py new file mode 100644 index 000000000000..2409f3170590 --- /dev/null +++ b/slm/pipelines/examples/structured_index/src/structured_index.py @@ -0,0 +1,228 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import pickle +from datetime import datetime +from typing import Dict, List, Optional + +import faiss +import numpy as np + +from .index import Indexer +from .load import Loader +from .parse import DocumentStructureParser +from .summarize import Summarizer + + +class StructuredIndexer(object): + def __init__(self, log_dir: Optional[str] = None): + self.set_logging(log_dir=log_dir) + + def set_logging(self, log_dir: Optional[str] = None): + """ + Configures the logging settings for the application. + + Args: + log_dir (Optional[str]): The directory where log files will be stored. If None, the default directory '.logs' is used. + + Returns: + None + """ + if log_dir is None: + log_dir = ".logs" + os.makedirs(log_dir, exist_ok=True) + log_filename = os.path.join(log_dir, f"{str(datetime.now())}.log") + + level = logging.DEBUG + logging.basicConfig( + level=level, + format="%(asctime)s [%(levelname)s] at %(filename)s,%(lineno)d: %(message)s", + datefmt="%Y-%m-%d(%a)%H:%M:%S", + filename=log_filename, + filemode="w", + ) + + console = logging.StreamHandler() + console.setLevel(level) + formatter = logging.Formatter("[%(levelname)-8s] %(message)s") + console.setFormatter(formatter) + logging.getLogger().addHandler(console) + + def prepare_dir(self, input_dir: str, mode: str, output_dir: str = None, save: bool = True): + """ + Prepares the input directory for processing and optionally creates an output directory. + + Args: + input_dir (str): The path to the input directory. + mode (str): The mode in which the files are being processed (e.g., "read", "write"). + output_dir (str, optional): The path to the output directory. Defaults to None. + save (bool, optional): Whether to create the output directory if it doesn't exist. Defaults to True. + + Raises: + NotADirectoryError: If the `input_dir` is not a directory. + FileNotFoundError: If the `input_dir` does not exist. + + Returns: + None + + Logs: + - An info message indicating the mode of operation and the input directory. + - An error message if the `input_dir` is not a directory. + - An error message if the `input_dir` does not exist. + """ + if not os.path.isdir(input_dir): + logging.error(f"Path {input_dir} is not a directory.") + raise NotADirectoryError(f"Path {input_dir} is not a directory.") + if not os.path.exists(input_dir): + logging.error(f"Path {input_dir} does not exist.") + raise FileNotFoundError(f"Path {input_dir} does not exist.") + logging.info(f"{mode} files in {input_dir}") + if save and output_dir is not None: + os.makedirs(output_dir, exist_ok=True) + + def search( + self, + queries_dict: Dict, + input_dir: str = "data/index", + output_dir: str = "data/search_result", + model_name_or_path: str = "BAAI/bge-large-en-v1.5", + top_k: int = 10, + embedding_batch_size: int = 128, + save: bool = True, + ) -> List[Dict]: + """ + Perform a semantic search on a pre-indexed corpus using a specified embedding model. + + Args: + queries_dict (Dict): A dictionary where keys are query identifiers and values are the actual queries. + input_dir (str, optional): The directory containing the pre-indexed corpus and embeddings. Defaults to "data/index". + output_dir (str, optional): The directory where the search results will be saved. Defaults to "data/search_result". + model_name_or_path (str, optional): The name or path of the embedding model to use. Defaults to "BAAI/bge-large-en-v1.5". + top_k (int, optional): The number of top results to retrieve for each query. Defaults to 10. + embedding_batch_size (int, optional): The batch size for embedding the queries. Defaults to 128. + save (bool, optional): Whether to save the search results to a file. Defaults to True. + + Returns: + List[Dict]: A list of dictionaries containing the search results for each query. + + Raises: + AssertionError: If the number of index cache files does not match the number of corpus files, or if the corpus embeddings do not match in shape. + + Notes: + - The input directory should contain the pre-computed embeddings and corresponding corpus files. + - The output directory will be created if it does not exist. + - The search results are saved in a JSON file with a timestamp in the filename. + """ + input_dir = os.path.join(input_dir, model_name_or_path) # The embedding model must be the same + output_dir = os.path.join(output_dir, model_name_or_path) + self.prepare_dir(input_dir=input_dir, output_dir=output_dir, save=True, mode="Searching") + + indexcachefiles = [] + corpusfiles = [] + for root, _, files in os.walk(input_dir): + for file in files: + if file.endswith(".npy"): + indexcachefiles.append(os.path.join(root, file)) + document_name = os.path.splitext(os.path.basename(file))[0] + corpusfiles.append(os.path.join(root, document_name + ".pkl")) + logging.info(f"Loading Index cache from {indexcachefiles[-1]}") + logging.info(f"Loading Corpus from {corpusfiles[-1]}") + assert len(indexcachefiles) == len(corpusfiles) and len(indexcachefiles) > 0 + + indexer = Indexer(model_name_or_path=model_name_or_path) + + corpus_infos: List[Dict] = [] + for file in corpusfiles: + corpus_infos.extend(pickle.load(open(file, "rb"))) + logging.info(f"corpus/index quantity: {len(corpus_infos)}") + corpus_embeddings = np.load(indexcachefiles[0]) + for file in indexcachefiles[1:]: + temp_corpus_embeddings = np.load(file) + assert corpus_embeddings.shape[1] == temp_corpus_embeddings.shape[1] + corpus_embeddings = np.concatenate((corpus_embeddings, temp_corpus_embeddings), axis=0) + index: faiss.IndexFlatIP = indexer.build_engine(corpus_embeddings) + assert len(corpus_infos) == index.ntotal + + queries = list(queries_dict.values()) + keys = list(queries_dict.keys()) + query_embeddings = indexer.get_corpus_embedding(queries, batch_size=embedding_batch_size) + hits = indexer._search(index, query_embeddings, top_k) + assert len(hits) == len(queries) + + for passages in hits: + for p_dict in passages: + id = p_dict["corpus_id"] + if id < len(corpus_infos) and id >= 0: + p_dict["content"] = corpus_infos[id]["content"] + p_dict["source"] = corpus_infos[id]["source"] + p_dict["level"] = corpus_infos[id]["level"] + if save: + result_dict = {} + for i in range(len(hits)): + hit_dict = {"query": queries[i], "hits": hits[i]} + assert queries_dict[keys[i]] == queries[i] + result_dict[keys[i]] = hit_dict + queryfilename = f'query_{datetime.now().strftime("%Y%m%d%H%M%S")}.json' + with open(os.path.join(output_dir, queryfilename), "w") as f: + json.dump(result_dict, f, ensure_ascii=False, indent=4) + logging.info(f"Saved to: {queryfilename}") + return result_dict + + def pipeline( + self, + filepath: str, + parse_model_name_or_path: str = "Qwen/Qwen2-7B-Instruct", + parse_model_url: str = None, + summarize_model_name_or_path: str = "Qwen/Qwen2-7B-Instruct", + summarize_model_url: str = None, + encode_model_name_or_path: str = "BAAI/bge-large-en-v1.5", + ): + """ + Process a document through a series of steps including loading, parsing, summarizing and indexing to get Structured Index. + + Args: + filepath (str): The path to the document file to be processed. + parse_model_name_or_path (str, optional): The name or path of the model used for document parsing. Defaults to "Qwen/Qwen2-7B-Instruct". + parse_model_url (str, optional): The URL of the model used for document parsing. Defaults to None. + summarize_model_name_or_path (str, optional): The name or path of the model used for document summarization. Defaults to "Qwen/Qwen2-7B-Instruct". + summarize_model_url (str, optional): The URL of the model used for document summarization. Defaults to None. + encode_model_name_or_path (str, optional): The name or path of the model used for document encoding (indexing). Defaults to "BAAI/bge-large-en-v1.5". + + Returns: + dict: A dictionary containing the processed document data, or None if the document could not be loaded. + + Steps: + 1. Load the document from the specified file path. + 2. Parse the document structure using the specified parsing model. + 3. Summarize the document content using the specified summarization model. + 4. Index the document using the specified encoding model. + + Notes: + - The document is processed in a sequence of steps: loading, parsing, summarizing, and indexing. + - Each step uses a different model specified by the respective model name or path and URL. + - If the document cannot be loaded, the function returns None. + """ + doc = Loader.load_file(filepath=filepath) + if doc is None: + return None + processor = DocumentStructureParser(model_name_or_path=parse_model_name_or_path, url=parse_model_url) + doc = processor.parse_doc(data=doc) + processor = Summarizer(model_name_or_path=summarize_model_name_or_path, url=summarize_model_url) + doc = processor.summarize_doc(data=doc) + processor = Indexer(model_name_or_path=encode_model_name_or_path) + doc = processor.index_doc(file_path=filepath, data=doc) + return doc diff --git a/slm/pipelines/examples/structured_index/src/summarize.py b/slm/pipelines/examples/structured_index/src/summarize.py new file mode 100644 index 000000000000..b1ab0b710c0b --- /dev/null +++ b/slm/pipelines/examples/structured_index/src/summarize.py @@ -0,0 +1,151 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, List + +from tqdm import tqdm + +from .utils import get_messages, get_response, load_data, load_model + + +class Summarizer: + def __init__( + self, + model_name_or_path: str = "Qwen/Qwen2-7B-Instruct", + url: str = None, + ): + self.model_name = model_name_or_path + if url is None: + self.model, self.tokenizer = load_model(model_name_or_path=model_name_or_path) + self.url = url + self.prompt_template = """# Generate text summaries based on text content and sub node summaries +**Task Description:** +You will receive a text and a summary of multiple child nodes that the text may depend on. Your task is to generate a concise and accurate summary for this text, no more than three sentences. +**Tips:** +1. Understand the text and child nodes: Carefully read the text itself and the summary of the child nodes (if any), and understand the logical relationship between them. Generate a summary with a focus on the text itself, noting that there may be missing spaces between words in the English text. +2. Integrate information: Integrate key information from the child node summary (if any) into the text summary to ensure the completeness and accuracy of the information. +3. Concise and clear: Use concise language to avoid redundant information and highlight key points. No more than two sentences. +4. Language consistency: The language used in the abstract should be consistent with the language of the text itself. +5. Compliant format: Output format is ..., If unable to generate a summary, simply output an empty summary . +**Input:** + +Artificial intelligence (AI) is changing our lives. From autonomous vehicle to smart homes, AI has been used more and more widely. + + +Autonomous vehicle use AI technology to achieve driverless driving and improve road safety and traffic efficiency. +The smart home system achieves automation control through AI technology, improving the convenience and comfort of life. + +**Output:** + +AI is changing our lives. Its applications include autonomous vehicle that improve road safety and traffic efficiency, and smart home systems that improve the convenience and comfort of life. + +**Input:** + +[text] + + +[child_summary] + +**Output:**""" + + def __call__(self, file_path: str) -> Dict: + data = load_data(file_path=file_path, mode="Summarizing") + return self.summarize_doc(data=data) + + def extract_summary(self, response: str) -> str: + """ + Extract the summary content from a given response string enclosed within tags. + + Args: + response (str): The response string containing the summary within tags. + + Returns: + str: The extracted summary content. + """ + start_index = response.rfind("") + len("") + end_index = response.rfind("") + content = response[start_index:end_index] + return content.strip() + + def calc_topologic(self, nodes: List[Dict]) -> List[int]: + """ + Calculate the topological order of nodes based on their parent-child relationships. + + Args: + nodes (List[Dict]): A list of dictionaries representing nodes, each containing an 'id' and potentially a 'parent_id'. + + Returns: + List[int]: A list of node IDs in topological order. + """ + topologic = [] + pred = [0 for _ in range(len(nodes))] + for node in nodes: + if "parent_id" in node and node["parent_id"] >= 0: + pred[node["parent_id"]] += 1 + while len(topologic) < len(nodes): + for i in range(len(pred)): + if pred[i] == 0: + topologic.append(i) + pred[i] = -1 + if "parent_id" in nodes[i]: + pred[nodes[i]["parent_id"]] -= 1 + return topologic + + def summarize_doc(self, data: Dict) -> Dict: + """ + Summarize a document by generating summaries for each node based on their content and child summaries. + + Args: + data (Dict): A dictionary containing the document's data, including nodes with content and child relationships. + + Returns: + Dict: A dictionary containing the summarized content for each node. + """ + topologic = self.calc_topologic(nodes=data["nodes"]) # Calculate topological order + logging.debug(f"Topologic: {topologic}") + assert len(topologic) == len(data["nodes"]) + for cur_id in tqdm(topologic): + cur_node = data["nodes"][cur_id] + assert "summary" not in cur_node + assert cur_id == cur_node["id"] + + message_content = self.prompt_template + message_content = message_content.replace("[text]", cur_node["content"]) + child_summary = "" + if "child_ids" in cur_node: + for child_id in cur_node["child_ids"]: + assert "summary" in data["nodes"][child_id] + child_summary += data["nodes"][child_id]["summary"] + "\n" + message_content = message_content.replace("[child_summary]", child_summary.strip()) + # logging.debug(f"Message content: {message_content}") + messages = [ + { + "role": "user", + "content": message_content, + } + ] + if self.url is None: + response = get_response(messages=messages, tokenizer=self.tokenizer, model=self.model) + else: + response = get_messages( + messages=messages, + model_name=self.model_name, + url=self.url, + temperature=0.0, + )[0] + data["nodes"][cur_id]["summary"] = self.extract_summary(response=response) + logging.debug(f"Summary for node {cur_id}: {data['nodes'][cur_id]['summary']}\n") + + return data diff --git a/slm/pipelines/examples/structured_index/src/utils.py b/slm/pipelines/examples/structured_index/src/utils.py new file mode 100644 index 000000000000..5a45a004efb2 --- /dev/null +++ b/slm/pipelines/examples/structured_index/src/utils.py @@ -0,0 +1,189 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import time +from typing import Dict, List, Sequence, Union + +import numpy as np +import paddle +import requests + +from paddlenlp.transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + + +def load_data(file_path: str, mode: str) -> Dict: + """ + Load data from a JSON file in StructuredIndexer and return it as a dictionary. + + Args: + file_path (str): The path to the JSON file to be loaded. + mode (str): A string indicating the mode (e.g., "read", "load") for logging purposes. + + Returns: + Dict: A dictionary containing the data loaded from the JSON file. + + Raises: + ValueError: If the provided path is not a file or if the file is not a JSON file. + FileNotFoundError: If the file does not exist. + """ + if not os.path.isfile(file_path): + raise ValueError(f"{file_path} is not a file") + if not file_path.endswith(".json"): + raise ValueError(f"File {file_path} is not a json file") + if not os.path.exists(file_path): + raise FileNotFoundError(f"File {file_path} does not exist") + logging.info(f"{mode} file {file_path}") + with open(file_path, "r") as f: + data = json.load(f) + return data + + +def load_model(model_name_or_path: str): + """ + Load a model and its tokenizer from a specified path or model name. + + Args: + model_name_or_path (str): The path to the model or the name of the model to be loaded. + + Returns: + Tuple[AutoModelForCausalLM, AutoTokenizer]: A tuple containing the loaded model and its tokenizer. + + Raises: + RuntimeError: If the model fails to load from the specified path or model name. + """ + device = "gpu" if paddle.device.cuda.device_count() >= 1 else "cpu" + paddle.device.set_device(device) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + try: + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, dtype="float16") + except Exception: + try: + model = AutoModel.from_pretrained(model_name_or_path, dtype="float16") + except Exception as e: + raise RuntimeError(f"Failed to load model from {model_name_or_path}: {e}") + model.eval() + logging.info(f"Model {model_name_or_path} Loaded to {device}") + logging.debug(f"{model.config}") + return model, tokenizer + + +def encode( + sentences: Sequence[str], tokenizer, model, convert_to_numpy: bool = True +) -> Union[np.ndarray, paddle.Tensor]: + """ + Encode a sequence of sentences into embeddings using a specified model and tokenizer. + + Args: + sentences (Sequence[str]): A sequence of sentences to be encoded. + tokenizer: The tokenizer used to preprocess the sentences. + model: The model used to generate embeddings. + convert_to_numpy (bool, optional): Whether to convert the embeddings to a numpy array. Defaults to True. + + Returns: + Union[np.ndarray, paddle.Tensor]: The embeddings of the sentences, either as a numpy array or a PaddlePaddle tensor. + """ + model.eval() + encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors="pd") + with paddle.no_grad(): + model_output = model(**encoded_input) + sentence_embeddings = model_output[0][:, (0)] + sentence_embeddings = paddle.nn.functional.normalize(x=sentence_embeddings, p=2, axis=1) + if convert_to_numpy and isinstance(sentence_embeddings, paddle.Tensor): + sentence_embeddings_np: np.ndarray = sentence_embeddings.cpu().numpy() + return sentence_embeddings_np + return sentence_embeddings + + +def get_response(messages: List[Dict], tokenizer, model, max_new_tokens=1024) -> str: + """ + Generate a response using a specified model and tokenizer based on the input messages. + + Args: + messages (List[Dict]): A list of dictionaries containing the input messages. + tokenizer: The tokenizer used to preprocess the input messages. + model: The model used to generate the response. + max_new_tokens (int, optional): The maximum number of tokens to generate in the response. Defaults to 1024. + + Returns: + str: The generated response. + """ + # logging.debug(messages[0]['content']) + inputs = tokenizer(messages[0]["content"], return_tensors="pd") + outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) + response: List[str] = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) + assert isinstance(response, list) + response: str = response[0] + # logging.debug(f"response =\n{response}") + return response.strip() + + +def get_messages( + messages: List[Dict], + model_name: str = "Qwen2.5-72B-Instruct-GPTQ-Int4", + url: str = None, + api_key: str = None, + temperature: float = None, + n: int = 1, + max_tokens: int = None, + max_new_tokens: int = None, + stream: bool = False, + debug: bool = False, +) -> List[str]: + """ + Send a request to a remote model API to generate responses based on the input messages. + """ + headers = {"Content-Type": "application/json"} + if model_name.startswith("gpt"): + if api_key is None: + raise ValueError("api_key is None") + headers["Authorization"] = f"Bearer {api_key}" + + data = { + "model": model_name, + "messages": messages, + "n": n, + "stream": stream, + } + if temperature is not None: + data["temperature"] = temperature + if max_tokens is not None: + data["max_tokens"] = max_tokens + if max_new_tokens is not None: + data["max_new_tokens"] = max_new_tokens + + if debug: + print(f"data={data}") + + message = "" + if not isinstance(message, Dict) or "choices" not in message: + repeat_index = 0 + while not isinstance(message, Dict) or "choices" not in message: + if repeat_index > 5: + raise ConnectionError(f"{url} Error\nmessage=\n{message}") + if debug: + print(f"message=\n{message}") + time.sleep(5) + response = requests.post(url, json=data, headers=headers) + if debug: + print(f"response=\n{response.text}") + message = json.loads(response.text) + repeat_index += 1 + if not isinstance(message, Dict) or "choices" not in message: + raise ConnectionError(f"{url} Error\nmessage=\n{message}") + if len(message["choices"]) != n: + raise ValueError(f"{model_name} response num error") + return [message["choices"][i]["message"]["content"] for i in range(n)]