-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #26 from Nayjest/tokenize_remote_models
Tiktoken usage for estimating number of tokens in prompt / response, fitting semantic search results to target token num
- Loading branch information
Showing
12 changed files
with
191 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import logging | ||
|
||
import tiktoken | ||
import requests.exceptions | ||
from ._env import env | ||
|
||
|
||
class CantLoadTikTokenEncoding(RuntimeError): | ||
... | ||
|
||
|
||
def _resolve_tiktoken_encoding( | ||
for_model: str = None, encoding: str | tiktoken.Encoding = None | ||
) -> tiktoken.Encoding: | ||
assert ( | ||
for_model is None or encoding is None | ||
), "You may specify encoding or for_model(LLM), but not both" | ||
if isinstance(encoding, tiktoken.Encoding): | ||
return encoding | ||
if for_model is None and encoding is None: | ||
if env().config.TIKTOKEN_ENCODING: | ||
return _resolve_tiktoken_encoding(encoding=env().config.TIKTOKEN_ENCODING) | ||
for_model = ( | ||
env().config.LLM_DEFAULT_ARGS.get("model", None) or env().config.MODEL | ||
) | ||
if for_model: | ||
try: | ||
return tiktoken.encoding_for_model(for_model) | ||
except (KeyError, requests.exceptions.ConnectionError): | ||
logging.warning( | ||
f"Can't resolve tiktoken encoding for '{for_model}'. " | ||
f"Default encoding will be used." | ||
) | ||
encoding = encoding or "cl100k_base" | ||
try: | ||
return tiktoken.get_encoding(encoding) | ||
except (ValueError, requests.exceptions.ConnectionError) as e: | ||
raise CantLoadTikTokenEncoding( | ||
f"Can't load tiktok encoding '{encoding}'" | ||
) from e | ||
|
||
|
||
def encode( | ||
string: str, for_model: str = None, encoding: str | tiktoken.Encoding = None | ||
) -> list[int]: | ||
"""Encodes string to LLM tokens""" | ||
return _resolve_tiktoken_encoding(for_model, encoding).encode(string) | ||
|
||
|
||
def num_tokens_from_string( | ||
string: str, for_model: str = None, encoding: str | tiktoken.Encoding = None | ||
) -> int: | ||
"""Returns the number of tokens in a text string.""" | ||
return len(encode(string, for_model=for_model, encoding=encoding)) | ||
|
||
|
||
def fit_to_token_size( | ||
docs: list[str], | ||
max_tokens: int, | ||
min_documents: int = None, | ||
for_model: str = None, | ||
encoding: str | tiktoken.Encoding = None, | ||
) -> tuple[list[str], int]: | ||
""" | ||
Fit the list of documents to the max_tokens size. | ||
Returns the new list of documents and qty of removed items | ||
""" | ||
encoding = _resolve_tiktoken_encoding(for_model, encoding) | ||
tot_size = 0 | ||
for i, doc in enumerate(docs): | ||
tot_size += num_tokens_from_string(doc, encoding=encoding) | ||
if min_documents and i < min_documents: | ||
continue | ||
if tot_size > max_tokens: | ||
result = docs[:i] | ||
return result, len(docs) - len(result) | ||
return docs, 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
-r min.txt | ||
chromadb>=0.4.18,<0.6 | ||
anthropic>=0.19.1,<=0.25.8 | ||
google-generativeai>=0.7.2,<1 | ||
vertexai>=1.60.0,<2 | ||
transformers>=4.43.3,<5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import microcore as mc | ||
from microcore import SearchResult | ||
|
||
|
||
def test_fit_vector_search_to_tokens(): | ||
mc.texts.clear("test_collection") | ||
raw_items = [str(i) for i in range(10)] | ||
mc.texts.save_many("test_collection", raw_items) | ||
res = mc.texts.search("test_collection", "qwe", n_results=10) | ||
# Check all loaded | ||
assert sorted(res) == raw_items | ||
|
||
fres = res.fit_to_token_size(3) | ||
# check fit | ||
assert len(fres) == 3 | ||
assert any(i in raw_items for i in fres) | ||
|
||
# check that distances of fitted elements are smallest | ||
smallest_dist = sorted(i.distance for i in res)[:3] | ||
fitted_dist = sorted(i.distance for i in fres) | ||
assert fitted_dist == smallest_dist | ||
|
||
assert fres[0].num_tokens() == 1 | ||
|
||
|
||
def test_fit_vector_search_to_tokens_min_docs(): | ||
mc.texts.clear("test_collection") | ||
raw_items = [str(i) for i in range(10)] | ||
mc.texts.save_many("test_collection", raw_items) | ||
res = mc.texts.search("test_collection", "qwe", n_results=10).fit_to_token_size(3, 4) | ||
assert len(res) == 4 | ||
res = mc.texts.search("test_collection", "qwe", n_results=10).fit_to_token_size(5, 3) | ||
assert len(res) == 5 | ||
|
||
|
||
def test_num_tokens(): | ||
assert SearchResult("apple pineapple orange").num_tokens(encoding='cl100k_base') >= 3 | ||
assert SearchResult("Hi").num_tokens(for_model='gpt-4') <= 2 |