Skip to content

Commit

Permalink
Simplify and optimize _byte_pair_merge
Browse files Browse the repository at this point in the history
We're already calculating the min during construction of the first pairs. Because of this minimum calculation is moved to the end of the loop.

Since we've filtered out single tokens, we can safely exit when the parts length is already small enough
  • Loading branch information
Lőrinc committed Jan 15, 2024
1 parent 7398253 commit 8f5dd7d
Showing 1 changed file with 24 additions and 65 deletions.
89 changes: 24 additions & 65 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,78 +19,37 @@ fn _byte_pair_merge(
ranks: &HashMap<Vec<u8>, Rank>,
piece: &[u8],
) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the byte pair starting at position start.
// The rank of the last item in the vector is not a valid value.
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();

let get_rank = {
#[inline(always)]
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
if (start_idx + skip + 2) < parts.len() {
ranks
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
.copied()
} else {
None
}
}
let get_rank = |parts: &Vec<(usize, _)>, start_idx: usize, end_idx: usize| {
*parts.get(end_idx)
.map(|e| parts.get(start_idx).unwrap().0..e.0)
.and_then(|r| piece.get(r))
.filter(|p| p.len() < piece.len())
.and_then(|p| ranks.get(p))
.unwrap_or(&Rank::MAX)
};

// We look up the ranks once in the beginning and iteratively update
// them during each merge, which reduces the number of rank lookups.
for i in 0..parts.len() - 2 {
match get_rank(&parts, i, 0) {
Some(rank) => {
// Rank::MAX is a sentinel value and cannot be a valid rank
debug_assert!(rank != Rank::MAX);
parts[i].1 = rank;
}
None => {
continue;
}
};
}

// If you have n parts and m merges, this does O(mn) work.
// We could do something with a heap and do O(m log n) work.
// It is important to consider that n is often small (<100), and as such
// the cache-locality benefits outweigh the algorithmic complexity downsides
// of the `parts` vector data structure above.

// Note that we hash bytes, not token pairs. As long as we train BPE the way we
// currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
loop {
if parts.len() == 1 {
break;
let (mut min_rank_index, mut min_rank) = (0, Rank::MAX);
let mut parts = Vec::with_capacity(piece.len() + 1);
for i in 0..piece.len() + 1 {
let part = (i, *piece.get(i..i + 2).and_then(|p| ranks.get(p)).unwrap_or(&Rank::MAX));
if part.1 < min_rank {
(min_rank_index, min_rank) = part;
}
parts.push(part);
}

// Rank::MAX is a sentinel rank value allowing us to
// take the min more quickly
let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
while parts.len() > 3 && min_rank != Rank::MAX {
if min_rank_index > 0 {
parts[min_rank_index - 1].1 = get_rank(&parts, min_rank_index - 1, min_rank_index + 2);
}
parts[min_rank_index].1 = get_rank(&parts, min_rank_index, min_rank_index + 3);
parts.remove(min_rank_index + 1);

if min_rank.0 != Rank::MAX {
let i = min_rank.1;

// NOTE: We are about to remove parts[i + 1]. We do not do it
// yet because there are cache-locality benefits to updating
// parts[i] and parts[i-1] before removing, which could thrash
// the cache. Thus, we update the rank calculation by skipping over
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
(min_rank_index, min_rank) = (0, parts[0].1);
for i in 1..parts.len() - 2 {
if parts[i].1 < min_rank {
(min_rank_index, min_rank) = (i, parts[i].1);
}

parts.remove(i + 1);
} else {
break;
}
}

Expand Down

0 comments on commit 8f5dd7d

Please sign in to comment.