From 26998f13682ce34190d548fa902a7b65f7fc691c Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Fri, 21 Jun 2024 21:49:38 +0200 Subject: [PATCH] New attempt at finding best effort binary search window --- benches/chunk_size.rs | 111 ++++++++++++++++++---------------------- src/chunk_size.rs | 6 ++- src/splitter.rs | 116 +++++++++++++++++++++++++++++++++--------- 3 files changed, 145 insertions(+), 88 deletions(-) diff --git a/benches/chunk_size.rs b/benches/chunk_size.rs index 83d407b8..f488069f 100644 --- a/benches/chunk_size.rs +++ b/benches/chunk_size.rs @@ -1,9 +1,11 @@ #![allow(missing_docs)] -use std::path::PathBuf; +use std::{fs, path::PathBuf}; +use ahash::AHashMap; use cached_path::Cache; use divan::AllocProfiler; +use once_cell::sync::Lazy; #[global_allocator] static ALLOC: AllocProfiler = AllocProfiler::system(); @@ -30,16 +32,46 @@ fn download_file_to_cache(src: &str) -> PathBuf { .unwrap() } +const TEXT_FILENAMES: &[&str] = &["romeo_and_juliet", "room_with_a_view"]; +const MARKDOWN_FILENAMES: &[&str] = &["commonmark_spec"]; +const CODE_FILENAMES: &[&str] = &["hashbrown_set_rs"]; + +static FILES: Lazy> = Lazy::new(|| { + let mut m = AHashMap::new(); + for &name in TEXT_FILENAMES { + m.insert( + name, + fs::read_to_string(format!("tests/inputs/text/{name}.txt")).unwrap(), + ); + } + for &name in MARKDOWN_FILENAMES { + m.insert( + name, + fs::read_to_string(format!("tests/inputs/markdown/{name}.md")).unwrap(), + ); + } + for &name in CODE_FILENAMES { + m.insert( + name, + fs::read_to_string(format!("tests/inputs/code/{name}.txt")).unwrap(), + ); + } + m +}); + +static BERT_TOKENIZER: Lazy = Lazy::new(|| { + let vocab_path = download_file_to_cache( + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + ); + rust_tokenizers::tokenizer::BertTokenizer::from_file(vocab_path, false, false).unwrap() +}); + #[divan::bench_group] mod text { - use std::fs; - use divan::{black_box_drop, counter::BytesCount, Bencher}; use text_splitter::{ChunkConfig, ChunkSizer, TextSplitter}; - use crate::CHUNK_SIZES; - - const TEXT_FILENAMES: &[&str] = &["romeo_and_juliet", "room_with_a_view"]; + use crate::{CHUNK_SIZES, FILES, TEXT_FILENAMES}; fn bench(bencher: Bencher<'_, '_>, filename: &str, gen_splitter: G) where @@ -47,12 +79,7 @@ mod text { S: ChunkSizer, { bencher - .with_inputs(|| { - ( - gen_splitter(), - fs::read_to_string(format!("tests/inputs/text/{filename}.txt")).unwrap(), - ) - }) + .with_inputs(|| (gen_splitter(), FILES.get(filename).unwrap().clone())) .input_counter(|(_, text)| BytesCount::of_str(text)) .bench_values(|(splitter, text)| { splitter.chunks(&text).for_each(black_box_drop); @@ -86,18 +113,10 @@ mod text { #[cfg(feature = "rust-tokenizers")] #[divan::bench(args = TEXT_FILENAMES, consts = CHUNK_SIZES)] fn rust_tokenizers(bencher: Bencher<'_, '_>, filename: &str) { - use crate::download_file_to_cache; + use crate::BERT_TOKENIZER; bench(bencher, filename, || { - let vocab_path = download_file_to_cache( - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", - ); - TextSplitter::new( - ChunkConfig::new(N).with_sizer( - rust_tokenizers::tokenizer::BertTokenizer::from_file(vocab_path, false, false) - .unwrap(), - ), - ) + TextSplitter::new(ChunkConfig::new(N).with_sizer(&*BERT_TOKENIZER)) }); } } @@ -105,14 +124,10 @@ mod text { #[cfg(feature = "markdown")] #[divan::bench_group] mod markdown { - use std::fs; - use divan::{black_box_drop, counter::BytesCount, Bencher}; use text_splitter::{ChunkConfig, ChunkSizer, MarkdownSplitter}; - use crate::CHUNK_SIZES; - - const MARKDOWN_FILENAMES: &[&str] = &["commonmark_spec"]; + use crate::{CHUNK_SIZES, FILES, MARKDOWN_FILENAMES}; fn bench(bencher: Bencher<'_, '_>, filename: &str, gen_splitter: G) where @@ -120,12 +135,7 @@ mod markdown { S: ChunkSizer, { bencher - .with_inputs(|| { - ( - gen_splitter(), - fs::read_to_string(format!("tests/inputs/markdown/{filename}.md")).unwrap(), - ) - }) + .with_inputs(|| (gen_splitter(), FILES.get(filename).unwrap().clone())) .input_counter(|(_, text)| BytesCount::of_str(text)) .bench_values(|(splitter, text)| { splitter.chunks(&text).for_each(black_box_drop); @@ -159,18 +169,10 @@ mod markdown { #[cfg(feature = "rust-tokenizers")] #[divan::bench(args = MARKDOWN_FILENAMES, consts = CHUNK_SIZES)] fn rust_tokenizers(bencher: Bencher<'_, '_>, filename: &str) { - use crate::download_file_to_cache; + use crate::BERT_TOKENIZER; bench(bencher, filename, || { - let vocab_path = download_file_to_cache( - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", - ); - MarkdownSplitter::new( - ChunkConfig::new(N).with_sizer( - rust_tokenizers::tokenizer::BertTokenizer::from_file(vocab_path, false, false) - .unwrap(), - ), - ) + MarkdownSplitter::new(ChunkConfig::new(N).with_sizer(&*BERT_TOKENIZER)) }); } } @@ -178,14 +180,10 @@ mod markdown { #[cfg(feature = "code")] #[divan::bench_group] mod code { - use std::fs; - use divan::{black_box_drop, counter::BytesCount, Bencher}; use text_splitter::{ChunkConfig, ChunkSizer, CodeSplitter}; - use crate::CHUNK_SIZES; - - const CODE_FILENAMES: &[&str] = &["hashbrown_set_rs"]; + use crate::{CHUNK_SIZES, CODE_FILENAMES, FILES}; fn bench(bencher: Bencher<'_, '_>, filename: &str, gen_splitter: G) where @@ -193,12 +191,7 @@ mod code { S: ChunkSizer, { bencher - .with_inputs(|| { - ( - gen_splitter(), - fs::read_to_string(format!("tests/inputs/code/{filename}.txt")).unwrap(), - ) - }) + .with_inputs(|| (gen_splitter(), FILES.get(filename).unwrap().clone())) .input_counter(|(_, text)| BytesCount::of_str(text)) .bench_values(|(splitter, text)| { splitter.chunks(&text).for_each(black_box_drop); @@ -240,18 +233,12 @@ mod code { #[cfg(feature = "rust-tokenizers")] #[divan::bench(args = CODE_FILENAMES, consts = CHUNK_SIZES)] fn rust_tokenizers(bencher: Bencher<'_, '_>, filename: &str) { - use crate::download_file_to_cache; + use crate::BERT_TOKENIZER; bench(bencher, filename, || { - let vocab_path = download_file_to_cache( - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", - ); CodeSplitter::new( tree_sitter_rust::language(), - ChunkConfig::new(N).with_sizer( - rust_tokenizers::tokenizer::BertTokenizer::from_file(vocab_path, false, false) - .unwrap(), - ), + ChunkConfig::new(N).with_sizer(&*BERT_TOKENIZER), ) .unwrap() }); diff --git a/src/chunk_size.rs b/src/chunk_size.rs index cb3b139b..075d8b7f 100644 --- a/src/chunk_size.rs +++ b/src/chunk_size.rs @@ -434,8 +434,9 @@ where &mut self, offset: usize, levels_with_first_chunk: impl Iterator, - ) -> Option { + ) -> (Option, Option) { let mut semantic_level = None; + let mut max_offset = None; // We assume that larger levels are also longer. We can skip lower levels if going to a higher level would result in a shorter text let levels_with_first_chunk = @@ -451,13 +452,14 @@ where let chunk_size = self.check_capacity(offset, str, false); // If this no longer fits, we use the level we are at. if chunk_size.fits.is_gt() { + max_offset = Some(offset + str.len()); break; } // Otherwise break up the text with the next level semantic_level = Some(level); } - semantic_level + (semantic_level, max_offset) } /// Clear the cached values. Once we've moved the cursor, diff --git a/src/splitter.rs b/src/splitter.rs index d1b9d347..3f4cbf21 100644 --- a/src/splitter.rs +++ b/src/splitter.rs @@ -248,6 +248,8 @@ where semantic_split: SemanticSplitRanges, /// Original text to iterate over and generate chunks from text: &'text str, + /// Average number of sections in a chunk for each level + chunk_stats: ChunkStats, } impl<'sizer, 'text: 'sizer, Sizer, Level> TextChunks<'text, 'sizer, Sizer, Level> @@ -271,6 +273,7 @@ where prev_item_end: 0, semantic_split: SemanticSplitRanges::new(offsets), text, + chunk_stats: ChunkStats::new(), } } @@ -289,6 +292,9 @@ where self.update_cursor(end); let chunk = self.text.get(start..end)?; + + self.chunk_stats.update_max_chunk_size(end - start); + // Trim whitespace if user requested it Some(self.chunk_sizer.trim_chunk(start, chunk)) } @@ -426,7 +432,7 @@ where let remaining_text = self.text.get(self.cursor..).unwrap(); - let semantic_level = self.chunk_sizer.find_correct_level( + let (semantic_level, mut max_offset) = self.chunk_sizer.find_correct_level( self.cursor, self.semantic_split .levels_in_remaining_text(self.cursor) @@ -438,14 +444,14 @@ where }), ); - let mut sections = if let Some(semantic_level) = semantic_level { - Either::Left( - self.semantic_split - .semantic_chunks(self.cursor, remaining_text, semantic_level) - .filter(|(_, str)| !str.is_empty()), - ) + let sections = if let Some(semantic_level) = semantic_level { + Either::Left(self.semantic_split.semantic_chunks( + self.cursor, + remaining_text, + semantic_level, + )) } else { - let semantic_level = self.chunk_sizer.find_correct_level( + let (semantic_level, fallback_max_offset) = self.chunk_sizer.find_correct_level( self.cursor, FallbackLevel::iter().filter_map(|level| { level @@ -455,32 +461,45 @@ where }), ); + max_offset = match (fallback_max_offset, max_offset) { + (Some(fallback), Some(max)) => Some(fallback.min(max)), + (fallback, max) => fallback.or(max), + }; + + let fallback_level = semantic_level.unwrap_or(FallbackLevel::Char); + Either::Right( - semantic_level - .unwrap_or(FallbackLevel::Char) + fallback_level .sections(remaining_text) - .map(|(offset, text)| (self.cursor + offset, text)) - .filter(|(_, str)| !str.is_empty()), + .map(|(offset, text)| (self.cursor + offset, text)), ) }; + let mut sections = sections + .take_while(move |(offset, _)| max_offset.map_or(true, |max| *offset <= max)) + .filter(|(_, str)| !str.is_empty()); + // Start filling up the next sections. Since calculating the size of the chunk gets more expensive // the farther we go, we conservatively check for a smaller range to do the later binary search in. let mut low = 0; let mut prev_equals: Option = None; - let max_chunk_size = self.chunk_config.capacity().max(); - let mut num_sections = 3; + let max = self.chunk_config.capacity().max(); + let mut target_offset = self.chunk_stats.max_chunk_size.unwrap_or(max); loop { let prev_num = self.next_sections.len(); - // Default to at least several items for the binary search - self.next_sections - .extend((0..num_sections).map_while(|_| sections.next())); + for (offset, str) in sections.by_ref() { + self.next_sections.push((offset, str)); + if offset + str.len() > (self.cursor.saturating_add(target_offset)) { + break; + } + } let new_num = self.next_sections.len(); // If we've iterated through the whole iterator, break here. - if new_num - prev_num < num_sections { + if new_num - prev_num == 0 { break; } + // Check if the last item fits if let Some(&(offset, str)) = self.next_sections.last() { let text_end = offset + str.len(); @@ -489,13 +508,22 @@ where self.text.get(self.cursor..text_end).expect("Invalid range"), false, ); - // Average size of each section - let average_size = (chunk_size.size() / num_sections).max(1); - num_sections = (max_chunk_size / average_size + 1) - .saturating_sub(num_sections) - .max(1); - match chunk_size.fits() { + let fits = chunk_size.fits(); + if fits.is_le() { + let final_offset = offset + str.len() - self.cursor; + let size = chunk_size.size().max(1); + let diff = (max - size).max(1); + let avg_size = (final_offset / size) + 1; + + target_offset = final_offset.saturating_add( + diff.saturating_mul(avg_size) + .max(final_offset / 10) + .saturating_add(1), + ); + } + + match fits { Ordering::Less => { // We know we can go higher low = new_num.saturating_sub(1); @@ -555,10 +583,50 @@ where } } +/// Keeps track of the average size of chunks as we go +#[derive(Debug, Default)] +struct ChunkStats { + /// The size of the biggest chunk we've seen, if we have seen at least one + max_chunk_size: Option, +} + +impl ChunkStats { + fn new() -> Self { + Self::default() + } + + /// Update statistics after the chunk has been produced + fn update_max_chunk_size(&mut self, size: usize) { + self.max_chunk_size = self.max_chunk_size.map(|s| s.max(size)).or(Some(size)); + } +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn chunk_stats_empty() { + let stats = ChunkStats::new(); + assert_eq!(stats.max_chunk_size, None); + } + + #[test] + fn chunk_stats_one() { + let mut stats = ChunkStats::new(); + stats.update_max_chunk_size(10); + assert_eq!(stats.max_chunk_size, Some(10)); + } + + #[test] + fn chunk_stats_multiple() { + let mut stats = ChunkStats::new(); + stats.update_max_chunk_size(10); + stats.update_max_chunk_size(20); + stats.update_max_chunk_size(30); + assert_eq!(stats.max_chunk_size, Some(30)); + } + impl SemanticLevel for usize {} #[test]