Skip to content

Commit

Permalink
COMPAT: ensure we pass numpy array to cKDtree and combat with sklearn…
Browse files Browse the repository at this point in the history
… 1.3.0 (#107)

* COMPAT: ensure we pass numpy array to cKDtree

* sklearn compat

* pin older sklearn in minimal
  • Loading branch information
martinfleis authored Jul 2, 2023
1 parent 44ce06b commit 308f0df
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion ci/envs/38-minimal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- libpysal=4.5
- py-opencv=4.6
# tests
- scikit-learn
- scikit-learn==1.2
- shapely
- geopandas
- pytest
Expand Down
17 changes: 13 additions & 4 deletions pointpats/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,20 @@ def build_best_tree(coordinates, metric):
coordinates = numpy.asarray(coordinates)
tree = spatial.cKDTree
try:
import sklearn
from sklearn.neighbors import KDTree, BallTree
from packaging.version import Version

if metric in KDTree.valid_metrics:
if Version(sklearn.__version__) >= Version("1.3.0"):
kdtree_valid_metrics = KDTree.valid_metrics()
balltree_valid_metrics = BallTree.valid_metrics()
else:
kdtree_valid_metrics = KDTree.valid_metrics
balltree_valid_metrics = BallTree.valid_metrics

if metric in kdtree_valid_metrics:
tree = lambda coordinates: KDTree(coordinates, metric=metric)
elif metric in BallTree.valid_metrics:
elif metric in balltree_valid_metrics:
tree = lambda coordinates: BallTree(coordinates, metric=metric)
elif callable(metric):
warnings.warn(
Expand All @@ -323,8 +332,8 @@ def build_best_tree(coordinates, metric):
else:
raise KeyError(
f"Metric {metric} not found in set of available types."
f"BallTree metrics: {BallTree.valid_metrics}, and"
f"scikit KDTree metrics: {KDTree.valid_metrics}."
f"BallTree metrics: {balltree_valid_metrics}, and"
f"scikit KDTree metrics: {kdtree_valid_metrics}."
)
except ModuleNotFoundError as e:
if metric not in ("l2", "euclidean"):
Expand Down
4 changes: 2 additions & 2 deletions pointpats/pointpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,9 @@ def knn_other(self, other, k=1):
if k < 1:
raise ValueError('k must be at least 1')
try:
nn = self.tree.query(other.points, k=k)
nn = self.tree.query(np.asarray(other.points), k=k)
except:
nn = self.tree.query(other, k=k)
nn = self.tree.query(np.asarray(other), k=k)
return nn[1], nn[0]

def explode(self, mark):
Expand Down

0 comments on commit 308f0df

Please sign in to comment.