Skip to content

Commit

Permalink
Better NaN handling (facebookresearch#2986)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#2986

A NaN vector is a vector with at least one NaN (not-a-number) entry.
After discussion in the Faiss team we decided that:
- training should throw an exception on NaN vectors
- added NaN vectors should be ignored (never returned)
- searched NaN vectors should return only -1s

This diff implements this for a few common index types + adds relevant tests.

Reviewed By: algoriddle

Differential Revision: D48031390

fbshipit-source-id: 99e7786582e91950e3a53c1d8bcffdd00b6afd24
  • Loading branch information
mdouze authored and facebook-github-bot committed Aug 4, 2023
1 parent a4ddb18 commit a3fbf2d
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 108 deletions.
25 changes: 14 additions & 11 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <faiss/impl/HNSW.h>

#include <cmath>
#include <string>

#include <faiss/impl/AuxIndexStructures.h>
Expand Down Expand Up @@ -542,12 +543,11 @@ int search_from_candidates(
for (int i = 0; i < candidates.size(); i++) {
idx_t v1 = candidates.ids[i];
float d = candidates.dis[i];
FAISS_ASSERT(v1 >= 0);
assert(v1 >= 0);
if (!sel || sel->is_member(v1)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_replace_top(nres, D, I, d, v1);
if (d < D[0]) {
faiss::maxheap_replace_top(k, D, I, d, v1);
nres++;
}
}
vt.set(v1);
Expand Down Expand Up @@ -612,10 +612,9 @@ int search_from_candidates(

auto add_to_heap = [&](const size_t idx, const float dis) {
if (!sel || sel->is_member(idx)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, dis, idx);
} else if (dis < D[0]) {
faiss::maxheap_replace_top(nres, D, I, dis, idx);
if (dis < D[0]) {
faiss::maxheap_replace_top(k, D, I, dis, idx);
nres++;
}
}
candidates.push(idx, dis);
Expand Down Expand Up @@ -668,7 +667,7 @@ int search_from_candidates(
stats.n3 += ndis;
}

return nres;
return std::min(nres, k);
}

std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
Expand Down Expand Up @@ -816,6 +815,11 @@ HNSWStats HNSW::search(
// greedy search on upper levels
storage_idx_t nearest = entry_point;
float d_nearest = qdis(nearest);
if (!std::isfinite(d_nearest)) {
// means either the query or the entry point are NaN: in
// both cases we can only return -1 as a result
return stats;
}

for (int level = max_level; level >= 1; level--) {
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
Expand All @@ -826,7 +830,6 @@ HNSWStats HNSW::search(
MinimaxHeap candidates(ef);

candidates.push(nearest, d_nearest);

search_from_candidates(
*this, qdis, k, I, D, candidates, vt, stats, 0, 0, params);
} else {
Expand Down
7 changes: 4 additions & 3 deletions faiss/impl/ResultHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,8 @@ struct SingleBestResultHandler {
/// begin results for query # i
void begin(const size_t current_idx) {
this->current_idx = current_idx;
min_dis = HUGE_VALF;
min_idx = 0;
min_dis = C::neutral();
min_idx = -1;
}

/// add one result for query i
Expand All @@ -472,7 +472,8 @@ struct SingleBestResultHandler {
this->i1 = i1;

for (size_t i = i0; i < i1; i++) {
this->dis_tab[i] = HUGE_VALF;
this->dis_tab[i] = C::neutral();
this->ids_tab[i] = -1;
}
}

Expand Down
5 changes: 5 additions & 0 deletions faiss/impl/ScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,11 @@ void ScalarQuantizer::set_derived_sizes() {
}

void ScalarQuantizer::train(size_t n, const float* x) {
for (size_t i = 0; i < n * d; i++) {
FAISS_THROW_IF_NOT_MSG(
std::isfinite(x[i]), "training data contains NaN or Inf");
}

int bit_per_dim = qtype == QT_4bit_uniform ? 4
: qtype == QT_4bit ? 4
: qtype == QT_6bit ? 6
Expand Down
182 changes: 182 additions & 0 deletions tests/test_error_reporting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""This script tests a few failure cases of Faiss and whether they are handled
properly."""

import numpy as np
import unittest
import faiss

from common_faiss_tests import get_dataset_2
from faiss.contrib.datasets import SyntheticDataset


class TestValidIndexParams(unittest.TestCase):

def test_IndexIVFPQ(self):
d = 32
nb = 1000
nt = 1500
nq = 200

(xt, xb, xq) = get_dataset_2(d, nt, nb, nq)

coarse_quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8)
index.cp.min_points_per_centroid = 5 # quiet warning
index.train(xt)
index.add(xb)

# invalid nprobe
index.nprobe = 0
k = 10
self.assertRaises(RuntimeError, index.search, xq, k)

# invalid k
index.nprobe = 4
k = -10
self.assertRaises(AssertionError, index.search, xq, k)

# valid params
index.nprobe = 4
k = 10
D, nns = index.search(xq, k)

self.assertEqual(D.shape[0], nq)
self.assertEqual(D.shape[1], k)

def test_IndexFlat(self):
d = 32
nb = 1000
nt = 0
nq = 200

(xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
index = faiss.IndexFlat(d, faiss.METRIC_L2)

index.add(xb)

# invalid k
k = -5
self.assertRaises(AssertionError, index.search, xq, k)

# valid k
k = 5
D, I = index.search(xq, k)

self.assertEqual(D.shape[0], nq)
self.assertEqual(D.shape[1], k)


class TestReconsException(unittest.TestCase):

def test_recons_exception(self):

d = 64 # dimension
nb = 1000
rs = np.random.RandomState(1234)
xb = rs.rand(nb, d).astype('float32')
nlist = 10
quantizer = faiss.IndexFlatL2(d) # the other index
index = faiss.IndexIVFFlat(quantizer, d, nlist)
index.train(xb)
index.add(xb)
index.make_direct_map()

index.reconstruct(9)

self.assertRaises(
RuntimeError,
index.reconstruct, 100001
)

def test_reconstuct_after_add(self):
index = faiss.index_factory(10, 'IVF5,SQfp16')
index.train(faiss.randn((100, 10), 123))
index.add(faiss.randn((100, 10), 345))
index.make_direct_map()
index.add(faiss.randn((100, 10), 678))

# should not raise an exception
index.reconstruct(5)
print(index.ntotal)
index.reconstruct(150)


class TestNaN(unittest.TestCase):
""" NaN values handling is transparent: they don't produce results
but should not crash. The tests below cover a few common index types.
"""

def do_test_train(self, factory_string):
""" NaN and Inf should raise an exception at train time """
ds = SyntheticDataset(32, 200, 20, 10)
index = faiss.index_factory(ds.d, factory_string)
# try to train with NaNs
xt = ds.get_train().copy()
xt[:, ::4] = np.nan
self.assertRaises(RuntimeError, index.train, xt)

def test_train_IVFSQ(self):
self.do_test_train("IVF10,SQ8")

def test_train_IVFPQ(self):
self.do_test_train("IVF10,PQ4np")

def test_train_SQ(self):
self.do_test_train("SQ8")

def do_test_add(self, factory_string):
""" stored NaNs should not be returned at search time """
ds = SyntheticDataset(32, 200, 20, 10)
index = faiss.index_factory(ds.d, factory_string)
if not index.is_trained:
index.train(ds.get_train())
xb = ds.get_database()
xb[12, 3] = np.nan
index.add(xb)
D, I = index.search(ds.get_queries(), 20)
self.assertTrue(np.where(I == 12)[0].size == 0)

def test_add_Flat(self):
self.do_test_add("Flat")

def test_add_HNSW(self):
self.do_test_add("HNSW32,Flat")

def xx_test_add_SQ8(self):
# this is expected to fail because:
# in ASAN mode, the float NaN -> int conversion crashes
# in opt mode it works but there is no way to encode the NaN,
# so the value cannot be ignored.
self.do_test_add("SQ8")

def test_add_IVFFlat(self):
self.do_test_add("IVF10,Flat")

def do_test_search(self, factory_string):
""" NaN query vectors should return -1 """
ds = SyntheticDataset(32, 200, 20, 10)
index = faiss.index_factory(ds.d, factory_string)
if not index.is_trained:
index.train(ds.get_train())
index.add(ds.get_database())
xq = ds.get_queries()
xq[7, 3] = np.nan
D, I = index.search(ds.get_queries(), 20)
self.assertTrue(np.all(I[7] == -1))

def test_search_Flat(self):
self.do_test_search("Flat")

def test_search_HNSW(self):
self.do_test_search("HNSW32,Flat")

def test_search_IVFFlat(self):
self.do_test_search("IVF10,Flat")

def test_search_SQ(self):
self.do_test_search("SQ8")
Loading

0 comments on commit a3fbf2d

Please sign in to comment.