Skip to content

Commit

Permalink
fix c_ops compatibility (#1450)
Browse files Browse the repository at this point in the history
* fix c_ops compatibility with initial breakdowns and track and signalling compatibility issues under paddle 2.2.1
  • Loading branch information
Steffy-zxf authored Dec 13, 2021
1 parent d1140bb commit 583f724
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions paddlenlp/experimental/faster_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib

import paddle
import paddle.fluid.core as core
import paddle.nn as nn
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import in_dygraph_mode
from paddlenlp.utils.downloader import get_path_from_url
from paddlenlp.transformers import BertTokenizer, ErnieTokenizer, RobertaTokenizer
from paddle import _C_ops
from paddlenlp.utils.log import logger

__all__ = ["to_tensor", "to_vocab_buffer", "FasterTokenizer"]

Expand Down Expand Up @@ -77,6 +79,15 @@ class FasterTokenizer(nn.Layer):

def __init__(self, vocab, do_lower_case=False, is_split_into_words=False):
super(FasterTokenizer, self).__init__()

try:
self.mod = importlib.import_module("paddle._C_ops")
except Exception as e:
logger.warning(
f"The paddlepaddle version is {paddle.__version__}, not the latest. "
"Please upgrade the paddlepaddle package (>= 2.2.1).")
self.mod = importlib.import_module("paddle.fluid.core.ops")

vocab_buffer = to_vocab_buffer(vocab, "vocab")
self.register_buffer("vocab", vocab_buffer, persistable=True)

Expand All @@ -94,11 +105,12 @@ def forward(self,
if text_pair is not None:
if isinstance(text_pair, list) or isinstance(text_pair, tuple):
text_pair = to_tensor(list(text_pair))
input_ids, seg_ids = _C_ops.faster_tokenizer(
input_ids, seg_ids = self.mod.faster_tokenizer(
self.vocab, text, text_pair, "do_lower_case",
self.do_lower_case, "max_seq_len", max_seq_len,
"pad_to_max_seq_len", pad_to_max_seq_len, "is_split_into_words",
self.is_split_into_words)

return input_ids, seg_ids

attrs = {
Expand Down

0 comments on commit 583f724

Please sign in to comment.