-
Notifications
You must be signed in to change notification settings - Fork 0
/
NearestNeighborsMatcher.h
101 lines (83 loc) · 3.06 KB
/
NearestNeighborsMatcher.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#ifndef NEARESTNEIGHBORSMATCHER_H
#define NEARESTNEIGHBORSMATCHER_H
#include "TrainingDatum.h"
#include <QObject>
#include <QMap>
#include <functional>
template <typename T>
class NearestNeighborsMatcher : public QObject
{
public:
NearestNeighborsMatcher(QObject* parent = nullptr) : QObject(parent), _neighborCount(5)
{
}
void clear()
{
_trainingData.clear();
}
void setNeighborCount(uint neighborCount)
{
_neighborCount = neighborCount;
}
uint neighborCount() const
{
return _neighborCount;
}
void setDistanceFunction(std::function<qreal(T, T)> distanceFunction)
{
_distanceFunction = distanceFunction;
}
void addTrainingData(const TrainingDatum<T>& trainingData)
{
_trainingData.push_back(trainingData);
}
int classifyDataPoint(const T& untrainedDataPoint) const
{
int result = -1;
//Note: Boolean check on distance function means that the function is valid
//We also want to make sure that we have at least one point more than the number of neighbors to check against.
if (_trainingData.count() > _neighborCount && _distanceFunction)
{
QVector<QPair<int, qreal>> distancesWithClass;
int index = 0;
for (auto trainingDatum : _trainingData)
{
qreal distance = _distanceFunction(untrainedDataPoint, trainingDatum.observation());
auto pair = qMakePair(trainingDatum.classification(), distance);
distancesWithClass.push_back(pair);
}
std::sort(distancesWithClass.begin(), distancesWithClass.end(), [](const QPair<int, qreal>& firstPair, const QPair<int, qreal>& secondPair)
{
return firstPair.second < secondPair.second;
});
QMap<int, QPair<int, int>> totals;
index = 0;
auto distanceWithClass = distancesWithClass.begin();
while (distanceWithClass != distancesWithClass.end() && index < _neighborCount)
{
if (totals.contains((*distanceWithClass).first))
{
totals[(*distanceWithClass).first].second++;
}
else
{
totals[(*distanceWithClass).first] = qMakePair((*distanceWithClass).first, 1);
}
++distanceWithClass;
++index;
}
auto totalsAsVector = totals.values();
std::sort(totalsAsVector.begin(), totalsAsVector.end(), [](const QPair<int, int>& firstPair, const QPair<int, int>& secondPair)
{
return firstPair.second > secondPair.second;
});
result = totalsAsVector.first().first;
}
return result;
}
private:
uint _neighborCount;
std::function<qreal(T, T)> _distanceFunction;
QVector<TrainingDatum<T>> _trainingData;
};
#endif // NEARESTNEIGHBORSMATCHER_H