diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 777a0acc..72bf67be 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -15,19 +15,19 @@ jobs: matrix: # cibuildwheel builds linux wheels inside a manylinux container # it also takes care of procuring the correct python version for us - os: [ubuntu-latest, windows-latest, macos-13] - python-version: [38, 39, 310, 311, 312] + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: [39, 310, 311, 312, 313] steps: - uses: actions/checkout@v4 - - uses: pypa/cibuildwheel@v2.18.0 + - uses: pypa/cibuildwheel@v2.21.2 env: CIBW_BUILD: "cp${{ matrix.python-version}}-*" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: - name: dist + name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} path: ./wheelhouse/*.whl build_wheels_aarch64: @@ -37,7 +37,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: [38, 39, 310, 311, 312] + python-version: [39, 310, 311, 312, 313] steps: - uses: actions/checkout@v4 @@ -48,16 +48,17 @@ jobs: platforms: arm64 - name: Build wheels - uses: pypa/cibuildwheel@v2.18.0 + uses: pypa/cibuildwheel@v2.21.2 env: CIBW_BUILD: "cp${{ matrix.python-version}}-*" CIBW_ARCHS: aarch64 CIBW_BUILD_VERBOSITY: 3 # https://github.com/rust-lang/cargo/issues/10583 CIBW_ENVIRONMENT_LINUX: PATH="$PATH:$HOME/.cargo/bin" CARGO_NET_GIT_FETCH_WITH_CLI=true - - uses: actions/upload-artifact@v3 + + - uses: actions/upload-artifact@v4 with: - name: dist + name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} path: ./wheelhouse/*.whl build_sdist: @@ -65,7 +66,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 name: Install Python with: python-version: "3.9" diff --git a/CHANGELOG.md b/CHANGELOG.md index f94795ea..c70420b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,11 +2,26 @@ This is the changelog for the open source version of tiktoken. +## [v0.8.0] + +- Support for `o1-` and `chatgpt-4o-` models +- Build wheels for Python 3.13 +- Add possessive quantifiers to limit backtracking in regular expressions, thanks to @l0rinc! +- Provide a better error message and type for invalid token decode +- Permit tuples in type hints +- Better error message for passing invalid input to `get_encoding` +- Better error messages during plugin loading +- Add a `__version__` attribute +- Update versions of `pyo3`, `regex`, `fancy-regex` +- Drop support for Python 3.8 + ## [v0.7.0] + - Support for `gpt-4o` - Performance improvements ## [v0.6.0] + - Optimise regular expressions for a 20% performance improvement, thanks to @paplorinc! - Add `text-embedding-3-*` models to `encoding_for_model` - Check content hash for downloaded files @@ -16,14 +31,17 @@ This is the changelog for the open source version of tiktoken. Thank you to @paplorinc, @mdwelsh, @Praneet460! ## [v0.5.2] + - Build wheels for Python 3.12 - Update version of PyO3 to allow multiple imports - Avoid permission errors when using default cache logic ## [v0.5.1] + - Add `encoding_name_for_model`, undo some renames to variables that are implementation details ## [v0.5.0] + - Add `tiktoken._educational` submodule to better document how byte pair encoding works - Ensure `encoding_for_model` knows about several new models - Add `decode_with_offets` @@ -32,23 +50,28 @@ Thank you to @paplorinc, @mdwelsh, @Praneet460! - Update versions of dependencies ## [v0.4.0] + - Add `decode_batch` and `decode_bytes_batch` - Improve error messages and handling ## [v0.3.3] + - `tiktoken` will now make a best effort attempt to replace surrogate pairs with the corresponding - Unicode character and will replace lone surrogates with the Unicode replacement character. + Unicode character and will replace lone surrogates with the Unicode replacement character. ## [v0.3.2] + - Add encoding for GPT-4 ## [v0.3.1] + - Build aarch64 wheels - Make `blobfile` an optional dependency Thank you to @messense for the environment variable that makes cargo not OOM under emulation! ## [v0.3.0] + - Improve performance by 5-20%; thank you to @nistath! - Add `gpt-3.5-turbo` models to `encoding_for_model` - Add prefix matching to `encoding_for_model` to better support future model versions @@ -57,16 +80,19 @@ Thank you to @messense for the environment variable that makes cargo not OOM und - Add packaging metadata ## [v0.2.0] -- Add ``tiktoken.encoding_for_model`` to get the encoding for a specific model + +- Add `tiktoken.encoding_for_model` to get the encoding for a specific model - Improve portability of caching logic Thank you to @fritzo, @arvid220u, @khanhvu207, @henriktorget for various small corrections ## [v0.1.2] + - Avoid use of `blobfile` for public files - Add support for Python 3.8 - Add py.typed - Improve the public tests ## [v0.1.1] + - Initial release diff --git a/Cargo.toml b/Cargo.toml index 52881a05..2eed0c16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tiktoken" -version = "0.7.0" +version = "0.8.0" edition = "2021" rust-version = "1.57.0" @@ -9,7 +9,7 @@ name = "_tiktoken" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.20.0", features = ["extension-module"] } +pyo3 = { version = "0.22.2", default-features = false, features = ["extension-module", "macros"] } # tiktoken dependencies fancy-regex = "0.13.0" diff --git a/README.md b/README.md index 124d5828..ad8d88b1 100644 --- a/README.md +++ b/README.md @@ -128,3 +128,4 @@ setup( Then simply `pip install ./my_tiktoken_extension` and you should be able to use your custom encodings! Make sure **not** to use an editable install. + diff --git a/pyproject.toml b/pyproject.toml index 7cc7cb10..1eb81ce6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,13 @@ [project] name = "tiktoken" -version = "0.7.0" +version = "0.8.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" readme = "README.md" license = {file = "LICENSE"} authors = [{name = "Shantanu Jain"}, {email = "shantanu@openai.com"}] dependencies = ["regex>=2022.1.18", "requests>=2.26.0"] optional-dependencies = {blobfile = ["blobfile>=2"]} -requires-python = ">=3.8" +requires-python = ">=3.9" [project.urls] homepage = "https://github.com/openai/tiktoken" @@ -24,7 +24,7 @@ build-verbosity = 1 linux.before-all = "curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y" linux.environment = { PATH = "$PATH:$HOME/.cargo/bin" } -macos.before-all = "rustup target add aarch64-apple-darwin" +macos.before-all = "rustup target add aarch64-apple-darwin x86_64-apple-darwin" skip = [ "*-manylinux_i686", diff --git a/src/lib.rs b/src/lib.rs index 46712ecd..0203acfc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,12 +6,11 @@ use std::num::NonZeroU64; use std::thread; use fancy_regex::Regex; -use fancy_regex::RegexBuilder; use pyo3::exceptions; use pyo3::prelude::*; -use pyo3::pyclass; -use pyo3::PyResult; +use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyBytes, PyList, PyTuple}; +use pyo3::PyResult; use rustc_hash::FxHashMap as HashMap; type Rank = u32; @@ -75,8 +74,10 @@ fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, } pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, Rank>) -> Vec { - assert!(piece.len() > 1); - _byte_pair_merge(&ranks, &piece) + if piece.len() == 1 { + return vec![ranks[piece]]; + } + _byte_pair_merge(ranks, piece) .windows(2) .map(|part| ranks[&piece[part[0].0..part[1].0]]) .collect() @@ -84,7 +85,7 @@ pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, Rank>) -> Vec(piece: &'a [u8], ranks: &HashMap, Rank>) -> Vec<&'a [u8]> { assert!(piece.len() > 1); - _byte_pair_merge(&ranks, &piece) + _byte_pair_merge(ranks, piece) .windows(2) .map(|part| &piece[part[0].0..part[1].0]) .collect() @@ -138,14 +139,25 @@ fn hash_current_thread() -> usize { // that works great for our use case of avoiding collisions in our array. Unfortunately, // it's private. However, there are only so many ways you can layout a u64, so just transmute // https://github.com/rust-lang/rust/issues/67939 - const _: [u8; 8] = [0; std::mem::size_of::()]; + const _: [u8; 8] = [0; std::mem::size_of::()]; const _: [u8; 8] = [0; std::mem::size_of::()]; let x = unsafe { - std::mem::transmute::(thread::current().id()).0 + std::mem::transmute::(thread::current().id()).0 }; u64::from(x) as usize } +#[derive(Debug, Clone)] +struct DecodeKeyError { + token: Rank, +} + +impl std::fmt::Display for DecodeKeyError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Invalid token for decoding: {}", self.token) + } +} + const MAX_NUM_THREADS: usize = 128; #[pyclass] @@ -171,16 +183,19 @@ impl CoreBPE { &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] } - fn _decode_native(&self, tokens: &[Rank]) -> Vec { + fn _decode_native(&self, tokens: &[Rank]) -> Result, DecodeKeyError> { let mut ret = Vec::with_capacity(tokens.len() * 2); - for token in tokens { - let token_bytes = self - .decoder - .get(token) - .unwrap_or_else(|| &self.special_tokens_decoder[token]); + for &token in tokens { + let token_bytes = match self.decoder.get(&token) { + Some(bytes) => bytes, + None => self + .special_tokens_decoder + .get(&token) + .ok_or(DecodeKeyError { token })?, + }; ret.extend(token_bytes); } - ret + Ok(ret) } fn _encode_ordinary_native(&self, text: &str) -> Vec { @@ -307,7 +322,9 @@ impl CoreBPE { let (mut tokens, last_piece_token_len) = self._increase_last_piece_token_len(tokens, last_piece_token_len); - let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); + let unstable_bytes = self + ._decode_native(&tokens[tokens.len() - last_piece_token_len..]) + .unwrap(); tokens.truncate(tokens.len() - last_piece_token_len); // TODO: we should try harder to find additional stable tokens @@ -418,7 +435,7 @@ impl CoreBPE { special_tokens_encoder: HashMap, pattern: &str, ) -> PyResult { - let regex = RegexBuilder::new(pattern).backtrack_limit(10_000).build() + let regex = Regex::new(pattern) .map_err(|e| PyErr::new::(e.to_string()))?; let special_regex = { @@ -468,8 +485,27 @@ impl CoreBPE { py.allow_threads(|| self._encode_ordinary_native(text)) } - fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec { - py.allow_threads(|| self._encode_native(text, &allowed_special).0) + fn encode(&self, py: Python, text: &str, allowed_special: HashSet) -> Vec { + py.allow_threads(|| { + let allowed_special: HashSet<&str> = + allowed_special.iter().map(|s| s.as_ref()).collect(); + self._encode_native(text, &allowed_special).0 + }) + } + + fn encode_to_tiktoken_buffer( + &self, + py: Python, + text: &str, + allowed_special: HashSet, + ) -> Py { + let tokens = py.allow_threads(|| { + let allowed_special: HashSet<&str> = + allowed_special.iter().map(|s| s.as_ref()).collect(); + self._encode_native(text, &allowed_special).0 + }); + let buffer = TiktokenBuffer { tokens }; + buffer.into_py(py) } fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { @@ -486,14 +522,17 @@ impl CoreBPE { // Somewhat niche, but this may not be correct if we'd have had a regex // split between the valid UTF-8 and the invalid bytes, which is why this // method is private - let mut unstable_bytes = - self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); + let mut unstable_bytes = self + ._decode_native(&tokens[tokens.len() - last_piece_token_len..]) + .unwrap(); unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); tokens.truncate(tokens.len() - last_piece_token_len); match self.encoder.get(&unstable_bytes) { Some(token) => tokens.push(*token), - None => tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)), + None => { + tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)) + } } } tokens @@ -506,12 +545,19 @@ impl CoreBPE { &self, py: Python, text: &str, - allowed_special: HashSet<&str>, + allowed_special: HashSet, ) -> Py { - let (tokens, completions) = - py.allow_threads(|| self._encode_unstable_native(text, &allowed_special)); - let py_completions = - PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..]))); + let (tokens, completions) = py.allow_threads(|| { + let allowed_special: HashSet<&str> = + allowed_special.iter().map(|s| s.as_ref()).collect(); + self._encode_unstable_native(text, &allowed_special) + }); + let py_completions = PyList::new_bound( + py, + completions + .iter() + .map(|seq| PyList::new_bound(py, &seq[..])), + ); (tokens, py_completions).into_py(py) } @@ -538,17 +584,19 @@ impl CoreBPE { // Decoding // ==================== - fn decode_bytes(&self, py: Python, tokens: Vec) -> Py { - let bytes = py.allow_threads(|| self._decode_native(&tokens)); - PyBytes::new(py, &bytes).into() + fn decode_bytes(&self, py: Python, tokens: Vec) -> Result, PyErr> { + match py.allow_threads(|| self._decode_native(&tokens)) { + Ok(bytes) => Ok(PyBytes::new_bound(py, &bytes).into()), + Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))), + } } fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult> { if let Some(bytes) = self.decoder.get(&token) { - return Ok(PyBytes::new(py, bytes).into()); + return Ok(PyBytes::new_bound(py, bytes).into()); } if let Some(bytes) = self.special_tokens_decoder.get(&token) { - return Ok(PyBytes::new(py, bytes).into()); + return Ok(PyBytes::new_bound(py, bytes).into()); } Err(PyErr::new::(token.to_string())) } @@ -560,29 +608,83 @@ impl CoreBPE { fn token_byte_values(&self, py: Python) -> Vec> { self.sorted_token_bytes .iter() - .map(|x| PyBytes::new(py, x).into()) + .map(|x| PyBytes::new_bound(py, x).into()) .collect() } } +#[pyclass] +struct TiktokenBuffer { + tokens: Vec, +} + +#[pymethods] +impl TiktokenBuffer { + // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25 + unsafe fn __getbuffer__( + slf: Bound<'_, Self>, + view: *mut pyo3::ffi::Py_buffer, + flags: std::os::raw::c_int, + ) -> PyResult<()> { + if view.is_null() { + return Err(pyo3::exceptions::PyBufferError::new_err("View is null")); + } + if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE { + return Err(pyo3::exceptions::PyBufferError::new_err( + "Object is not writable", + )); + } + + (*view).obj = slf.clone().into_any().into_ptr(); + + let data = &slf.borrow().tokens; + (*view).buf = data.as_ptr() as *mut std::os::raw::c_void; + (*view).len = (data.len() * std::mem::size_of::()) as isize; + (*view).readonly = 1; + (*view).itemsize = std::mem::size_of::() as isize; + (*view).format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT { + let msg = std::ffi::CString::new("I").unwrap(); + msg.into_raw() + } else { + std::ptr::null_mut() + }; + (*view).ndim = 1; + (*view).shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND { + &mut (*view).len + } else { + std::ptr::null_mut() + }; + (*view).strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES { + &mut (*view).itemsize + } else { + std::ptr::null_mut() + }; + (*view).suboffsets = std::ptr::null_mut(); + (*view).internal = std::ptr::null_mut(); + + Ok(()) + } + + unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) { + std::mem::drop(std::ffi::CString::from_raw((*view).format)); + } +} + #[pymodule] -fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { +fn _tiktoken(_py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; Ok(()) } #[cfg(test)] mod tests { - use fancy_regex::RegexBuilder; + use rustc_hash::FxHashMap as HashMap; use crate::{byte_pair_split, Rank}; fn setup_ranks() -> HashMap, Rank> { - HashMap::from_iter([ - (b"ab".to_vec(), 0), - (b"cd".to_vec(), 1), - ]) + HashMap::from_iter([(b"ab".to_vec(), 0), (b"cd".to_vec(), 1)]) } #[test] @@ -598,16 +700,4 @@ mod tests { let res = byte_pair_split(b"abab", &ranks); assert_eq!(res, vec![b"ab", b"ab"]); } - - #[test] - fn test_effect_of_backtrack_limit() { - let regex = RegexBuilder::new(r"(a|b|ab)*(?=c)") - .backtrack_limit(10) - .build() - .expect("Failed to build regex") - .clone(); - - let input = "ab".repeat(100) + "c"; - assert!(regex.is_match(&input).is_err(), "Should throw"); - } } diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 0e02b47a..3cf83776 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -11,22 +11,6 @@ from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES -@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) -def test_extremely_big_encoding(make_enc: Callable[[], tiktoken.Encoding]): - enc = make_enc() - for c in ["^", "0", "a", "'s", " ", "\n"]: - print(f"Validating `{c}`") - - big_value = c * 10_000 - assert big_value == enc.decode(enc.encode(big_value)) - - big_value = " " + big_value - assert big_value == enc.decode(enc.encode(big_value)) - - big_value = big_value + "\n" - assert big_value == enc.decode(enc.encode(big_value)) - - def test_simple(): enc = tiktoken.get_encoding("gpt2") assert enc.encode("hello world") == [31373, 995] @@ -40,7 +24,7 @@ def test_simple(): for enc_name in tiktoken.list_encoding_names(): enc = tiktoken.get_encoding(enc_name) - for token in range(10_000): + for token in range(min(10_000, enc.max_token_value - 1)): assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token @@ -107,6 +91,20 @@ def test_encode_surrogate_pairs(): assert enc.encode("\ud83d") == enc.encode("�") +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_catastrophically_repetitive(make_enc: Callable[[], tiktoken.Encoding]): + enc = make_enc() + for c in ["^", "0", "a", "'s", " ", "\n"]: + big_value = c * 10_000 + assert big_value == enc.decode(enc.encode(big_value)) + + big_value = " " + big_value + assert big_value == enc.decode(enc.encode(big_value)) + + big_value = big_value + "\n" + assert big_value == enc.decode(enc.encode(big_value)) + + # ==================== # Roundtrip # ==================== diff --git a/tests/test_misc.py b/tests/test_misc.py index a2b00f67..7da53897 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -13,6 +13,10 @@ def test_encoding_for_model(): assert enc.name == "p50k_edit" enc = tiktoken.encoding_for_model("gpt-3.5-turbo-0301") assert enc.name == "cl100k_base" + enc = tiktoken.encoding_for_model("gpt-4") + assert enc.name == "cl100k_base" + enc = tiktoken.encoding_for_model("gpt-4o") + assert enc.name == "o200k_base" def test_optional_blobfile_dependency(): diff --git a/tests/test_pickle.py b/tests/test_pickle.py new file mode 100644 index 00000000..49481ae1 --- /dev/null +++ b/tests/test_pickle.py @@ -0,0 +1,23 @@ +import tiktoken + + +def test_pickle(): + import pickle + + enc_old = tiktoken.get_encoding("r50k_base") + enc_new = pickle.loads(pickle.dumps(enc_old)) + assert enc_old.encode("hello world") == enc_new.encode("hello world") + + enc_old = tiktoken.Encoding( + name="custom_enc", + pat_str=enc_old._pat_str, + mergeable_ranks=enc_old._mergeable_ranks, + special_tokens={"<|pickle|>": 100_000}, + ) + enc_new = pickle.loads(pickle.dumps(enc_old)) + assert enc_old.encode("hello world") == enc_new.encode("hello world") + assert ( + enc_old.encode("<|pickle|>", allowed_special="all") + == enc_new.encode("<|pickle|>", allowed_special="all") + == [100_000] + ) diff --git a/tiktoken/__init__.py b/tiktoken/__init__.py index 3a531b18..9ed615e3 100644 --- a/tiktoken/__init__.py +++ b/tiktoken/__init__.py @@ -4,3 +4,5 @@ from .model import encoding_name_for_model as encoding_name_for_model from .registry import get_encoding as get_encoding from .registry import list_encoding_names as list_encoding_names + +__version__ = "0.8.0" diff --git a/tiktoken/_educational.py b/tiktoken/_educational.py index a6f5fccd..317e7756 100644 --- a/tiktoken/_educational.py +++ b/tiktoken/_educational.py @@ -1,6 +1,8 @@ """This is an educational implementation of the byte pair encoding algorithm.""" + +from __future__ import annotations + import collections -from typing import Optional import regex @@ -18,7 +20,7 @@ def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None: self._decoder = {token: token_bytes for token_bytes, token in mergeable_ranks.items()} self._pat = regex.compile(pat_str) - def encode(self, text: str, visualise: Optional[str] = "colour") -> list[int]: + def encode(self, text: str, visualise: str | None = "colour") -> list[int]: """Encodes a string into tokens. >>> enc.encode("hello world") @@ -79,7 +81,7 @@ def from_tiktoken(encoding): def bpe_encode( - mergeable_ranks: dict[bytes, int], input: bytes, visualise: Optional[str] = "colour" + mergeable_ranks: dict[bytes, int], input: bytes, visualise: str | None = "colour" ) -> list[int]: parts = [bytes([b]) for b in input] while True: @@ -115,7 +117,7 @@ def bpe_encode( def bpe_train( - data: str, vocab_size: int, pat_str: str, visualise: Optional[str] = "colour" + data: str, vocab_size: int, pat_str: str, visualise: str | None = "colour" ) -> dict[bytes, int]: # First, add tokens for each individual byte value if vocab_size < 2**8: @@ -207,7 +209,7 @@ def train_simple_encoding(): gpt2_pattern = ( r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" ) - with open(__file__, "r") as f: + with open(__file__) as f: data = f.read() enc = SimpleBytePairEncoding.train(data, vocab_size=600, pat_str=gpt2_pattern) diff --git a/tiktoken/core.py b/tiktoken/core.py index 3e32d4ce..6dcdc327 100644 --- a/tiktoken/core.py +++ b/tiktoken/core.py @@ -2,7 +2,7 @@ import functools from concurrent.futures import ThreadPoolExecutor -from typing import AbstractSet, Collection, Literal, NoReturn, Optional, Union +from typing import AbstractSet, Collection, Literal, NoReturn, Sequence import regex @@ -17,7 +17,7 @@ def __init__( pat_str: str, mergeable_ranks: dict[bytes, int], special_tokens: dict[str, int], - explicit_n_vocab: Optional[int] = None, + explicit_n_vocab: int | None = None, ): """Creates an Encoding object. @@ -76,8 +76,8 @@ def encode( self, text: str, *, - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 - disallowed_special: Union[Literal["all"], Collection[str]] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006 + disallowed_special: Literal["all"] | Collection[str] = "all", ) -> list[int]: """Encodes a string into tokens. @@ -116,10 +116,6 @@ def encode( if match := _special_token_regex(disallowed_special).search(text): raise_disallowed_special_token(match.group()) - # https://github.com/PyO3/pyo3/pull/3632 - if isinstance(allowed_special, frozenset): - allowed_special = set(allowed_special) - try: return self._core_bpe.encode(text, allowed_special) except UnicodeEncodeError: @@ -151,8 +147,8 @@ def encode_batch( text: list[str], *, num_threads: int = 8, - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 - disallowed_special: Union[Literal["all"], Collection[str]] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006 + disallowed_special: Literal["all"] | Collection[str] = "all", ) -> list[list[int]]: """Encodes a list of strings into tokens, in parallel. @@ -180,8 +176,8 @@ def encode_with_unstable( self, text: str, *, - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 - disallowed_special: Union[Literal["all"], Collection[str]] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006 + disallowed_special: Literal["all"] | Collection[str] = "all", ) -> tuple[list[int], list[list[int]]]: """Encodes a string into stable tokens and possible completion sequences. @@ -213,7 +209,7 @@ def encode_with_unstable( return self._core_bpe.encode_with_unstable(text, allowed_special) - def encode_single_token(self, text_or_bytes: Union[str, bytes]) -> int: + def encode_single_token(self, text_or_bytes: str | bytes) -> int: """Encodes text corresponding to a single token to its token value. NOTE: this will encode all special tokens. @@ -233,7 +229,7 @@ def encode_single_token(self, text_or_bytes: Union[str, bytes]) -> int: # Decoding # ==================== - def decode_bytes(self, tokens: list[int]) -> bytes: + def decode_bytes(self, tokens: Sequence[int]) -> bytes: """Decodes a list of tokens into bytes. ``` @@ -243,7 +239,7 @@ def decode_bytes(self, tokens: list[int]) -> bytes: """ return self._core_bpe.decode_bytes(tokens) - def decode(self, tokens: list[int], errors: str = "replace") -> str: + def decode(self, tokens: Sequence[int], errors: str = "replace") -> str: """Decodes a list of tokens into a string. WARNING: the default behaviour of this function is lossy, since decoded bytes are not @@ -271,7 +267,7 @@ def decode_single_token_bytes(self, token: int) -> bytes: """ return self._core_bpe.decode_single_token_bytes(token) - def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]: + def decode_tokens_bytes(self, tokens: Sequence[int]) -> list[bytes]: """Decodes a list of tokens into a list of bytes. Useful for visualising tokenisation. @@ -280,7 +276,7 @@ def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]: """ return [self.decode_single_token_bytes(token) for token in tokens] - def decode_with_offsets(self, tokens: list[int]) -> tuple[str, list[int]]: + def decode_with_offsets(self, tokens: Sequence[int]) -> tuple[str, list[int]]: """Decodes a list of tokens into a string and a list of offsets. Each offset is the index into text corresponding to the start of each token. @@ -306,14 +302,16 @@ def decode_with_offsets(self, tokens: list[int]) -> tuple[str, list[int]]: return text, offsets def decode_batch( - self, batch: list[list[int]], *, errors: str = "replace", num_threads: int = 8 + self, batch: Sequence[Sequence[int]], *, errors: str = "replace", num_threads: int = 8 ) -> list[str]: """Decodes a batch (list of lists of tokens) into a list of strings.""" decoder = functools.partial(self.decode, errors=errors) with ThreadPoolExecutor(num_threads) as e: return list(e.map(decoder, batch)) - def decode_bytes_batch(self, batch: list[list[int]], *, num_threads: int = 8) -> list[bytes]: + def decode_bytes_batch( + self, batch: Sequence[Sequence[int]], *, num_threads: int = 8 + ) -> list[bytes]: """Decodes a batch (list of lists of tokens) into a list of bytes.""" with ThreadPoolExecutor(num_threads) as e: return list(e.map(self.decode_bytes, batch)) @@ -343,7 +341,7 @@ def n_vocab(self) -> int: # Private # ==================== - def _encode_single_piece(self, text_or_bytes: Union[str, bytes]) -> list[int]: + def _encode_single_piece(self, text_or_bytes: str | bytes) -> list[int]: """Encodes text corresponding to bytes without a regex split. NOTE: this will not encode any special tokens. diff --git a/tiktoken/load.py b/tiktoken/load.py index cc0a6a6d..8434c234 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -6,7 +6,6 @@ import os import tempfile import uuid -from typing import Optional import requests @@ -32,7 +31,7 @@ def check_hash(data: bytes, expected_hash: str) -> bool: return actual_hash == expected_hash -def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> bytes: +def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes: user_specified_cache = True if "TIKTOKEN_CACHE_DIR" in os.environ: cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] @@ -85,8 +84,8 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> byte def data_gym_to_mergeable_bpe_ranks( vocab_bpe_file: str, encoder_json_file: str, - vocab_bpe_hash: Optional[str] = None, - encoder_json_hash: Optional[str] = None, + vocab_bpe_hash: str | None = None, + encoder_json_hash: str | None = None, ) -> dict[bytes, int]: # NB: do not add caching to this function rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] @@ -140,9 +139,7 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n") -def load_tiktoken_bpe( - tiktoken_bpe_file: str, expected_hash: Optional[str] = None -) -> dict[bytes, int]: +def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) -> dict[bytes, int]: # NB: do not add caching to this function contents = read_file_cached(tiktoken_bpe_file, expected_hash) return { diff --git a/tiktoken/model.py b/tiktoken/model.py index 6ecd7232..681b9131 100644 --- a/tiktoken/model.py +++ b/tiktoken/model.py @@ -5,7 +5,9 @@ # TODO: these will likely be replaced by an API endpoint MODEL_PREFIX_TO_ENCODING: dict[str, str] = { + "o1-": "o200k_base", # chat + "chatgpt-4o-": "o200k_base", "gpt-4o-": "o200k_base", # e.g., gpt-4o-2024-05-13 "gpt-4-": "cl100k_base", # e.g., gpt-4-0314, etc., plus gpt-4-32k "gpt-3.5-turbo-": "cl100k_base", # e.g, gpt-3.5-turbo-0301, -0401, etc. diff --git a/tiktoken/registry.py b/tiktoken/registry.py index a753ce67..17c4574f 100644 --- a/tiktoken/registry.py +++ b/tiktoken/registry.py @@ -4,18 +4,19 @@ import importlib import pkgutil import threading -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Sequence import tiktoken_ext +import tiktoken from tiktoken.core import Encoding _lock = threading.RLock() ENCODINGS: dict[str, Encoding] = {} -ENCODING_CONSTRUCTORS: Optional[dict[str, Callable[[], dict[str, Any]]]] = None +ENCODING_CONSTRUCTORS: dict[str, Callable[[], dict[str, Any]]] | None = None -@functools.lru_cache() +@functools.lru_cache def _available_plugin_modules() -> Sequence[str]: # tiktoken_ext is a namespace package # submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes @@ -36,23 +37,33 @@ def _find_constructors() -> None: return ENCODING_CONSTRUCTORS = {} - for mod_name in _available_plugin_modules(): - mod = importlib.import_module(mod_name) - try: - constructors = mod.ENCODING_CONSTRUCTORS - except AttributeError as e: - raise ValueError( - f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS" - ) from e - for enc_name, constructor in constructors.items(): - if enc_name in ENCODING_CONSTRUCTORS: + try: + for mod_name in _available_plugin_modules(): + mod = importlib.import_module(mod_name) + try: + constructors = mod.ENCODING_CONSTRUCTORS + except AttributeError as e: raise ValueError( - f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}" - ) - ENCODING_CONSTRUCTORS[enc_name] = constructor + f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS" + ) from e + for enc_name, constructor in constructors.items(): + if enc_name in ENCODING_CONSTRUCTORS: + raise ValueError( + f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}" + ) + ENCODING_CONSTRUCTORS[enc_name] = constructor + except Exception: + # Ensure we idempotently raise errors + ENCODING_CONSTRUCTORS = None + raise + + def get_encoding(encoding_name: str) -> Encoding: + if not isinstance(encoding_name, str): + raise ValueError(f"Expected a string in get_encoding, got {type(encoding_name)}") + if encoding_name in ENCODINGS: return ENCODINGS[encoding_name] @@ -66,7 +77,9 @@ def get_encoding(encoding_name: str) -> Encoding: if encoding_name not in ENCODING_CONSTRUCTORS: raise ValueError( - f"Unknown encoding {encoding_name}. Plugins found: {_available_plugin_modules()}" + f"Unknown encoding {encoding_name}.\n" + f"Plugins found: {_available_plugin_modules()}\n" + f"tiktoken version: {tiktoken.__version__} (are you on latest?)" ) constructor = ENCODING_CONSTRUCTORS[encoding_name] diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py index 449ec068..f2599f16 100644 --- a/tiktoken_ext/openai_public.py +++ b/tiktoken_ext/openai_public.py @@ -9,7 +9,9 @@ # The pattern in the original GPT-2 release is: # r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" # This is equivalent, but executes faster: -_legacy_splitter_regex = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s""" +r50k_pat_str = ( + r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s""" +) def gpt2(): @@ -22,7 +24,7 @@ def gpt2(): return { "name": "gpt2", "explicit_n_vocab": 50257, - "pat_str": _legacy_splitter_regex, + "pat_str": r50k_pat_str, "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -36,7 +38,7 @@ def r50k_base(): return { "name": "r50k_base", "explicit_n_vocab": 50257, - "pat_str": _legacy_splitter_regex, + "pat_str": r50k_pat_str, "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -50,7 +52,7 @@ def p50k_base(): return { "name": "p50k_base", "explicit_n_vocab": 50281, - "pat_str": _legacy_splitter_regex, + "pat_str": r50k_pat_str, "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -64,7 +66,7 @@ def p50k_edit(): special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} return { "name": "p50k_edit", - "pat_str": _legacy_splitter_regex, + "pat_str": r50k_pat_str, "mergeable_ranks": mergeable_ranks, "special_tokens": special_tokens, } @@ -95,11 +97,10 @@ def o200k_base(): "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken", expected_hash="446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d", ) - special_tokens = { - ENDOFTEXT: 199999, - ENDOFPROMPT: 200018, - } - # This regex could be made more efficient + special_tokens = {ENDOFTEXT: 199999, ENDOFPROMPT: 200018} + # This regex could be made more efficient. If I was the one working on this encoding, I would + # have done a few other things differently too, e.g. I think you can allocate tokens more + # efficiently across languages. pat_str = "|".join( [ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",