Skip to content

Commit

Permalink
Sync codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
hauntsaninja committed Feb 9, 2024
1 parent 55c8d83 commit 89153d7
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 16 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

This is the changelog for the open source version of tiktoken.

## [v0.6.0]
- Optimise regular expressions for a 20% performance improvement
- Add `text-embedding-3-*` models to `encoding_for_model`
- Check content hash for downloaded files
- Allow pickling `Encoding` objects. Registered `Encoding` will be pickled by reference
- Workaround PyO3 bug for frozenset conversion

Thank you to @paplorinc, @mdwelsh, @Praneet460!

## [v0.5.2]
- Build wheels for Python 3.12
- Update version of PyO3 to allow multiple imports
Expand Down
5 changes: 1 addition & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tiktoken"
version = "0.5.2"
version = "0.6.0"
edition = "2021"
rust-version = "1.57.0"

Expand All @@ -16,6 +16,3 @@ fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.5.0"

[profile.release]
incremental = true
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "tiktoken"
version = "0.5.2"
version = "0.6.0"
description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
readme = "README.md"
license = {file = "LICENSE"}
Expand Down
24 changes: 24 additions & 0 deletions tiktoken/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def encode(
if match := _special_token_regex(disallowed_special).search(text):
raise_disallowed_special_token(match.group())

# https://github.com/PyO3/pyo3/pull/3632
if isinstance(allowed_special, frozenset):
allowed_special = set(allowed_special)

try:
return self._core_bpe.encode(text, allowed_special)
except UnicodeEncodeError:
Expand Down Expand Up @@ -364,6 +368,26 @@ def _encode_only_native_bpe(self, text: str) -> list[int]:
def _encode_bytes(self, text: bytes) -> list[int]:
return self._core_bpe._encode_bytes(text)

def __getstate__(self) -> object:
import tiktoken.registry

# As an optimisation, pickle registered encodings by reference
if self is tiktoken.registry.ENCODINGS.get(self.name):
return self.name
return {
"name": self.name,
"pat_str": self._pat_str,
"mergeable_ranks": self._mergeable_ranks,
"special_tokens": self._special_tokens,
}

def __setstate__(self, value: object) -> None:
import tiktoken.registry

if isinstance(value, str):
self.__dict__ = tiktoken.registry.get_encoding(value).__dict__
return
self.__init__(**value)


@functools.lru_cache(maxsize=128)
Expand Down
29 changes: 18 additions & 11 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def read_file(blobpath: str) -> bytes:
return resp.content


def check_hash(data: bytes, hash: str) -> bool:
data_hash = hashlib.sha256(data).hexdigest()
return data_hash == hash
def check_hash(data: bytes, expected_hash: str) -> bool:
actual_hash = hashlib.sha256(data).hexdigest()
return actual_hash == expected_hash


def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> bytes:
user_specified_cache = True
if "TIKTOKEN_CACHE_DIR" in os.environ:
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
Expand All @@ -52,13 +52,15 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
data = f.read()
if expected_hash and not check_hash(data, expected_hash):
raise ValueError(
f"Hash mismatch for cached data from {blobpath} (expected {expected_hash}). "
f"Please delete the cache file at {cache_path} and try again."
)
if expected_hash is None or check_hash(data, expected_hash):
return data

# the cached file does not match the hash, remove it and re-fetch
try:
os.remove(cache_path)
except OSError:
pass

contents = read_file(blobpath)
if expected_hash and not check_hash(contents, expected_hash):
raise ValueError(
Expand All @@ -81,7 +83,10 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:


def data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file: str, encoder_json_file: str, vocab_bpe_hash: Optional[str]=None, encoder_json_hash: Optional[str]=None
vocab_bpe_file: str,
encoder_json_file: str,
vocab_bpe_hash: Optional[str] = None,
encoder_json_hash: Optional[str] = None,
) -> dict[bytes, int]:
# NB: do not add caching to this function
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
Expand Down Expand Up @@ -135,7 +140,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")


def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: Optional[str]=None) -> dict[bytes, int]:
def load_tiktoken_bpe(
tiktoken_bpe_file: str, expected_hash: Optional[str] = None
) -> dict[bytes, int]:
# NB: do not add caching to this function
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
return {
Expand Down

0 comments on commit 89153d7

Please sign in to comment.