Skip to content

Commit

Permalink
Fix decoding special tokens in SentencePiece tokenizer (#7233)
Browse files Browse the repository at this point in the history
* Fix decoding special tokens in SentencePiece tokenizer

* Apply the change to the other Decode method overload
  • Loading branch information
tarekgh authored Sep 9, 2024
1 parent 4e364e4 commit 87a41fa
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 68 deletions.
177 changes: 109 additions & 68 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1549,19 +1549,6 @@ revMerge is null ||
}

ValueStringBuilder sb = new(stackalloc char[256]);
if (enumerator.Current == BeginningOfSentenceId)
{
if (considerSpecialTokens)
{
sb.Append(BeginningOfSentenceToken);
}

// escape prefix control tokens.
if (!enumerator.MoveNext())
{
return sb.Length == 0 ? string.Empty : sb.ToString();
}
}

int bytesCount = -1;
byte[]? bytesPoolArray = null;
Expand All @@ -1575,6 +1562,9 @@ revMerge is null ||

while (enumerator.Current < _byteCodeToIdOffset)
{
// It is possible listing some special tokens before the byte tokens in the tokenizer's data.
TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb);

// Skip control tokens.
if (!enumerator.MoveNext())
{
Expand All @@ -1588,16 +1578,20 @@ revMerge is null ||
}
else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token))
{
sb.Append(token);
AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex);
}
else
{
TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb);
}
}
else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token))
{
AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex);
}
else if (considerSpecialTokens && _specialTokensReverse is not null && _specialTokensReverse.TryGetValue(enumerator.Current, out string? specialToken))
else
{
sb.Append(specialToken);
TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb);
}

char[]? charPoolArray = null;
Expand All @@ -1610,6 +1604,10 @@ revMerge is null ||
{
FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb);
}

// It is possible listing some special tokens before the byte tokens in the tokenizer's data.
TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb);

continue;
}

Expand Down Expand Up @@ -1642,9 +1640,9 @@ revMerge is null ||
{
AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex);
}
else if (considerSpecialTokens && _specialTokensReverse is not null && _specialTokensReverse.TryGetValue(enumerator.Current, out string? specialToken))
else
{
sb.Append(specialToken);
TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb);
}
}
}
Expand Down Expand Up @@ -1736,6 +1734,31 @@ static void AppendTokenWithCheckingPrefix(bool addDummyPrefix, bool treatWhitesp

prefixRemoved = true;
}

static void TryDecodeAsSpecialToken(SentencePieceBpeTokenizer tokenizer, int id, bool considerSpecialTokens, ref ValueStringBuilder sb)
{
if (!considerSpecialTokens)
{
return;
}

if (id == tokenizer.BeginningOfSentenceId)
{
sb.Append(tokenizer.BeginningOfSentenceToken);
}
else if (id == tokenizer.EndOfSentenceId)
{
sb.Append(tokenizer.EndOfSentenceToken);
}
else if (id == tokenizer.UnknownId)
{
sb.Append(tokenizer.UnknownToken);
}
else if (tokenizer._specialTokensReverse?.TryGetValue(id, out string? specialToken) is true)
{
sb.Append(specialToken);
}
}
}

/// <summary>
Expand Down Expand Up @@ -1776,29 +1799,6 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool

Span<char> buffer = destination;

if (enumerator.Current == BeginningOfSentenceId)
{
if (considerSpecialTokens)
{
if (buffer.Length < BeginningOfSentenceToken.Length)
{
return OperationStatus.DestinationTooSmall;
}

BeginningOfSentenceToken.AsSpan().CopyTo(buffer);
buffer = buffer.Slice(BeginningOfSentenceToken.Length);
charsWritten += BeginningOfSentenceToken.Length;
}

idsConsumed++;

// escape prefix control tokens.
if (!enumerator.MoveNext())
{
return OperationStatus.Done;
}
}

int bytesCount = -1;
byte[]? bytesPoolArray = null;
bool prefixRemoved = false;
Expand All @@ -1808,9 +1808,15 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool
if (enumerator.Current <= _maxByteId)
{
// First token is a byte token.

while (enumerator.Current < _byteCodeToIdOffset)
{
OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten);
if (status != OperationStatus.Done)
{
return status;
}
buffer = destination.Slice(charsWritten);

// Skip control tokens.
idsConsumed++;
if (!enumerator.MoveNext())
Expand All @@ -1833,6 +1839,16 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool
return OperationStatus.DestinationTooSmall;
}
}
else
{
OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten);
if (status != OperationStatus.Done)
{
return status;
}

idsConsumed++;
}
}
else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token))
{
Expand All @@ -1841,24 +1857,16 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool
return OperationStatus.DestinationTooSmall;
}
}
else if (_specialTokensReverse is not null && _specialTokensReverse.TryGetValue(enumerator.Current, out string? specialToken))
else
{
if (considerSpecialTokens)
OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten);
if (status != OperationStatus.Done)
{
if (buffer.Length < specialToken.Length)
{
return OperationStatus.DestinationTooSmall;
}

specialToken.AsSpan().CopyTo(buffer);
charsWritten += specialToken.Length;
return status;
}

idsConsumed++;
}
else
{
return OperationStatus.InvalidData;
}

char[]? charPoolArray = null;

Expand All @@ -1876,6 +1884,12 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool
}
}

OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten);
if (status != OperationStatus.Done)
{
return status;
}

idsConsumed++;
continue;
}
Expand Down Expand Up @@ -1918,24 +1932,16 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool
return OperationStatus.DestinationTooSmall;
}
}
else if (_specialTokensReverse is not null && _specialTokensReverse.TryGetValue(enumerator.Current, out string? specialToken))
else
{
if (considerSpecialTokens)
OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten);
if (status != OperationStatus.Done)
{
if (buffer.Length < specialToken.Length)
{
return OperationStatus.DestinationTooSmall;
}

specialToken.AsSpan().CopyTo(buffer);
charsWritten += specialToken.Length;
return status;
}

idsConsumed++;
}
else
{
return OperationStatus.InvalidData;
}
}
}

Expand Down Expand Up @@ -1973,6 +1979,41 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool

return OperationStatus.Done;

static OperationStatus TryDecodeAsSpecialToken(SentencePieceBpeTokenizer tokenizer, int id, bool considerSpecialTokens, Span<char> buffer, ref int charsWritten)
{
string? specialToken = null;

if (id == tokenizer.BeginningOfSentenceId)
{
specialToken = tokenizer.BeginningOfSentenceToken;
}
else if (id == tokenizer.EndOfSentenceId)
{
specialToken = tokenizer.EndOfSentenceToken;
}
else if (id == tokenizer.UnknownId)
{
specialToken = tokenizer.UnknownToken;
}
else if (!tokenizer._specialTokensReverse?.TryGetValue(id, out specialToken) is true)
{
return OperationStatus.InvalidData;
}

if (considerSpecialTokens && specialToken is not null)
{
if (buffer.Length < specialToken!.Length)
{
return OperationStatus.DestinationTooSmall;
}

specialToken.AsSpan().CopyTo(buffer);
charsWritten += specialToken.Length;
}

return OperationStatus.Done;
}

static bool FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, Span<char> buffer, ref int charsWritten, ref int idsConsumed)
{
Debug.Assert(bytesCount >= 1);
Expand Down
15 changes: 15 additions & 0 deletions test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,21 @@ public void TestLlamaTokenizerProperties(Tokenizer llamaTokenizer)
TokenizerTests.TestTokenLimits(llamaTokenizer);
}

/// <summary>
/// Test that the special token with a small id is decoded correctly.
/// </summary>
[Theory]
[MemberData(nameof(LlamaTokenizersListData))]
public void TestDecodeSpecialTokenWithSmallId(LlamaTokenizer llamaTokenizer)
{
Assert.Equal(llamaTokenizer.EndOfSentenceToken, llamaTokenizer.Decode([llamaTokenizer.EndOfSentenceId], considerSpecialTokens: true));
Span<char> destinationBuffer = stackalloc char[llamaTokenizer.EndOfSentenceToken.Length];
Assert.Equal(OperationStatus.Done, llamaTokenizer.Decode([llamaTokenizer.EndOfSentenceId], destinationBuffer, considerSpecialTokens: true, out int idsConsumed, out int charactersWritten));
Assert.Equal(llamaTokenizer.EndOfSentenceToken.Length, charactersWritten);
Assert.Equal(llamaTokenizer.EndOfSentenceToken, destinationBuffer.ToString());
Assert.Equal(1, idsConsumed);
}

[Fact]
public void TestSentencePieceNormalizer()
{
Expand Down

0 comments on commit 87a41fa

Please sign in to comment.