Skip to content

Commit

Permalink
Remove need for ChunkSize in public interface
Browse files Browse the repository at this point in the history
  • Loading branch information
benbrandt committed Jun 21, 2024
1 parent 887f672 commit c11dad3
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 131 deletions.
21 changes: 20 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,26 @@

#### Rust

- `ChunkSize::from_offsets` was removed. This was only used to create an internal optimization, which turned out to not be very accurate anyway. It often required in tokenization implementations to do more work to calculate the size as well, which is no longer necessary. It should be simple to convert to the `ChunkSize::from_size` method (and likely simplify your code as well), which is now the only way to create a `ChunkSize`.
- `ChunkSize` has been removed. This was a holdover from a previous internal optimization, which turned out to not be very accurate anyway.
- This makes implementing a custom `ChunkSizer` much easier, as you now only need to generate the size of the chunk as a `usize`. It often required in tokenization implementations to do more work to calculate the size as well, which is no longer necessary.

##### Before

```rust
pub trait ChunkSizer {
// Required method
fn chunk_size(&self, chunk: &str, capacity: &ChunkCapacity) -> ChunkSize;
}
```

##### After

```rust
pub trait ChunkSizer {
// Required method
fn size(&self, chunk: &str) -> usize;
}
```

## v0.13.3

Expand Down
13 changes: 5 additions & 8 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use pyo3::{
pybacked::PyBackedStr,
};
use text_splitter::{
Characters, ChunkCapacity, ChunkCapacityError, ChunkConfig, ChunkConfigError, ChunkSize,
ChunkSizer, CodeSplitter, CodeSplitterError, MarkdownSplitter, TextSplitter,
Characters, ChunkCapacity, ChunkCapacityError, ChunkConfig, ChunkConfigError, ChunkSizer,
CodeSplitter, CodeSplitterError, MarkdownSplitter, TextSplitter,
};
use tiktoken_rs::{get_bpe_from_model, CoreBPE};
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -88,16 +88,13 @@ struct CustomCallback(PyObject);

impl ChunkSizer for CustomCallback {
/// Determine the size of a given chunk to use for validation
fn chunk_size(&self, chunk: &str, capacity: &ChunkCapacity) -> ChunkSize {
fn size(&self, chunk: &str) -> usize {
Python::with_gil(|py| {
let size = self
.0
self.0
.call_bound(py, (chunk,), None)
.unwrap()
.extract::<usize>(py)
.unwrap();

ChunkSize::from_size(size, capacity)
.unwrap()
})
}
}
Expand Down
125 changes: 80 additions & 45 deletions src/chunk_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ impl ChunkCapacity {
Ordering::Equal
}
}

/// Generates a chunk size object based on the size provided from a sizer
/// Calculates and stores whether or not it fits within the capacity
#[must_use]
fn chunk_size(&self, size: usize) -> ChunkSize {
ChunkSize::new(self.fits(size), size)
}
}

impl From<usize> for ChunkCapacity {
Expand Down Expand Up @@ -198,14 +205,9 @@ pub struct ChunkSize {
}

impl ChunkSize {
/// Generate a chunk size from a given size. Will not be able to compute the
/// max byte offset that fits within the capacity.
#[must_use]
pub fn from_size(size: usize, capacity: &ChunkCapacity) -> Self {
Self {
fits: capacity.fits(size),
size,
}
fn new(fits: Ordering, size: usize) -> Self {
Self { fits, size }
}

/// Determine whether the chunk size fits within the capacity or not
Expand All @@ -224,7 +226,7 @@ impl ChunkSize {
/// Determines the size of a given chunk.
pub trait ChunkSizer {
/// Determine the size of a given chunk to use for validation
fn chunk_size(&self, chunk: &str, capacity: &ChunkCapacity) -> ChunkSize;
fn size(&self, chunk: &str) -> usize;
}

/// Indicates there was an error with the chunk configuration.
Expand Down Expand Up @@ -409,7 +411,7 @@ where

*cache
.entry(offset..(offset + chunk.len()))
.or_insert_with(|| self.chunk_config.sizer.chunk_size(chunk, &capacity))
.or_insert_with(|| capacity.chunk_size(self.chunk_config.sizer.size(chunk)))
}

/// Check if the chunk is within the capacity. Chunk should be trimmed if necessary beforehand.
Expand Down Expand Up @@ -479,34 +481,51 @@ mod tests {
let chunk = "12345";

assert_eq!(
Characters.chunk_size(chunk, &4.into()).fits,
ChunkCapacity::from(4)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Greater
);
assert_eq!(
Characters.chunk_size(chunk, &5.into()).fits,
ChunkCapacity::from(5)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Equal
);
assert_eq!(Characters.chunk_size(chunk, &6.into()).fits, Ordering::Less);
assert_eq!(
ChunkCapacity::from(6)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Less
);
}

#[test]
fn check_chunk_capacity_for_range() {
let chunk = "12345";

assert_eq!(
Characters.chunk_size(chunk, &(0..0).into()).fits,
ChunkCapacity::from(0..0)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Greater
);
assert_eq!(
Characters.chunk_size(chunk, &(0..5).into()).fits,
ChunkCapacity::from(0..5)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Greater
);
assert_eq!(
Characters.chunk_size(chunk, &(5..6).into()).fits,
ChunkCapacity::from(5..6)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Equal
);
assert_eq!(
Characters.chunk_size(chunk, &(6..100).into()).fits,
ChunkCapacity::from(6..100)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Less
);
}
Expand All @@ -516,15 +535,21 @@ mod tests {
let chunk = "12345";

assert_eq!(
Characters.chunk_size(chunk, &(0..).into()).fits,
ChunkCapacity::from(0..)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Equal
);
assert_eq!(
Characters.chunk_size(chunk, &(5..).into()).fits,
ChunkCapacity::from(5..)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Equal
);
assert_eq!(
Characters.chunk_size(chunk, &(6..).into()).fits,
ChunkCapacity::from(6..)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Less
);
}
Expand All @@ -534,7 +559,9 @@ mod tests {
let chunk = "12345";

assert_eq!(
Characters.chunk_size(chunk, &(..).into()).fits,
ChunkCapacity::from(..)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Equal
);
}
Expand All @@ -544,19 +571,27 @@ mod tests {
let chunk = "12345";

assert_eq!(
Characters.chunk_size(chunk, &(0..=4).into()).fits,
ChunkCapacity::from(0..=4)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Greater
);
assert_eq!(
Characters.chunk_size(chunk, &(5..=6).into()).fits,
ChunkCapacity::from(5..=6)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Equal
);
assert_eq!(
Characters.chunk_size(chunk, &(4..=5).into()).fits,
ChunkCapacity::from(4..=5)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Equal
);
assert_eq!(
Characters.chunk_size(chunk, &(6..=100).into()).fits,
ChunkCapacity::from(6..=100)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Less
);
}
Expand All @@ -566,15 +601,21 @@ mod tests {
let chunk = "12345";

assert_eq!(
Characters.chunk_size(chunk, &(..0).into()).fits,
ChunkCapacity::from(..0)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Greater
);
assert_eq!(
Characters.chunk_size(chunk, &(..5).into()).fits,
ChunkCapacity::from(..5)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Greater
);
assert_eq!(
Characters.chunk_size(chunk, &(..6).into()).fits,
ChunkCapacity::from(..6)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Equal
);
}
Expand All @@ -584,15 +625,21 @@ mod tests {
let chunk = "12345";

assert_eq!(
Characters.chunk_size(chunk, &(..=4).into()).fits,
ChunkCapacity::from(..=4)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Greater
);
assert_eq!(
Characters.chunk_size(chunk, &(..=5).into()).fits,
ChunkCapacity::from(..=5)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Equal
);
assert_eq!(
Characters.chunk_size(chunk, &(..=6).into()).fits,
ChunkCapacity::from(..=6)
.chunk_size(Characters.size(chunk))
.fits,
Ordering::Equal
);
}
Expand All @@ -604,9 +651,9 @@ mod tests {

impl ChunkSizer for CountingSizer {
// Return character version, but count calls
fn chunk_size(&self, chunk: &str, capacity: &ChunkCapacity) -> ChunkSize {
fn size(&self, chunk: &str) -> usize {
self.calls.fetch_add(1, atomic::Ordering::SeqCst);
Characters.chunk_size(chunk, capacity)
Characters.size(chunk)
}
}

Expand Down Expand Up @@ -668,18 +715,6 @@ mod tests {
);
}

#[test]
fn test_chunk_size_from_size() {
let chunk_size = ChunkSize::from_size(10, &10.into());
assert_eq!(
ChunkSize {
fits: Ordering::Equal,
size: 10,
},
chunk_size
);
}

#[test]
fn basic_chunk_config() {
let config = ChunkConfig::new(10);
Expand All @@ -700,7 +735,7 @@ mod tests {
struct BasicSizer;

impl ChunkSizer for BasicSizer {
fn chunk_size(&self, _chunk: &str, _capacity: &ChunkCapacity) -> ChunkSize {
fn size(&self, _chunk: &str) -> usize {
unimplemented!()
}
}
Expand Down
11 changes: 5 additions & 6 deletions src/chunk_size/characters.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{ChunkCapacity, ChunkSize, ChunkSizer};
use crate::ChunkSizer;

/// Used for splitting a piece of text into chunks based on the number of
/// characters in each chunk.
Expand All @@ -14,8 +14,8 @@ pub struct Characters;

impl ChunkSizer for Characters {
/// Determine the size of a given chunk to use for validation.
fn chunk_size(&self, chunk: &str, capacity: &ChunkCapacity) -> ChunkSize {
ChunkSize::from_size(chunk.chars().count(), capacity)
fn size(&self, chunk: &str) -> usize {
chunk.chars().count()
}
}

Expand All @@ -25,8 +25,7 @@ mod tests {

#[test]
fn returns_size() {
let capacity = 10;
let offsets = Characters.chunk_size("eé", &capacity.into());
assert_eq!(offsets, ChunkSize::from_size(2, &capacity.into()));
let offsets = Characters.size("eé");
assert_eq!(offsets, 2);
}
}
Loading

0 comments on commit c11dad3

Please sign in to comment.