Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cmuparlay/ParlayANN into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Magdalen Dobson committed Mar 28, 2024
2 parents 2e13121 + 990ea77 commit 627bb2b
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 2 deletions.
103 changes: 103 additions & 0 deletions algorithms/utils/check_range_recall.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// This code is part of the Problem Based Benchmark Suite (PBBS)
// Copyright (c) 2011 Guy Blelloch and the PBBS team
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights (to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be included
// in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

#pragma once

#include <algorithm>
#include <set>

#include "beamSearch.h"
#include "csvfile.h"
#include "parse_results.h"
#include "parlay/parallel.h"
#include "parlay/primitives.h"
#include "types.h"
#include "stats.h"

template<typename Point, typename PointRange, typename indexType>
void checkRangeRecall(
Graph<indexType> &G,
PointRange &Base_Points,
PointRange &Query_Points,
RangeGroundTruth<indexType> GT,
RangeParams RP,
long start_point) {


parlay::sequence<parlay::sequence<indexType>> all_rr;

parlay::internal::timer t;
float query_time;
stats<indexType> QueryStats(Query_Points.size());

all_rr = RangeSearch<Point, PointRange, indexType>(Query_Points, G, Base_Points, QueryStats, start_point, RP);
query_time = t.next_time();


float pointwise_recall = 0.0;
float reported_results = 0.0;
float total_results = 0.0;
float num_nonzero = 0.0;

//since distances are exact, just have to cross-check number of results
size_t n = Query_Points.size();
int numCorrect = 0;
for (indexType i = 0; i < n; i++) {
float num_reported_results = all_rr[i].size();
float num_actual_results = GT[i].size();
reported_results += num_reported_results;
total_results += num_actual_results;
if(num_actual_results != 0) {pointwise_recall += num_reported_results/num_actual_results; num_nonzero++;}
}

pointwise_recall /= num_nonzero;
float cumulative_recall = reported_results/total_results;

float QPS = Query_Points.size() / query_time;
auto stats_ = {QueryStats.dist_stats(), QueryStats.visited_stats()};

std::cout << "For ";
RP.print();
std::cout << ", Pointwise Recall = " << pointwise_recall << ", Cumulative Recall = " << cumulative_recall << ", QPS = " << QPS << std::endl;


}


template<typename Point, typename PointRange, typename indexType>
void range_search_wrapper(Graph<indexType> &G, PointRange &Base_Points,
PointRange &Query_Points,
RangeGroundTruth<indexType> GT, double rad,
indexType start_point=0){

std::vector<long> beams;

beams = {10, 20, 30, 40, 50, 100, 1000, 2000, 3000};

for(long b: beams){
RangeParams RP(rad, b);
checkRangeRecall<Point, PointRange, indexType>(G, Base_Points, Query_Points, GT, RP, start_point);
}



}
3 changes: 1 addition & 2 deletions data_tools/compute_range_groundtruth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ int main(int argc, char* argv[]) {
PointRange<uint8_t, Euclidian_Point<uint8_t>> B = PointRange<uint8_t, Euclidian_Point<uint8_t>>(bFile);
PointRange<uint8_t, Euclidian_Point<uint8_t>> Q = PointRange<uint8_t, Euclidian_Point<uint8_t>>(qFile);
answers = compute_range_groundtruth<PointRange<uint8_t, Euclidian_Point<uint8_t>>>(B, Q, r);
write_nonzero_elts<PointRange<uint8_t, Euclidian_Point<uint8_t>>, uint8_t>(answers, Q, std::string(gFile));
} else if(df == "mips"){
PointRange<uint8_t, Mips_Point<uint8_t>> B = PointRange<uint8_t, Mips_Point<uint8_t>>(bFile);
PointRange<uint8_t, Mips_Point<uint8_t>> Q = PointRange<uint8_t, Mips_Point<uint8_t>>(qFile);
Expand All @@ -151,7 +150,7 @@ int main(int argc, char* argv[]) {
answers = compute_range_groundtruth<PointRange<int8_t, Mips_Point<int8_t>>>(B, Q, r);
}
}
// write_rangeres(answers, std::string(gFile));
write_rangeres(answers, std::string(gFile));


return 0;
Expand Down

0 comments on commit 627bb2b

Please sign in to comment.