Skip to content

Commit

Permalink
first working version of corner-region extrapolation
Browse files Browse the repository at this point in the history
  • Loading branch information
jlogan03 committed Nov 12, 2023
1 parent 7443228 commit 8b1da30
Showing 1 changed file with 161 additions and 37 deletions.
198 changes: 161 additions & 37 deletions interpn/src/multilinear_regular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,27 +103,43 @@ where
// ergonomically unpleasant.
// Also notably, storing the index offsets as bool instead of usize
// reduces memory overhead, but has not effect on throughput rate.
let inds: &mut [usize] = &mut [0_usize; MAXDIMS][..ndims]; // Indices of lower corner of hypercube
let origin: &mut [usize] = &mut [0_usize; MAXDIMS][..ndims]; // Indices of lower corner of hypercube
let ioffs = &mut [false; MAXDIMS][..ndims]; // Offset index for selected vertex
let sat = &mut [0_u8; MAXDIMS][..ndims]; // Saturated-low flag
let sat = &mut [0_u8; MAXDIMS][..ndims]; // Saturation none/high/low flags for each dim

// let opinds: &mut [usize] = &mut [0_usize; MAXDIMS][..ndims]; // Indices of opposite vertex
// let opextsat = &mut [0_u8; MAXDIMS][..ndims]; // Opposite vertex sat flags

let dxs = &mut [T::zero(); MAXDIMS][..ndims]; // Sub-cell volume storage
let extrapdxs = &mut [T::zero(); MAXDIMS][..ndims]; // Extrapolated distances

// Whether the opposite vertex is on the saturated bound
// on each dimension
let opsat = &mut [false; MAXDIMS][..ndims];

// Whether the current vertex is on the saturated bound
// on each dimension
let thissat = &mut [false; MAXDIMS][..ndims];

// Populate lower corner and saturation flag for each dimension
for i in 0..ndims {
(inds[i], sat[i]) = self.get_loc(x[i], i)
(origin[i], sat[i]) = self.get_loc(x[i], i)
}

// Check if any dimension is saturated.
// This gives a ~15% overall speedup for points on the interior.
let maybe_neg = (0..ndims).any(|j| sat[j] != 0);
let any_dims_saturated = (0..ndims).any(|j| sat[j] != 0);

// Traverse vertices, summing contributions to the interpolated value.
//
// This visits the 2^ndims elements of the cartesian product
// of `[0, 1] x ... x [0, 1]` using O(ndims) storage.
// of `[0, 1] x ... x [0, 1]` without simultaneously actualizing them in storage.
let mut interped = T::zero();
let nverts = 2_usize.pow(ndims as u32);
for i in 0..nverts {
let mut k: usize = 0; // index of the value for this vertex in self.vals
let mut sign = T::one(); // sign of the contribution from this vertex
let mut extrapvol = T::zero();

// Every 2^nth vertex, flip which side of the cube we are examining
// in the nth dimension.
Expand All @@ -142,43 +158,134 @@ where
// Accumulate the index into the value array,
// saturating to the bound if the resulting index would be outside.
for j in 0..ndims {
k += self.dimprod[j] * (inds[j] + ioffs[j] as usize).min(self.dims[j] - 1);
k += self.dimprod[j]
* (origin[j] + ioffs[j] as usize).min(self.dims[j].saturating_sub(1));
}

// Get the value at this vertex
let v = self.vals[k];

// Accumulate the volume of the prism formed by the
// observation location and the opposing vertex
let mut vol = T::one();

for j in 0..ndims {
let iloc = inds[j] + !ioffs[j] as usize; // Index of location of opposite vertex
let iloc = origin[j] + !ioffs[j] as usize; // Index of location of opposite vertex
let loc = self.starts[j] + self.steps[j] * T::from(iloc).unwrap(); // Loc. of opposite vertex
let dx = x[j] - loc; // Delta position from opposite vertex to obs. loc
vol = vol * dx;
dxs[j] = loc; // Use dxs[j] as storage for float locs
}

// Determine the sign of the contribution.
// For observation points outside the grid, negate the contribution
// of the inner points on the dimensions that are saturated, in order
// to naturally transition to extrapolation.
// If there are any saturated points, check if the current vertex
// is on a saturated dimension. If it is, return
if maybe_neg {
let neg =
(0..ndims).any(|j| (((!ioffs[j]) && sat[j] == 2) || (ioffs[j] && sat[j] == 1)));
for j in 0..ndims {
dxs[j] = x[j] - dxs[j]; // Make the actual delta-locs
}

for j in 0..ndims {
dxs[j] = dxs[j].abs();
}

// Clip maximum dx for some cases to handle multidimensional extrapolation
if any_dims_saturated {
// For which dimensions is the opposite vertex on a saturated bound?
(0..ndims).for_each(|j| {
opsat[j] = (!ioffs[j] && sat[j] == 2) || (ioffs[j] && sat[j] == 1)
});

// For how many total dimensions is the opposite vertex on a saturated bound?
let opsatcount = opsat.iter().fold(0, |acc, x| acc + *x as usize);

// If the opposite vertex is on exactly one saturated bound, negate its contribution
// let neg = (0..ndims).any(|j| (((!ioffs[j]) && sat[j] == 2) || (ioffs[j] && sat[j] == 1)));
let neg = opsatcount == 1;
if neg {
interped = interped + v.neg() * vol.abs();
continue;
sign = sign.neg();
}

// If the opposite vertex is on _more_ than one saturated bound,
// it should be clipped on multiple axes which, if the clipping
// were implemented in a general constructive geometry way, would
// result in a zero volume. Since we only deal in the difference
// in position between vertices and the observation point, our
// clipping method would not properly set this volume to zero,
// and we need to implement that behavior with explicit logic.
let zeroed = opsatcount > 1;
if zeroed {
sign = T::zero();
}

// If the opposite vertex is on exactly one saturated bound,
// allow the dx on that dimension to be as large as needed,
// but clip the dx on other saturated dimensions so that we
// don't produce an overlapping partition in outside-corner regions.
if neg {
for j in 0..ndims {
let is_saturated = sat[j] != 0;
if is_saturated && !opsat[j] {
println!("{i} clipping axis {j} due to opposite vertex");
let dxb4 = dxs[j];
dxs[j] = dxs[j].min(self.steps[j]);
println!(
"{i} dx before: {:?} after: {:?}",
<f64 as NumCast>::from(dxb4).unwrap(),
<f64 as NumCast>::from(dxs[j]).unwrap()
);
}
}
}

// For which dimensions is the current vertex on a saturated bound?
(0..ndims).for_each(|j| {
thissat[j] = (ioffs[j] && sat[j] == 2) || (!ioffs[j] && sat[j] == 1)
});

// For how many total dimensions is the current vertex on a saturated bound?
let thissatcount = thissat.iter().fold(0, |acc, x| acc + *x as usize);

// Subtract the extrapolated volume from the contribution for this vertex
// if it is on multiple saturated bounds.
// Put differently - find the part of the volume that is scaling non-linearly
// in the coordinates, and bookkeep it to be removed entirely later.
if thissatcount > 1 {
// Copy forward the original dxs, extrapolated or not
(0..ndims).for_each(|j| extrapdxs[j] = dxs[j]);
// For extrapolated dimensions, take just the extrapolated distance
(0..ndims).for_each(|j| if thissat[j] { extrapdxs[j] = dxs[j] - self.steps[j] });
// Evaluate the extrapolated corner volume
extrapvol = extrapdxs.iter().fold(T::one(), |acc, x| acc * *x);
}

}

println!(
"{i} dxs {:?}",
dxs.iter()
.map(|xi| <f64 as NumCast>::from(*xi).unwrap())
.collect::<Vec<f64>>()
);
println!(
"{i} steps {:?}",
self.steps[..ndims]
.iter()
.map(|xi| <f64 as NumCast>::from(*xi).unwrap())
.collect::<Vec<f64>>()
);

let vol = (dxs.iter().fold(T::one(), |acc, x| acc * *x).abs() - extrapvol) * sign;

// Add contribution from this vertex, leaving the division
// by the total cell volume for later to save a few flops
interped = interped + v * vol.abs();
interped = interped + v * vol;

println!(
"{i} {ioffs:?} {sat:?} {} {} {} {}",
<f64 as NumCast>::from(v).unwrap(),
<f64 as NumCast>::from(vol).unwrap(),
<f64 as NumCast>::from(sign).unwrap(),
<f64 as NumCast>::from(interped / self.vol).unwrap(),
);
}

println!(
"Final {}",
<f64 as NumCast>::from(interped / self.vol).unwrap()
);
interped / self.vol
}

Expand Down Expand Up @@ -352,30 +459,45 @@ mod test {
// Make a function that is linear in both dimensions
// and should behave reasonably well under extrapolation in one
// dimension at a time, but not necessarily when extrapolating in both at once.
let z: Vec<f64> = grid.iter().map(|xyi| xyi[0] * xyi[1]).collect();
let zgrid: Vec<f64> = grid.iter().map(|xyi| xyi[0] * xyi[1]).collect();
// let z1: Vec<f64> = grid.iter().map(|xyi| xyi[0] + xyi[1]).collect();

// Make some grids to extrapolate
// High/low x
let xe1 = vec![-1.0; ny];
let xe2 = vec![11.0; ny];
let ye1 = linspace(-5.0, 5.0, ny);
let xye1: Vec<f64> = xe1.iter().interleave(ye1.iter()).map(|xi| *xi).collect();
let ze1: Vec<f64> = (0..ny).map(|i| xe1[i] + ye1[i]).collect();
let ze1: Vec<f64> = (0..ny).map(|i| xe1[i] * ye1[i]).collect();
let xye2: Vec<f64> = xe2.iter().interleave(ye1.iter()).map(|xi| *xi).collect();
let ze2: Vec<f64> = (0..ny).map(|i| xe2[i] + ye1[i]).collect();
let ze2: Vec<f64> = (0..ny).map(|i| xe2[i] * ye1[i]).collect();
// High/low y
let ye2 = vec![-6.0; nx];
let ye3 = vec![6.0; nx];
let xe3 = linspace(0.0, 10.0, nx);
let xye3: Vec<f64> = xe3.iter().interleave(ye2.iter()).map(|xi| *xi).collect();
let xye4: Vec<f64> = xe3.iter().interleave(ye3.iter()).map(|xi| *xi).collect();
let ze3: Vec<f64> = (0..nx).map(|i| xe3[i] + ye2[i]).collect();
let ze4: Vec<f64> = (0..nx).map(|i| xe3[i] + ye3[i]).collect();
let ze3: Vec<f64> = (0..nx).map(|i| xe3[i] * ye2[i]).collect();
let ze4: Vec<f64> = (0..nx).map(|i| xe3[i] * ye3[i]).collect();
// High/low corners and all over the place
let xw = linspace(-1.0, 11.0, 2);
let yw = linspace(-6.0, 6.0, 2);
let xyw: Vec<f64> = meshgrid(vec![&xw, &yw]).iter().flatten().map(|xx| *xx).collect();
let zw: Vec<f64> = (0..xyw.len() / 2).map(|i| xyw[2 * i] + xyw[2 * i + 1]).collect();
let xw = linspace(-10.0, 11.0, 100);
let yw = linspace(-7.0, 6.0, 100);
// let yw = vec![-6.0; 1];
let xyw: Vec<f64> = meshgrid(vec![&xw, &yw])
.iter()
.flatten()
.map(|xx| *xx)
.collect();
// let zw: Vec<f64> = (0..xyw.len() / 2)
// .map(|i| xyw[2 * i] * xyw[2 * i + 1])
// .collect();
let zw: Vec<f64> = (0..xyw.len() / 2)
.map(|i| xyw[2 * i] + xyw[2 * i + 1])
.collect();
let zgrid1: Vec<f64> = grid.iter().map(|xyi| xyi[0] + xyi[1]).collect();
// let zw: Vec<f64> = (0..xyw.len() / 2)
// .map(|i| xyw[2 * i])
// .collect();

let mut out = vec![0.0; nx.max(ny).max(zw.len())];

Expand All @@ -384,23 +506,25 @@ mod test {
let steps = [x[1] - x[0], y[1] - y[0]];

// Check extrapolating low x
interpn(&xye1, &mut out[..ny], &z, &dims, &starts, &steps);
interpn(&xye1, &mut out[..ny], &zgrid, &dims, &starts, &steps);
(0..ze1.len()).for_each(|i| assert!((out[i] - ze1[i]).abs() < 1e-12));

// Check extrapolating high x
interpn(&xye2, &mut out[..ny], &z, &dims, &starts, &steps);
interpn(&xye2, &mut out[..ny], &zgrid, &dims, &starts, &steps);
(0..ze2.len()).for_each(|i| assert!((out[i] - ze2[i]).abs() < 1e-12));

// Check extrapolating low y
interpn(&xye3, &mut out[..nx], &z, &dims, &starts, &steps);
interpn(&xye3, &mut out[..nx], &zgrid, &dims, &starts, &steps);
(0..ze3.len()).for_each(|i| assert!((out[i] - ze3[i]).abs() < 1e-12));

// Check extrapolating high y
interpn(&xye4, &mut out[..nx], &z, &dims, &starts, &steps);
interpn(&xye4, &mut out[..nx], &zgrid, &dims, &starts, &steps);
(0..ze4.len()).for_each(|i| assert!((out[i] - ze4[i]).abs() < 1e-12));

// Check interpolating off grid on the interior
interpn(&xyw, &mut out[..zw.len()], &z, &dims, &starts, &steps);
println!("\nCorners");
println!("xyw {xyw:?}");
interpn(&xyw, &mut out[..zw.len()], &zgrid1, &dims, &starts, &steps);
(0..zw.len()).for_each(|i| assert!((out[i] - zw[i]).abs() < 1e-12));
}
}

0 comments on commit 8b1da30

Please sign in to comment.