Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 21, 2024
1 parent 7c4bcbf commit 9bd9c41
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions straxen/itp_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def __init__(self, points, values, neighbours_to_use=None, array_valued=False, r
def __call__(self, points):
points = np.asarray(points)
if self.rotated:
assert points.shape[1] == 2, "InterpolateAndExtrapolate roated expects points of dimension 2"
points = np.array(straxen.rotate_perp_wires(points[:,0], points[:,1])).T
assert (
points.shape[1] == 2
), "InterpolateAndExtrapolate roated expects points of dimension 2"
points = np.array(straxen.rotate_perp_wires(points[:, 0], points[:, 1])).T

# kdtree doesn't grok NaNs
# Start with all Nans, then overwrite for the finite points
Expand Down Expand Up @@ -191,7 +193,9 @@ def itp_fun(positions):
itp_fun = self._regular_grid_interpolator(csys, map_data, array_valued, **kwargs)

elif method == "WeightedNearestNeighbors":
itp_fun = self._weighted_nearest_neighbors(csys, map_data, array_valued, rotated, **kwargs)
itp_fun = self._weighted_nearest_neighbors(
csys, map_data, array_valued, rotated, **kwargs
)

elif method == "Linear1Din2D":
assert rotated, "Linear1Din2D interpolate maps should be in rotated coordinates"
Expand All @@ -213,8 +217,7 @@ def __call__(self, *args, map_name="map"):

@staticmethod
def _linear_1d_in_2d(csys, map_data, array_valued, rotated=False, **kwargs):
"""Linear interpolator along the x-axis and nearest value along the y-axis
"""
"""Linear interpolator along the x-axis and nearest value along the y-axis."""
dimensions = len(csys[0])
grid = [np.unique(csys[:, i]) for i in range(dimensions)]

Expand All @@ -232,12 +235,20 @@ def interp(X, Y):
yp = (np.searchsorted(grid[1], y) - 1)[0]

# interpolate linearly over x along the y axis (along the wires) for each PMT
return np.array([[compiled_interp(xp, grid[0], map_data[:,yp,i]) for i in range(map_data.shape[-1])] for xp in x])
return np.array(
[
[
compiled_interp(xp, grid[0], map_data[:, yp, i])
for i in range(map_data.shape[-1])
]
for xp in x
]
)

def arg_formated_interp(positions):
if isinstance(positions, list):
positions = np.array(positions)
return interp(positions[:, 0], positions[:, 1])
if isinstance(positions, list):
positions = np.array(positions)
return interp(positions[:, 0], positions[:, 1])

return arg_formated_interp

Expand Down Expand Up @@ -284,7 +295,9 @@ def _weighted_nearest_neighbors(csys, map_data, array_valued, rotated=False, **k
else:
map_data = map_data.flatten()
kwargs = straxen.filter_kwargs(InterpolateAndExtrapolate, kwargs)
return InterpolateAndExtrapolate(csys, map_data, array_valued=array_valued, rotated=rotated, **kwargs)
return InterpolateAndExtrapolate(
csys, map_data, array_valued=array_valued, rotated=rotated, **kwargs
)

def scale_coordinates(self, scaling_factor, map_name="map"):
"""Scales the coordinate system by the specified factor.
Expand Down

0 comments on commit 9bd9c41

Please sign in to comment.