diff --git a/legacy/model_zoo/ernie-1.0/data_tools/helpers.cpp b/legacy/model_zoo/ernie-1.0/data_tools/helpers.cpp index 07ae7ccdf6dd..e9aefb9be357 100644 --- a/legacy/model_zoo/ernie-1.0/data_tools/helpers.cpp +++ b/legacy/model_zoo/ernie-1.0/data_tools/helpers.cpp @@ -32,7 +32,7 @@ using namespace std; const int32_t LONG_SENTENCE_LEN = 512; -void build_blending_indices(py::array_t& dataset_index, +void build_blending_indices(py::array_t& dataset_index, py::array_t& dataset_sample_index, const py::array_t& weights, const int32_t num_datasets, @@ -73,7 +73,7 @@ void build_blending_indices(py::array_t& dataset_index, } // Populate the indices. - dataset_index_ptr[sample_idx] = static_cast(max_error_index); + dataset_index_ptr[sample_idx] = static_cast(max_error_index); dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; // Update the total samples. diff --git a/paddlenlp/data/blendable_dataset.py b/paddlenlp/data/blendable_dataset.py index c84eb4038642..6a1100b2be1e 100644 --- a/paddlenlp/data/blendable_dataset.py +++ b/paddlenlp/data/blendable_dataset.py @@ -13,6 +13,7 @@ # limitations under the License. import hashlib +import importlib.metadata import os import time @@ -45,8 +46,18 @@ def __init__(self, datasets, weights, size, share_folder, *, data_cache_path=Non # Build indicies. def _build_indices(): start_time = time.time() - assert num_datasets < 255 - dataset_index = np.zeros(self.size, dtype=np.uint8) + + tool_helpers_version = importlib.metadata.version("tool_helpers") + if tool_helpers_version > "0.1.1": + assert ( + num_datasets < 32767 + ), f"Detect num_datasets({num_datasets})>=32767. Currently, num_datasets should be less than 32767." + dataset_index = np.zeros(self.size, dtype=np.int16) + else: + assert ( + num_datasets < 255 + ), f"Detect num_datasets:({num_datasets})>=255. When 'tool_helpers<=0.1.1', num_datasets should be less than 255. To support num_datasets greater than 255, please upgrade `tool_helpers>=0.1.2`." + dataset_index = np.zeros(self.size, dtype=np.uint8) dataset_sample_index = np.zeros(self.size, dtype=np.int64) from tool_helpers import helpers