Skip to content

Commit

Permalink
Support per query filter (#279)
Browse files Browse the repository at this point in the history
* Transferring Varun's chagges from external fork with squash merge

* generating multiple gt's for each filter label + search with multiple filter labels (code cleanup)

* supporting no-filter + one filter label + filter label file (multiple filters) while computing GT

* generating multiple gt's + refactoring code for readability & cleanliness

* adding more tests for filtered search

* updating pr-test to test filtered cases

* lowering recall requirement for disk index

* transferred functions to filter_utils 

* adding more test for build and search without universal label

* adding one_per_point distribution to generate_synthetic_labels + cleaning up artifacts after compute gt+ removing minor errors

* refactoring search_disk_index to use a query filter vector
---------

Co-authored-by: patelyash <patelyash@microsoft.com>
Co-authored-by: Varun Sivashankar <t-varunsi@microsoft.com>
  • Loading branch information
3 people authored and jinwei14 committed Mar 29, 2023
1 parent 162d1ea commit 331b574
Show file tree
Hide file tree
Showing 14 changed files with 1,792 additions and 622 deletions.
65 changes: 56 additions & 9 deletions .github/workflows/pr-test.yml

Large diffs are not rendered by default.

212 changes: 212 additions & 0 deletions include/filter_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#pragma once
#include <algorithm>
#include <fcntl.h>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iostream>
#include <memory>
#include <random>
#include <set>
#include <tuple>
#include <string>
#include <tsl/robin_map.h>
#include <tsl/robin_set.h>
#ifdef __APPLE__
#else
#include <malloc.h>
#endif

#ifdef _WINDOWS
#include <Windows.h>
typedef HANDLE FileHandle;
#else
#include <unistd.h>
typedef int FileHandle;
#endif

#ifndef _WINDOWS
#include <sys/uio.h>
#endif

#include "cached_io.h"
#include "common_includes.h"
#include "memory_mapper.h"
#include "utils.h"
#include "windows_customizations.h"

// custom types (for readability)
typedef tsl::robin_set<std::string> label_set;
typedef std::string path;

// structs for returning multiple items from a function
typedef std::tuple<std::vector<label_set>, tsl::robin_map<std::string, _u32>, tsl::robin_set<std::string>>
parse_label_file_return_values;
typedef std::tuple<std::vector<std::vector<_u32>>, _u64> load_label_index_return_values;

namespace diskann
{
template <typename T>
DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, label_set all_labels,
unsigned R, unsigned L, float alpha, unsigned num_threads);

DISKANN_DLLEXPORT load_label_index_return_values load_label_index(path label_index_path, _u32 label_number_of_points);

DISKANN_DLLEXPORT parse_label_file_return_values parse_label_file(path label_data_path, std::string universal_label);

template <typename T>
DISKANN_DLLEXPORT tsl::robin_map<std::string, std::vector<_u32>> generate_label_specific_vector_files_compat(
path input_data_path, tsl::robin_map<std::string, _u32> labels_to_number_of_points,
std::vector<label_set> point_ids_to_labels, label_set all_labels);

/*
* For each label, generates a file containing all vectors that have said label.
* Also copies data from original bin file to new dimension-aligned file.
*
* Utilizes POSIX functions mmap and writev in order to minimize memory
* overhead, so we include an STL version as well.
*
* Each data file is saved under the following format:
* input_data_path + "_" + label
*/
template <typename T>
inline tsl::robin_map<std::string, std::vector<_u32>> generate_label_specific_vector_files(
path input_data_path, tsl::robin_map<std::string, _u32> labels_to_number_of_points,
std::vector<label_set> point_ids_to_labels, label_set all_labels)
{
auto file_writing_timer = std::chrono::high_resolution_clock::now();
diskann::MemoryMapper input_data(input_data_path);
char *input_start = input_data.getBuf();

_u32 number_of_points, dimension;
std::memcpy(&number_of_points, input_start, sizeof(_u32));
std::memcpy(&dimension, input_start + sizeof(_u32), sizeof(_u32));
const _u32 VECTOR_SIZE = dimension * sizeof(T);
const size_t METADATA = 2 * sizeof(_u32);
if (number_of_points != point_ids_to_labels.size())
{
std::cerr << "Error: number of points in labels file and data file differ." << std::endl;
throw;
}

tsl::robin_map<std::string, iovec *> label_to_iovec_map;
tsl::robin_map<std::string, _u32> label_to_curr_iovec;
tsl::robin_map<std::string, std::vector<_u32>> label_id_to_orig_id;

// setup iovec list for each label
for (const auto &lbl : all_labels)
{
iovec *label_iovecs = (iovec *)malloc(labels_to_number_of_points[lbl] * sizeof(iovec));
if (label_iovecs == nullptr)
{
throw;
}
label_to_iovec_map[lbl] = label_iovecs;
label_to_curr_iovec[lbl] = 0;
label_id_to_orig_id[lbl].reserve(labels_to_number_of_points[lbl]);
}

// each point added to corresponding per-label iovec list
for (_u32 point_id = 0; point_id < number_of_points; point_id++)
{
char *curr_point = input_start + METADATA + (VECTOR_SIZE * point_id);
iovec curr_iovec;

curr_iovec.iov_base = curr_point;
curr_iovec.iov_len = VECTOR_SIZE;
for (const auto &lbl : point_ids_to_labels[point_id])
{
*(label_to_iovec_map[lbl] + label_to_curr_iovec[lbl]) = curr_iovec;
label_to_curr_iovec[lbl]++;
label_id_to_orig_id[lbl].push_back(point_id);
}
}

// write each label iovec to resp. file
for (const auto &lbl : all_labels)
{
int label_input_data_fd;
path curr_label_input_data_path(input_data_path + "_" + lbl);
_u32 curr_num_pts = labels_to_number_of_points[lbl];

label_input_data_fd =
open(curr_label_input_data_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC | O_APPEND, (mode_t)0644);
if (label_input_data_fd == -1)
throw;

// write metadata
_u32 metadata[2] = {curr_num_pts, dimension};
int return_value = write(label_input_data_fd, metadata, sizeof(_u32) * 2);
if (return_value == -1)
{
throw;
}

// limits on number of iovec structs per writev means we need to perform
// multiple writevs
size_t i = 0;
while (curr_num_pts > IOV_MAX)
{
return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), IOV_MAX);
if (return_value == -1)
{
close(label_input_data_fd);
throw;
}
curr_num_pts -= IOV_MAX;
i += 1;
}
return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), curr_num_pts);
if (return_value == -1)
{
close(label_input_data_fd);
throw;
}

free(label_to_iovec_map[lbl]);
close(label_input_data_fd);
}

std::chrono::duration<double> file_writing_time = std::chrono::high_resolution_clock::now() - file_writing_timer;
std::cout << "generated " << all_labels.size() << " label-specific vector files for index building in time "
<< file_writing_time.count() << "\n"
<< std::endl;

return label_id_to_orig_id;
}

inline std::vector<uint32_t> loadTags(const std::string &tags_file, const std::string &base_file)
{
const bool tags_enabled = tags_file.empty() ? false : true;
std::vector<uint32_t> location_to_tag;
if (tags_enabled)
{
size_t tag_file_ndims, tag_file_npts;
std::uint32_t *tag_data;
diskann::load_bin<std::uint32_t>(tags_file, tag_data, tag_file_npts, tag_file_ndims);
if (tag_file_ndims != 1)
{
diskann::cerr << "tags file error" << std::endl;
throw diskann::ANNException("tag file error", -1, __FUNCSIG__, __FILE__, __LINE__);
}

// check if the point count match
size_t base_file_npts, base_file_ndims;
diskann::get_bin_metadata(base_file, base_file_npts, base_file_ndims);
if (base_file_npts != tag_file_npts)
{
diskann::cerr << "point num in tags file mismatch" << std::endl;
throw diskann::ANNException("point num in tags file mismatch", -1, __FUNCSIG__, __FILE__, __LINE__);
}

location_to_tag.assign(tag_data, tag_data + tag_file_npts);
delete[] tag_data;
}
return location_to_tag;
}

} // namespace diskann
79 changes: 79 additions & 0 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ typedef int FileHandle;

#define BUFFER_SIZE_FOR_CACHED_IO (_u64)1024 * (_u64)1048576

#define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"
#define PBWIDTH 60

inline bool file_exists(const std::string &name, bool dirCheck = false)
{
int val;
Expand Down Expand Up @@ -693,6 +696,16 @@ inline uint64_t save_bin(const std::string &filename, T *data, size_t npts, size
diskann::cout << "Finished writing bin." << std::endl;
return bytes_written;
}

inline void print_progress(double percentage)
{
int val = (int)(percentage * 100);
int lpad = (int)(percentage * PBWIDTH);
int rpad = PBWIDTH - lpad;
printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, "");
fflush(stdout);
}

// load_aligned_bin functions START

template <typename T>
Expand Down Expand Up @@ -1015,6 +1028,72 @@ template <typename T> inline void normalize(T *arr, size_t dim)
}
}

inline std::vector<std::string> read_file_to_vector_of_strings(const std::string &filename, bool unique = false)
{
std::vector<std::string> result;
std::set<std::string> elementSet;
if (filename != "")
{
std::ifstream file(filename);
if (file.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + filename, -1);
}
std::string line;
while (std::getline(file, line))
{
if (line.empty())
{
break;
}
if (line.find(',') != std::string::npos)
{
std::cerr << "Every query must have exactly one filter" << std::endl;
exit(-1);
}
if (!line.empty() && (line.back() == '\r' || line.back() == '\n'))
{
line.erase(line.size() - 1);
}
if (!elementSet.count(line))
{
result.push_back(line);
}
if (unique)
{
elementSet.insert(line);
}
}
file.close();
}
else
{
throw diskann::ANNException(std::string("Failed to open file. filename can not be blank"), -1);
}
return result;
}

inline void clean_up_artifacts(tsl::robin_set<std::string> paths_to_clean, tsl::robin_set<std::string> path_suffixes)
{
try
{
for (const auto &path : paths_to_clean)
{
for (const auto &suffix : path_suffixes)
{
std::string curr_path_to_clean(path + "_" + suffix);
if (std::remove(curr_path_to_clean.c_str()) != 0)
diskann::cout << "Warning: Unable to remove file :" << curr_path_to_clean << std::endl;
}
}
diskann::cout << "Cleaned all artifacts" << std::endl;
}
catch (const std::exception &e)
{
diskann::cout << "Warning: Unable to clean all artifacts" << std::endl;
}
}

#ifdef _WINDOWS
#include <intrin.h>
#include <Psapi.h>
Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ else()
set(CPP_SOURCES ann_exception.cpp disk_utils.cpp distance.cpp index.cpp
linux_aligned_file_reader.cpp math_utils.cpp natural_number_map.cpp
natural_number_set.cpp memory_mapper.cpp partition.cpp pq.cpp
pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp)
pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp filter_utils.cpp)
if (RESTAPI)
list(APPEND CPP_SOURCES restapi/search_wrapper.cpp restapi/server.cpp)
endif()
Expand Down
2 changes: 1 addition & 1 deletion src/dll/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#Licensed under the MIT license.

add_library(${PROJECT_NAME} SHARED dllmain.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp
../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp ../math_utils.cpp ../disk_utils.cpp
../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp
../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp)

set(TARGET_DIR "$<$<CONFIG:Debug>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$<CONFIG:Release>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>")
Expand Down
Loading

0 comments on commit 331b574

Please sign in to comment.