-
-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FIX] KNN: Fix crash when Mahanalobis metric is used #1475
Conversation
Current coverage is 88.19% (diff: 100%)@@ master #1475 diff @@
==========================================
Files 77 77
Lines 7613 7617 +4
Methods 0 0
Messages 0 0
Branches 0 0
==========================================
+ Hits 6714 6718 +4
Misses 899 899
Partials 0 0
|
@@ -11,5 +11,7 @@ class KNNLearner(SklLearner): | |||
def __init__(self, n_neighbors=5, metric="euclidean", weights="uniform", | |||
algorithm='auto', | |||
preprocessors=None): | |||
if metric == "mahalanobis": | |||
algorithm = "brute" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment explaining why this is necessary. For the record, this works for me:
>>> from sklearn.neighbors import KNeighborsClassifier
>>> nn = KNeighborsClassifier(metric='mahalanobis', algorithm='auto')
>>> X = np.random.random((10, 2))
>>> y = (np.random.random(10) > .5).astype(int)
>>> nn.fit(X, y)
>>> nn.predict(X)
array([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me as well. But i.e. X = np.random.random((12, 2)) doesn't.
|
||
|
||
class TestKNNLearner(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.iris = Table('iris') | ||
cls.learn = KNNLearner() | ||
cls.learn = KNNLearner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would skip assigning classes in to cls.learn and cls.learn_reg and just use their names in tests.
self.learn() is not much shorter than KNNLearner() (in fact it has the same number of characters)
KNN learners (cls and reg) need additional parameter (metric_params) when Mahalanobis distance metric is used. Since both learners have the same parameters, new base class was created.
No description provided.