diff --git a/CHANGELOG.md b/CHANGELOG.md index c074bcce..8f8e88ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ This is the changelog for the open source version of tiktoken. +## [v0.6.0] +- Optimise regular expressions for a 20% performance improvement +- Add `text-embedding-3-*` models to `encoding_for_model` +- Check content hash for downloaded files +- Allow pickling `Encoding` objects. Registered `Encoding` will be pickled by reference +- Workaround PyO3 bug for frozenset conversion + +Thank you to @paplorinc, @mdwelsh, @Praneet460! + ## [v0.5.2] - Build wheels for Python 3.12 - Update version of PyO3 to allow multiple imports diff --git a/Cargo.toml b/Cargo.toml index fb69db74..14588580 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tiktoken" -version = "0.5.2" +version = "0.6.0" edition = "2021" rust-version = "1.57.0" @@ -16,6 +16,3 @@ fancy-regex = "0.11.0" regex = "1.8.3" rustc-hash = "1.1.0" bstr = "1.5.0" - -[profile.release] -incremental = true diff --git a/pyproject.toml b/pyproject.toml index b48941d6..47aada31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tiktoken" -version = "0.5.2" +version = "0.6.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" readme = "README.md" license = {file = "LICENSE"} diff --git a/tiktoken/core.py b/tiktoken/core.py index aa72c9d7..3e32d4ce 100644 --- a/tiktoken/core.py +++ b/tiktoken/core.py @@ -116,6 +116,10 @@ def encode( if match := _special_token_regex(disallowed_special).search(text): raise_disallowed_special_token(match.group()) + # https://github.com/PyO3/pyo3/pull/3632 + if isinstance(allowed_special, frozenset): + allowed_special = set(allowed_special) + try: return self._core_bpe.encode(text, allowed_special) except UnicodeEncodeError: @@ -364,6 +368,26 @@ def _encode_only_native_bpe(self, text: str) -> list[int]: def _encode_bytes(self, text: bytes) -> list[int]: return self._core_bpe._encode_bytes(text) + def __getstate__(self) -> object: + import tiktoken.registry + + # As an optimisation, pickle registered encodings by reference + if self is tiktoken.registry.ENCODINGS.get(self.name): + return self.name + return { + "name": self.name, + "pat_str": self._pat_str, + "mergeable_ranks": self._mergeable_ranks, + "special_tokens": self._special_tokens, + } + + def __setstate__(self, value: object) -> None: + import tiktoken.registry + + if isinstance(value, str): + self.__dict__ = tiktoken.registry.get_encoding(value).__dict__ + return + self.__init__(**value) @functools.lru_cache(maxsize=128) diff --git a/tiktoken/load.py b/tiktoken/load.py index 45729b11..cc0a6a6d 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -27,12 +27,12 @@ def read_file(blobpath: str) -> bytes: return resp.content -def check_hash(data: bytes, hash: str) -> bool: - data_hash = hashlib.sha256(data).hexdigest() - return data_hash == hash +def check_hash(data: bytes, expected_hash: str) -> bool: + actual_hash = hashlib.sha256(data).hexdigest() + return actual_hash == expected_hash -def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes: +def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> bytes: user_specified_cache = True if "TIKTOKEN_CACHE_DIR" in os.environ: cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] @@ -52,13 +52,15 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes: if os.path.exists(cache_path): with open(cache_path, "rb") as f: data = f.read() - if expected_hash and not check_hash(data, expected_hash): - raise ValueError( - f"Hash mismatch for cached data from {blobpath} (expected {expected_hash}). " - f"Please delete the cache file at {cache_path} and try again." - ) + if expected_hash is None or check_hash(data, expected_hash): return data + # the cached file does not match the hash, remove it and re-fetch + try: + os.remove(cache_path) + except OSError: + pass + contents = read_file(blobpath) if expected_hash and not check_hash(contents, expected_hash): raise ValueError( @@ -81,7 +83,10 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes: def data_gym_to_mergeable_bpe_ranks( - vocab_bpe_file: str, encoder_json_file: str, vocab_bpe_hash: Optional[str]=None, encoder_json_hash: Optional[str]=None + vocab_bpe_file: str, + encoder_json_file: str, + vocab_bpe_hash: Optional[str] = None, + encoder_json_hash: Optional[str] = None, ) -> dict[bytes, int]: # NB: do not add caching to this function rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] @@ -135,7 +140,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n") -def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: Optional[str]=None) -> dict[bytes, int]: +def load_tiktoken_bpe( + tiktoken_bpe_file: str, expected_hash: Optional[str] = None +) -> dict[bytes, int]: # NB: do not add caching to this function contents = read_file_cached(tiktoken_bpe_file, expected_hash) return {