Skip to content

Commit

Permalink
add short-circuiting checks to reduce occurrences of bisection search
Browse files Browse the repository at this point in the history
  • Loading branch information
jlogan03 committed Nov 13, 2023
1 parent 3323409 commit 21b4af1
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions interpn/src/multilinear_rectilinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub struct RectilinearGridInterpolator<'a, T: Float, const MAXDIMS: usize> {
/// Cumulative products of higher dimensions, used for indexing
dimprod: [usize; MAXDIMS],

inds: [usize; MAXDIMS],

/// Values at each point, size prod(dims)
vals: &'a [T],
}
Expand Down Expand Up @@ -43,18 +45,21 @@ where
acc *= dims[ndims - i - 1];
});

let inds = [0; MAXDIMS];

Self {
grids,
dims,
dimprod,
inds,
vals,
}
}

/// Interpolate on interleaved list of points.
/// Assumes C-style ordering of points ([x0, y0], [x0, y1], ..., [x0, yn], [x1, y0], ...).
#[inline(always)]
pub fn interp(&self, x: &[T], out: &mut [T]) {
pub fn interp(&mut self, x: &[T], out: &mut [T]) {
let n = out.len();
let ndims = self.grids.len();
assert!(x.len() % ndims == 0, "Dimension mismatch");
Expand All @@ -78,7 +83,7 @@ where
/// * If the index along any dimension exceeds the maximum representable
/// integer value within the value type `T`
#[inline(always)]
pub fn interp_one(&self, x: &[T]) -> T {
pub fn interp_one(&mut self, x: &[T]) -> T {
// Check sizes
let ndims = self.grids.len();
assert!(x.len() == ndims && ndims <= MAXDIMS, "Dimension mismatch");
Expand All @@ -87,7 +92,7 @@ where
// This storage _could_ be initialized with the interpolator struct, but
// this would then require that every usage of struct be `mut`, which is
// ergonomically unpleasant.
let inds = &mut [0_usize; MAXDIMS][..ndims]; // Indices of lower corner of hypercube
// let inds = &mut [0_usize; MAXDIMS][..ndims]; // Indices of lower corner of hypercube
let ioffs = &mut [false; MAXDIMS][..ndims]; // Offset index for selected vertex
let sat = &mut [0_u8; MAXDIMS][..ndims]; // Saturated-low flag

Expand All @@ -105,15 +110,15 @@ where

// Populate lower corner
for i in 0..ndims {
(inds[i], sat[i]) = self.get_loc(x[i], i)
(self.inds[i], sat[i]) = self.get_loc(x[i], i, self.inds[i])
}

// Check if any dimension is saturated.
// This gives a ~15% overall speedup for points on the interior.
let any_dims_saturated = (0..ndims).any(|j| sat[j] != 0);

// Calculate the total volume of this cell
let cell_vol = self.get_cell(inds, steps);
let cell_vol = self.get_cell(&self.inds, steps);

// Traverse vertices, summing contributions to the interpolated value.
//
Expand Down Expand Up @@ -144,7 +149,7 @@ where
// saturating to the bound if the resulting index would be outside.
for j in 0..ndims {
k += self.dimprod[j]
* (inds[j] + ioffs[j] as usize).min(self.dims[j].saturating_sub(1));
* (self.inds[j] + ioffs[j] as usize).min(self.dims[j].saturating_sub(1));
}

// Get the value at this vertex
Expand All @@ -153,7 +158,7 @@ where
// Accumulate the volume of the prism formed by the
// observation location and the opposing vertex
for j in 0..ndims {
let iloc = inds[j] + !ioffs[j] as usize; // Index of location of opposite vertex
let iloc = self.inds[j] + !ioffs[j] as usize; // Index of location of opposite vertex
let loc = self.grids[j][iloc]; // Loc. of opposite vertex
dxs[j] = loc;
}
Expand Down Expand Up @@ -256,8 +261,10 @@ where
/// Unfortunately, using a repr(u8) enum for the saturation flag
/// causes a significant perf hit.
#[inline(always)]
fn get_loc(&self, v: T, dim: usize) -> (usize, u8) {
fn get_loc(&self, v: T, dim: usize, guess: usize) -> (usize, u8) {
let grid = self.grids[dim];
let saturation: u8; // Saturated low/high/not at all
let iloc: isize; // Signed integer index location of this point

// Bisection search to find location on the grid.
//
Expand All @@ -270,7 +277,23 @@ where
//
// This process accounts for essentially the entire difference in
// performance between this method and the regular-grid method.
let iloc = self.grids[dim].partition_point(|x| *x < v) as isize - 1;
//
// First, try hard to avoid doing the bisection search at all
// by checking for extrapolation first, and keeping a rolling
// initial guess that drastically improves perf for batch runs.
if v < grid[0] {
iloc = -1;
}
else if v > *grid.last().unwrap() {
iloc = grid.len() as isize - 1
}
else if grid[guess] < v && grid[(guess + 1).min(grid.len() - 1)] >= v {
iloc = guess as isize;
}
else {
// If all else fails, do the actual binary search
iloc = grid.partition_point(|x| *x < v) as isize - 1;
}

let dimmax = self.dims[dim].saturating_sub(2); // maximum index for lower corner

Expand Down Expand Up @@ -363,8 +386,8 @@ mod test {

let grids = [&x[..], &y[..]];

let interpolator: RectilinearGridInterpolator<'_, _, 2> =
RectilinearGridInterpolator::new(&grids, &z[..]);
let interpolator: &mut RectilinearGridInterpolator<'_, _, 2> =
&mut RectilinearGridInterpolator::new(&grids, &z[..]);

// Check values at every incident vertex
xy.iter().zip(z.iter()).for_each(|(xyi, zi)| {
Expand Down Expand Up @@ -402,8 +425,8 @@ mod test {

let grids = [&x[..], &y[..]];

let interpolator: RectilinearGridInterpolator<'_, _, 2> =
RectilinearGridInterpolator::new(&grids, &z[..]);
let interpolator: &mut RectilinearGridInterpolator<'_, _, 2> =
&mut RectilinearGridInterpolator::new(&grids, &z[..]);

interpolator.interp(&xy[..], &mut out);

Expand Down

0 comments on commit 21b4af1

Please sign in to comment.