Skip to content

Commit

Permalink
New python interface, build setup, apps and unit tests (#308)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Dax Pryce <daxpryce@microsoft.com>

* Adding some diagnostics to a pr build in an attempt to see what is going on with our systems prior to running our streaming/incremental tests

* fix cast error and add some status prints to in-mem-dynamic app

* Adding unit tests for both memory and disk index builder methods

* After the refactor and polish of the API was left half done, I also left half a jillion bugs in the library. At least I'm confident that build_memory_index and StaticMemoryIndex work in some cases, whereas before they barely were getting off the ground

* Sanity checks of static index (not comprehensive coverage), and tombstone file for test_dynamic_memory_index

* Argument range checks of some of the static memory index values.

* fixes for dynamic index in python interface (#334)

* create separate default number of frozen points for dynamic indices

* consolidate works

* remove superfluous param from dynamic index

* remove superfluous param from dynamic index

* batch insert and args modification to apps

* batch insert and args modification to apps

* typo

* Committing the updated unit tests. At least the initial sanity checks of StaticMemory are done

* Fixing an error in the static memory index ctor

* Formatting python with black

* Have to disable initial load with DynamicMemoryIndex, as there is no way to build a memory index with an associated tags file yet, making it impossible to load an index without tags

* Working on unit tests and need to pull harsha's changes

* I think I aligned this such that we can execute it via command line with the right behaviors

* Providing rest of parameters build_memory_index requires

* For some reason argparse is allowing a bunch of blank space to come in on arguments and they need stripped. It also needs to be using the right types.

* Recall test now works

* More unit tests for dynamic memory index

* Adding different range check for alpha, as the values are only really that realistic between 1 and 2. Below 1 is an error, and above 2 we'll probably make a warning going forward

* Storing this while I cut a new branch and walk back some work for a future branch

* Undoing the auto load of the dynamic index until I can debug why my tag vector files cause an error in diskann

* Updating the documentation for the python bindings. It's a lot closer than it was.

* Fixing a unit test

* add timers to dyanmic apps (#337)

* add timers to dyanmic apps

* clang format

* np.uintc vs. int for dtype of tags

* fixes to types in dynamic app

* cast tags to np.uintc array

* more timers

* added example code in comments in app file

* round elapsed

* fix typo

* fix typo

---------

Co-authored-by: Harsha Vardhan Simhadri <harsha-simhadri@users.noreply.github.com>
Co-authored-by: harsha vardhan simhadri <harsha.v.simhadri@gmail.com>
  • Loading branch information
3 people authored Apr 27, 2023
1 parent 45a5409 commit 38d8c44
Show file tree
Hide file tree
Showing 30 changed files with 2,667 additions and 789 deletions.
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,10 @@ else()
endif()

add_subdirectory(src)
add_subdirectory(tests)
add_subdirectory(tests/utils)
if (NOT PYBIND)
add_subdirectory(tests)
add_subdirectory(tests/utils)
endif()

if (MSVC)
message(STATUS "The ${PROJECT_NAME}.sln has been created, opened it from VisualStudio to build Release or Debug configurations.\n"
Expand Down
1 change: 0 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@ recursive-include python *
recursive-include windows *
prune python/tests
recursive-include src *
recursive-include tests *
4 changes: 2 additions & 2 deletions include/defaults.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ namespace defaults
{
const float ALPHA = 1.2f;
const uint32_t NUM_THREADS = 0;
const uint32_t NUM_ROUNDS = 2;
const uint32_t MAX_OCCLUSION_SIZE = 750;
const uint32_t FILTER_LIST_SIZE = 0;
const uint32_t NUM_FROZEN_POINTS = 0;
const uint32_t NUM_FROZEN_POINTS_STATIC = 0;
const uint32_t NUM_FROZEN_POINTS_DYNAMIC = 1;
// following constants should always be specified, but are useful as a
// sensible default at cli / python boundaries
const uint32_t MAX_DEGREE = 64;
Expand Down
11 changes: 6 additions & 5 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,15 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

// Batch build from a file. Optionally pass tags vector.
DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load,
IndexWriteParameters &parameters, const std::vector<TagT> &tags = std::vector<TagT>());
const IndexWriteParameters &parameters,
const std::vector<TagT> &tags = std::vector<TagT>());

// Batch build from a file. Optionally pass tags file.
DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load,
IndexWriteParameters &parameters, const char *tag_filename);
const IndexWriteParameters &parameters, const char *tag_filename);

// Batch build from a data array, which must pad vectors to aligned_dim
DISKANN_DLLEXPORT void build(const T *data, const size_t num_points_to_load, IndexWriteParameters &parameters,
DISKANN_DLLEXPORT void build(const T *data, const size_t num_points_to_load, const IndexWriteParameters &parameters,
const std::vector<TagT> &tags);

// Filtered Support
Expand Down Expand Up @@ -215,7 +216,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

// Use after _data and _nd have been populated
// Acquire exclusive _update_lock before calling
void build_with_data_populated(IndexWriteParameters &parameters, const std::vector<TagT> &tags);
void build_with_data_populated(const IndexWriteParameters &parameters, const std::vector<TagT> &tags);

// generates 1 frozen point that will never be deleted from the graph
// This is not visible to the user
Expand Down Expand Up @@ -261,7 +262,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
void inter_insert(uint32_t n, std::vector<uint32_t> &pruned_list, InMemQueryScratch<T> *scratch);

// Acquire exclusive _update_lock before calling
void link(IndexWriteParameters &parameters);
void link(const IndexWriteParameters &parameters);

// Acquire exclusive _tag_lock and _delete_lock before calling
int reserve_location();
Expand Down
24 changes: 8 additions & 16 deletions include/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@ class IndexWriteParameters
const bool saturate_graph;
const uint32_t max_occlusion_size; // C
const float alpha;
const uint32_t num_rounds;
const uint32_t num_threads;
const uint32_t filter_list_size; // Lf
const uint32_t num_frozen_points;

private:
IndexWriteParameters(const uint32_t search_list_size, const uint32_t max_degree, const bool saturate_graph,
const uint32_t max_occlusion_size, const float alpha, const uint32_t num_rounds,
const uint32_t num_threads, const uint32_t filter_list_size, const uint32_t num_frozen_points)
const uint32_t max_occlusion_size, const float alpha, const uint32_t num_threads,
const uint32_t filter_list_size, const uint32_t num_frozen_points)
: search_list_size(search_list_size), max_degree(max_degree), saturate_graph(saturate_graph),
max_occlusion_size(max_occlusion_size), alpha(alpha), num_rounds(num_rounds), num_threads(num_threads),
max_occlusion_size(max_occlusion_size), alpha(alpha), num_threads(num_threads),
filter_list_size(filter_list_size), num_frozen_points(num_frozen_points)
{
}
Expand Down Expand Up @@ -70,21 +69,15 @@ class IndexWriteParametersBuilder
return *this;
}

IndexWriteParametersBuilder &with_num_rounds(const uint32_t num_rounds)
{
_num_rounds = num_rounds;
return *this;
}

IndexWriteParametersBuilder &with_num_threads(const uint32_t num_threads)
{
_num_threads = num_threads;
_num_threads = num_threads == 0 ? omp_get_num_threads() : num_threads;
return *this;
}

IndexWriteParametersBuilder &with_filter_list_size(const uint32_t filter_list_size)
{
_filter_list_size = filter_list_size;
_filter_list_size = filter_list_size == 0 ? _search_list_size : filter_list_size;
return *this;
}

Expand All @@ -97,13 +90,13 @@ class IndexWriteParametersBuilder
IndexWriteParameters build() const
{
return IndexWriteParameters(_search_list_size, _max_degree, _saturate_graph, _max_occlusion_size, _alpha,
_num_rounds, _num_threads, _filter_list_size, _num_frozen_points);
_num_threads, _filter_list_size, _num_frozen_points);
}

IndexWriteParametersBuilder(const IndexWriteParameters &wp)
: _search_list_size(wp.search_list_size), _max_degree(wp.max_degree),
_max_occlusion_size(wp.max_occlusion_size), _saturate_graph(wp.saturate_graph), _alpha(wp.alpha),
_num_rounds(wp.num_rounds), _filter_list_size(wp.filter_list_size), _num_frozen_points(wp.num_frozen_points)
_filter_list_size(wp.filter_list_size), _num_frozen_points(wp.num_frozen_points)
{
}
IndexWriteParametersBuilder(const IndexWriteParametersBuilder &) = delete;
Expand All @@ -115,10 +108,9 @@ class IndexWriteParametersBuilder
uint32_t _max_occlusion_size{defaults::MAX_OCCLUSION_SIZE};
bool _saturate_graph{defaults::SATURATE_GRAPH};
float _alpha{defaults::ALPHA};
uint32_t _num_rounds{defaults::NUM_ROUNDS};
uint32_t _num_threads{defaults::NUM_THREADS};
uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE};
uint32_t _num_frozen_points{defaults::NUM_FROZEN_POINTS};
uint32_t _num_frozen_points{defaults::NUM_FROZEN_POINTS_STATIC};
};

} // namespace diskann
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ requires = [
"cmake>=3.22",
"numpy>=1.21",
"wheel",
"ninja"
]
build-backend = "setuptools.build_meta"

Expand Down
28 changes: 28 additions & 0 deletions python/apps/cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import argparse
import utils


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="cluster", description="kmeans cluster points in a file"
)

parser.add_argument("-d", "--data_type", required=True)
parser.add_argument("-i", "--indexdata_file", required=True)
parser.add_argument("-k", "--num_clusters", type=int, required=True)
args = parser.parse_args()

npts, ndims = get_bin_metadata(indexdata_file)

data = utils.bin_to_numpy(args.data_type, args.indexdata_file)

offsets, permutation = utils.cluster_and_permute(
args.data_type, npts, ndims, data, args.num_clusters
)

permuted_data = data[permutation]

utils.numpy_to_bin(permuted_data, args.indexdata_file + ".cluster")
112 changes: 112 additions & 0 deletions python/apps/in-mem-dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import argparse

import diskannpy
import numpy as np
import utils

def insert_and_search(
dtype_str,
indexdata_file,
querydata_file,
Lb,
graph_degree,
K,
Ls,
num_insert_threads,
num_search_threads,
gt_file,
):
npts, ndims = utils.get_bin_metadata(indexdata_file)

if dtype_str == "float":
index = diskannpy.DynamicMemoryIndex(
"l2", np.float32, ndims, npts, Lb, graph_degree
)
queries = utils.bin_to_numpy(np.float32, querydata_file)
data = utils.bin_to_numpy(np.float32, indexdata_file)
elif dtype_str == "int8":
index = diskannpy.DynamicMemoryIndex(
"l2", np.int8, ndims, npts, Lb, graph_degree
)
queries = utils.bin_to_numpy(np.int8, querydata_file)
data = utils.bin_to_numpy(np.int8, indexdata_file)
elif dtype_str == "uint8":
index = diskannpy.DynamicMemoryIndex(
"l2", np.uint8, ndims, npts, Lb, graph_degree
)
queries = utils.bin_to_numpy(np.uint8, querydata_file)
data = utils.bin_to_numpy(np.uint8, indexdata_file)
else:
raise ValueError("data_type must be float, int8 or uint8")

tags = np.zeros(npts, dtype=np.uintc)
timer = utils.timer()
for i in range(npts):
tags[i] = i + 1
index.batch_insert(data, tags, num_insert_threads)
print('batch_insert complete in', timer.elapsed(), 's')

delete_tags = np.random.choice(
np.array(range(1, npts + 1, 1), dtype=np.uintc),
size=int(0.5 * npts),
replace=False
)
for tag in delete_tags:
index.mark_deleted(tag)
print('mark deletion completed in', timer.elapsed(), 's')

index.consolidate_delete()
print('consolidation completed in', timer.elapsed(), 's')

deleted_data = data[delete_tags - 1, :]

index.batch_insert(deleted_data, delete_tags, num_insert_threads)
print('re-insertion completed in', timer.elapsed(), 's')

tags, dists = index.batch_search(queries, K, Ls, num_search_threads)
print('Batch searched', queries.shape[0], ' queries in ', timer.elapsed(), 's')

res_ids = tags - 1
if gt_file != "":
recall = utils.calculate_recall_from_gt_file(K, res_ids, gt_file)
print(f"recall@{K} is {recall}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="in-mem-dynamic",
description="Inserts points dynamically in a clustered order and search from vectors in a file.",
)

parser.add_argument("-d", "--data_type", required=True)
parser.add_argument("-i", "--indexdata_file", required=True)
parser.add_argument("-q", "--querydata_file", required=True)
parser.add_argument("-Lb", "--Lbuild", default=50, type=int)
parser.add_argument("-Ls", "--Lsearch", default=50, type=int)
parser.add_argument("-R", "--graph_degree", default=32, type=int)
parser.add_argument("-TI", "--num_insert_threads", default=8, type=int)
parser.add_argument("-TS", "--num_search_threads", default=8, type=int)
parser.add_argument("-K", default=10, type=int)
parser.add_argument("--gt_file", default="")
args = parser.parse_args()

insert_and_search(
args.data_type,
args.indexdata_file,
args.querydata_file,
args.Lbuild,
args.graph_degree, # Build args
args.K,
args.Lsearch,
args.num_insert_threads,
args.num_search_threads, # search args
args.gt_file,
)

# An ingest optimized example with SIFT1M
# python3 ~/DiskANN/python/apps/in-mem-dynamic.py -d float \
# -i sift_base.fbin -q sift_query.fbin --gt_file gt100_base \
# -Lb 10 -R 30 -Ls 200
101 changes: 101 additions & 0 deletions python/apps/in-mem-static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import argparse
from xml.dom.pulldom import default_bufsize

import diskannpy
import numpy as np
import utils


def build_and_search(
dtype_str,
index_directory,
indexdata_file,
querydata_file,
Lb,
graph_degree,
K,
Ls,
num_threads,
gt_file,
index_prefix
):
if dtype_str == "float":
dtype = np.single
elif dtype_str == "int8":
dtype = np.byte
elif dtype_str == "uint8":
dtype = np.ubyte
else:
raise ValueError("data_type must be float, int8 or uint8")

# build index
diskannpy.build_memory_index(
data=indexdata_file,
metric="l2",
vector_dtype=dtype,
index_directory=index_directory,
complexity=Lb,
graph_degree=graph_degree,
num_threads=num_threads,
index_prefix=index_prefix,
alpha=1.2,
use_pq_build=False,
num_pq_bytes=8,
use_opq=False,
)

# ready search object
index = diskannpy.StaticMemoryIndex(
metric="l2",
vector_dtype=dtype,
data_path=indexdata_file,
index_directory=index_directory,
num_threads=num_threads, # this can be different at search time if you would like
initial_search_complexity=Ls,
index_prefix=index_prefix
)

queries = utils.bin_to_numpy(dtype, querydata_file)

ids, dists = index.batch_search(queries, 10, Ls, num_threads)

if gt_file != "":
recall = utils.calculate_recall_from_gt_file(K, ids, gt_file)
print(f"recall@{K} is {recall}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="in-mem-static",
description="Static in-memory build and search from vectors in a file",
)

parser.add_argument("-d", "--data_type", required=True)
parser.add_argument("-id", "--index_directory", required=False, default=".")
parser.add_argument("-i", "--indexdata_file", required=True)
parser.add_argument("-q", "--querydata_file", required=True)
parser.add_argument("-Lb", "--Lbuild", default=50, type=int)
parser.add_argument("-Ls", "--Lsearch", default=50, type=int)
parser.add_argument("-R", "--graph_degree", default=32, type=int)
parser.add_argument("-T", "--num_threads", default=8, type=int)
parser.add_argument("-K", default=10, type=int)
parser.add_argument("--gt_file", default="")
parser.add_argument("-ip", "--index_prefix", required=False, default="ann")
args = parser.parse_args()

build_and_search(
args.data_type,
args.index_directory.strip(),
args.indexdata_file.strip(),
args.querydata_file.strip(),
args.Lbuild,
args.graph_degree, # Build args
args.K,
args.Lsearch,
args.num_threads, # search args
args.gt_file,
args.index_prefix
)
Loading

0 comments on commit 38d8c44

Please sign in to comment.