From 87a41fae4abfddeea0f9e73876fe561606a9fead Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed <10833894+tarekgh@users.noreply.github.com> Date: Mon, 9 Sep 2024 15:02:38 -0700 Subject: [PATCH] Fix decoding special tokens in SentencePiece tokenizer (#7233) * Fix decoding special tokens in SentencePiece tokenizer * Apply the change to the other Decode method overload --- .../Model/SentencePieceBpeTokenizer.cs | 177 +++++++++++------- .../LlamaTests.cs | 15 ++ 2 files changed, 124 insertions(+), 68 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeTokenizer.cs index 25ef31b6f3..45a58c84a4 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeTokenizer.cs @@ -1549,19 +1549,6 @@ revMerge is null || } ValueStringBuilder sb = new(stackalloc char[256]); - if (enumerator.Current == BeginningOfSentenceId) - { - if (considerSpecialTokens) - { - sb.Append(BeginningOfSentenceToken); - } - - // escape prefix control tokens. - if (!enumerator.MoveNext()) - { - return sb.Length == 0 ? string.Empty : sb.ToString(); - } - } int bytesCount = -1; byte[]? bytesPoolArray = null; @@ -1575,6 +1562,9 @@ revMerge is null || while (enumerator.Current < _byteCodeToIdOffset) { + // It is possible listing some special tokens before the byte tokens in the tokenizer's data. + TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb); + // Skip control tokens. if (!enumerator.MoveNext()) { @@ -1588,16 +1578,20 @@ revMerge is null || } else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token)) { - sb.Append(token); + AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex); + } + else + { + TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb); } } else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token)) { AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex); } - else if (considerSpecialTokens && _specialTokensReverse is not null && _specialTokensReverse.TryGetValue(enumerator.Current, out string? specialToken)) + else { - sb.Append(specialToken); + TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb); } char[]? charPoolArray = null; @@ -1610,6 +1604,10 @@ revMerge is null || { FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb); } + + // It is possible listing some special tokens before the byte tokens in the tokenizer's data. + TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb); + continue; } @@ -1642,9 +1640,9 @@ revMerge is null || { AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex); } - else if (considerSpecialTokens && _specialTokensReverse is not null && _specialTokensReverse.TryGetValue(enumerator.Current, out string? specialToken)) + else { - sb.Append(specialToken); + TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb); } } } @@ -1736,6 +1734,31 @@ static void AppendTokenWithCheckingPrefix(bool addDummyPrefix, bool treatWhitesp prefixRemoved = true; } + + static void TryDecodeAsSpecialToken(SentencePieceBpeTokenizer tokenizer, int id, bool considerSpecialTokens, ref ValueStringBuilder sb) + { + if (!considerSpecialTokens) + { + return; + } + + if (id == tokenizer.BeginningOfSentenceId) + { + sb.Append(tokenizer.BeginningOfSentenceToken); + } + else if (id == tokenizer.EndOfSentenceId) + { + sb.Append(tokenizer.EndOfSentenceToken); + } + else if (id == tokenizer.UnknownId) + { + sb.Append(tokenizer.UnknownToken); + } + else if (tokenizer._specialTokensReverse?.TryGetValue(id, out string? specialToken) is true) + { + sb.Append(specialToken); + } + } } /// @@ -1776,29 +1799,6 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool Span buffer = destination; - if (enumerator.Current == BeginningOfSentenceId) - { - if (considerSpecialTokens) - { - if (buffer.Length < BeginningOfSentenceToken.Length) - { - return OperationStatus.DestinationTooSmall; - } - - BeginningOfSentenceToken.AsSpan().CopyTo(buffer); - buffer = buffer.Slice(BeginningOfSentenceToken.Length); - charsWritten += BeginningOfSentenceToken.Length; - } - - idsConsumed++; - - // escape prefix control tokens. - if (!enumerator.MoveNext()) - { - return OperationStatus.Done; - } - } - int bytesCount = -1; byte[]? bytesPoolArray = null; bool prefixRemoved = false; @@ -1808,9 +1808,15 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool if (enumerator.Current <= _maxByteId) { // First token is a byte token. - while (enumerator.Current < _byteCodeToIdOffset) { + OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten); + if (status != OperationStatus.Done) + { + return status; + } + buffer = destination.Slice(charsWritten); + // Skip control tokens. idsConsumed++; if (!enumerator.MoveNext()) @@ -1833,6 +1839,16 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool return OperationStatus.DestinationTooSmall; } } + else + { + OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten); + if (status != OperationStatus.Done) + { + return status; + } + + idsConsumed++; + } } else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token)) { @@ -1841,24 +1857,16 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool return OperationStatus.DestinationTooSmall; } } - else if (_specialTokensReverse is not null && _specialTokensReverse.TryGetValue(enumerator.Current, out string? specialToken)) + else { - if (considerSpecialTokens) + OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten); + if (status != OperationStatus.Done) { - if (buffer.Length < specialToken.Length) - { - return OperationStatus.DestinationTooSmall; - } - - specialToken.AsSpan().CopyTo(buffer); - charsWritten += specialToken.Length; + return status; } + idsConsumed++; } - else - { - return OperationStatus.InvalidData; - } char[]? charPoolArray = null; @@ -1876,6 +1884,12 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool } } + OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten); + if (status != OperationStatus.Done) + { + return status; + } + idsConsumed++; continue; } @@ -1918,24 +1932,16 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool return OperationStatus.DestinationTooSmall; } } - else if (_specialTokensReverse is not null && _specialTokensReverse.TryGetValue(enumerator.Current, out string? specialToken)) + else { - if (considerSpecialTokens) + OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten); + if (status != OperationStatus.Done) { - if (buffer.Length < specialToken.Length) - { - return OperationStatus.DestinationTooSmall; - } - - specialToken.AsSpan().CopyTo(buffer); - charsWritten += specialToken.Length; + return status; } + idsConsumed++; } - else - { - return OperationStatus.InvalidData; - } } } @@ -1973,6 +1979,41 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool return OperationStatus.Done; + static OperationStatus TryDecodeAsSpecialToken(SentencePieceBpeTokenizer tokenizer, int id, bool considerSpecialTokens, Span buffer, ref int charsWritten) + { + string? specialToken = null; + + if (id == tokenizer.BeginningOfSentenceId) + { + specialToken = tokenizer.BeginningOfSentenceToken; + } + else if (id == tokenizer.EndOfSentenceId) + { + specialToken = tokenizer.EndOfSentenceToken; + } + else if (id == tokenizer.UnknownId) + { + specialToken = tokenizer.UnknownToken; + } + else if (!tokenizer._specialTokensReverse?.TryGetValue(id, out specialToken) is true) + { + return OperationStatus.InvalidData; + } + + if (considerSpecialTokens && specialToken is not null) + { + if (buffer.Length < specialToken!.Length) + { + return OperationStatus.DestinationTooSmall; + } + + specialToken.AsSpan().CopyTo(buffer); + charsWritten += specialToken.Length; + } + + return OperationStatus.Done; + } + static bool FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, Span buffer, ref int charsWritten, ref int idsConsumed) { Debug.Assert(bytesCount >= 1); diff --git a/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs index b703bdf587..6d7178ac2d 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs @@ -376,6 +376,21 @@ public void TestLlamaTokenizerProperties(Tokenizer llamaTokenizer) TokenizerTests.TestTokenLimits(llamaTokenizer); } + /// + /// Test that the special token with a small id is decoded correctly. + /// + [Theory] + [MemberData(nameof(LlamaTokenizersListData))] + public void TestDecodeSpecialTokenWithSmallId(LlamaTokenizer llamaTokenizer) + { + Assert.Equal(llamaTokenizer.EndOfSentenceToken, llamaTokenizer.Decode([llamaTokenizer.EndOfSentenceId], considerSpecialTokens: true)); + Span destinationBuffer = stackalloc char[llamaTokenizer.EndOfSentenceToken.Length]; + Assert.Equal(OperationStatus.Done, llamaTokenizer.Decode([llamaTokenizer.EndOfSentenceId], destinationBuffer, considerSpecialTokens: true, out int idsConsumed, out int charactersWritten)); + Assert.Equal(llamaTokenizer.EndOfSentenceToken.Length, charactersWritten); + Assert.Equal(llamaTokenizer.EndOfSentenceToken, destinationBuffer.ToString()); + Assert.Equal(1, idsConsumed); + } + [Fact] public void TestSentencePieceNormalizer() {