diff --git a/examples/text_correction/ernie-csc/predict_sighan.py b/examples/text_correction/ernie-csc/predict_sighan.py index ad5ef4d9e877..1b0212e80763 100644 --- a/examples/text_correction/ernie-csc/predict_sighan.py +++ b/examples/text_correction/ernie-csc/predict_sighan.py @@ -94,9 +94,9 @@ def do_predict(args): max_seq_length=args.max_seq_length, is_test=True) batchify_fn = lambda samples, fn=Tuple( - Pad(axis=0, pad_val=tokenizer.pad_token_id), # input - Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment - Pad(axis=0, pad_val=pinyin_vocab.token_to_idx[pinyin_vocab.pad_token]), # pinyin + Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'), # segment + Pad(axis=0, pad_val=pinyin_vocab.token_to_idx[pinyin_vocab.pad_token], dtype='int64'), # pinyin Stack(axis=0, dtype='int64'), # length ): [data for data in fn(samples)] diff --git a/paddlenlp/taskflow/text_correction.py b/paddlenlp/taskflow/text_correction.py index a32e1945b30c..be0ad94e9cc9 100644 --- a/paddlenlp/taskflow/text_correction.py +++ b/paddlenlp/taskflow/text_correction.py @@ -105,9 +105,9 @@ def __init__(self, task, model, **kwargs): self._pypinyin = pypinyin self._max_seq_length = 128 self._batchify_fn = lambda samples, fn=Tuple( - Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # input - Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id), # segment - Pad(axis=0, pad_val=self._pinyin_vocab.token_to_idx[self._pinyin_vocab.pad_token]), # pinyin + Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype='int64'), # input + Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id, dtype='int64'), # segment + Pad(axis=0, pad_val=self._pinyin_vocab.token_to_idx[self._pinyin_vocab.pad_token], dtype='int64'), # pinyin Stack(axis=0, dtype='int64'), # length ): [data for data in fn(samples)] self._num_workers = self.kwargs[