Skip to content

Commit

Permalink
bbs producer/consumer threading (facebookresearch#2901)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#2901

This diff allows each GPU to work independently, a hot centroid (eg. out-of-distribution queries that hit a centroid heavily) will only block the one GPU that is processing it, others will continue to pick up work independently.

Reviewed By: mdouze

Differential Revision: D46521298

fbshipit-source-id: 171cb06cce8b2d16b7bd744799b105b3cd525be3
  • Loading branch information
algoriddle authored and facebook-github-bot committed Jun 14, 2023
1 parent d8a6350 commit 092606b
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 108 deletions.
251 changes: 150 additions & 101 deletions contrib/big_batch_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import os
from multiprocessing.pool import ThreadPool
import threading
import _thread
from queue import Queue
import traceback
import datetime

import numpy as np
import faiss
Expand Down Expand Up @@ -60,14 +64,21 @@ def toc(self):

def report(self, l):
if self.verbose == 1 or (
l > 1000 and time.time() < self.t_display + 1.0):
self.verbose == 2 and (
l > 1000 and time.time() < self.t_display + 1.0
)
):
return
t = time.time() - self.t0
print(
f"[{time.time()-self.t0:.1f} s] list {l}/{self.index.nlist} "
f"[{t:.1f} s] list {l}/{self.index.nlist} "
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
f"wait {self.t_accu[4]:.3f}",
end="\r", flush=True
f"wait {self.t_accu[4]:.3f} "
f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} "
f"mem {faiss.get_mem_usage_kb()}",
end="\r" if self.verbose <= 2 else "\n",
flush=True,
)
self.t_display = time.time()

Expand Down Expand Up @@ -141,24 +152,25 @@ def add_results_to_heap(self, q_subset, D, list_ids, I):
def sizes_in_checkpoint(self):
return (self.xq.shape, self.index.nprobe, self.index.nlist)

def write_checkpoint(self, fname, cur_list_no):
def write_checkpoint(self, fname, completed):
# write to temp file then move to final file
tmpname = fname + ".tmp"
pickle.dump(
{
"sizes": self.sizes_in_checkpoint(),
"cur_list_no": cur_list_no,
"rh": (self.rh.D, self.rh.I),
}, open(tmpname, "wb"), -1
)
with open(tmpname, "wb") as f:
pickle.dump(
{
"sizes": self.sizes_in_checkpoint(),
"completed": completed,
"rh": (self.rh.D, self.rh.I),
}, f, -1)
os.replace(tmpname, fname)

def read_checkpoint(self, fname):
ckp = pickle.load(open(fname, "rb"))
with open(fname, "rb") as f:
ckp = pickle.load(f)
assert ckp["sizes"] == self.sizes_in_checkpoint()
self.rh.D[:] = ckp["rh"][0]
self.rh.I[:] = ckp["rh"][1]
return ckp["cur_list_no"]
return ckp["completed"]


class BlockComputer:
Expand Down Expand Up @@ -225,11 +237,11 @@ def big_batch_search(
verbose=0,
threaded=0,
use_float16=False,
prefetch_threads=8,
computation_threads=0,
prefetch_threads=1,
computation_threads=1,
q_assign=None,
checkpoint=None,
checkpoint_freq=64,
checkpoint_freq=7200,
start_list=0,
end_list=None,
crash_at=-1
Expand All @@ -251,7 +263,7 @@ def big_batch_search(
threaded=0: sequential execution
threaded=1: prefetch next bucket while computing the current one
threaded>1: prefetch this many buckets at a time.
threaded=2: prefetch prefetch_threads buckets at a time.
compute_threads>1: the knn function will get an additional thread_no that
tells which worker should handle this.
Expand Down Expand Up @@ -311,12 +323,13 @@ def big_batch_search(
if end_list is None:
end_list = index.nlist

completed = set()
if checkpoint is not None:
assert (start_list, end_list) == (0, index.nlist)
if os.path.exists(checkpoint):
print("recovering checkpoint", checkpoint)
start_list = bbs.read_checkpoint(checkpoint)
print(" start at list", start_list)
completed = bbs.read_checkpoint(checkpoint)
print(" already completed", len(completed))
else:
print("no checkpoint: starting from scratch")

Expand Down Expand Up @@ -363,94 +376,130 @@ def add_results_and_prefetch(to_add, l):
bbs.add_results_to_heap(*to_add)
pool.close()
else:
# run by batches with parallel prefetch and parallel comp
list_step = threaded
assert start_list % list_step == 0

if prefetch_threads == 0:
prefetch_map = map
else:
prefetch_pool = ThreadPool(prefetch_threads)
prefetch_map = prefetch_pool.map

if computation_threads > 0:
comp_pool = ThreadPool(computation_threads)

def add_results_and_prefetch_batch(to_add, l):
def add_results(to_add):
for ta in to_add: # this one cannot be run in parallel...
if ta is not None:
bbs.add_results_to_heap(*ta)
if prefetch_threads == 0:
add_results(to_add)
else:
add_a = prefetch_pool.apply_async(add_results, (to_add, ))
next_lists = range(l, min(l + list_step, index.nlist))
res = list(prefetch_map(bbs.prepare_bucket, next_lists))
if prefetch_threads > 0:
add_a.get()
return res

# used only when computation_threads > 1
thread_id_to_seq_lock = threading.Lock()
thread_id_to_seq = {}

def do_comp(bucket):
(q_subset, xq_l, list_ids, xb_l) = bucket
def task_manager_thread(
task,
pool_size,
start_task,
end_task,
completed,
output_queue,
input_queue,
):
try:
tid = thread_id_to_seq[threading.get_ident()]
except KeyError:
with thread_id_to_seq_lock:
tid = len(thread_id_to_seq)
thread_id_to_seq[threading.get_ident()] = tid
D, I = comp.block_search(xq_l, xb_l, list_ids, k, thread_id=tid)
return q_subset, D, list_ids, I

prefetched_buckets = add_results_and_prefetch_batch([], start_list)
to_add = []
pool = ThreadPool(1)
prefetched_buckets_a = None

# loop over inverted lists
for l in range(start_list, end_list, list_step):
bbs.report(l)
buckets = prefetched_buckets
prefetched_buckets_a = pool.apply_async(
add_results_and_prefetch_batch, (to_add, l + list_step))

bbs.start_t_accu()

to_add = []
if computation_threads == 0:
for q_subset, xq_l, list_ids, xb_l in buckets:
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
to_add.append((q_subset, D, list_ids, I))
else:
to_add = list(comp_pool.map(do_comp, buckets))

bbs.stop_t_accu(2)
with ThreadPool(pool_size) as pool:
res = [pool.apply_async(
task,
args=(i, output_queue, input_queue))
for i in range(start_task, end_task)
if i not in completed]
for r in res:
r.get()
pool.close()
pool.join()
output_queue.put(None)
except:
traceback.print_exc()
_thread.interrupt_main()
raise

def task_manager(*args):
task_manager = threading.Thread(
target=task_manager_thread,
args=args,
)
task_manager.daemon = True
task_manager.start()
return task_manager

def prepare_task(task_id, output_queue, input_queue=None):
try:
# print(f"Prepare start: {task_id}")
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id)
output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l))
# print(f"Prepare end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
raise

def compute_task(task_id, output_queue, input_queue):
try:
# print(f"Compute start: {task_id}")
t_wait = 0
while True:
t0 = time.time()
input_value = input_queue.get()
t_wait += time.time() - t0
if input_value is None:
# signal for other compute tasks
input_queue.put(None)
break
centroid, q_subset, xq_l, list_ids, xb_l = input_value
# print(f'Compute work start: task {task_id}, centroid {centroid}')
t0 = time.time()
if computation_threads > 1:
D, I = comp.block_search(
xq_l, xb_l, list_ids, k, thread_id=task_id
)
else:
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
t_compute = time.time() - t0
# print(f'Compute work end: task {task_id}, centroid {centroid}')
t0 = time.time()
output_queue.put(
(centroid, t_wait, t_compute, q_subset, D, list_ids, I)
)
t_wait = time.time() - t0
# print(f"Compute end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
raise

prepare_to_compute_queue = Queue(2)
compute_to_main_queue = Queue(2)
compute_task_manager = task_manager(
compute_task,
computation_threads,
0,
computation_threads,
set(),
compute_to_main_queue,
prepare_to_compute_queue,
)
prepare_task_manager = task_manager(
prepare_task,
prefetch_threads,
start_list,
end_list,
completed,
prepare_to_compute_queue,
None,
)

t_checkpoint = time.time()
while True:
value = compute_to_main_queue.get()
if not value:
break
centroid, t_wait, t_compute, q_subset, D, list_ids, I = value
# to test checkpointing
if l == crash_at:
if centroid == crash_at:
1 / 0

bbs.start_t_accu()
prefetched_buckets = prefetched_buckets_a.get()
bbs.stop_t_accu(4)

bbs.t_accu[2] += t_compute
bbs.t_accu[4] += t_wait
bbs.add_results_to_heap(q_subset, D, list_ids, I)
completed.add(centroid)
bbs.report(centroid)
if checkpoint is not None:
if (l // list_step) % checkpoint_freq == 0:
print("writing checkpoint %s" % l)
bbs.write_checkpoint(checkpoint, l)
if time.time() - t_checkpoint > checkpoint_freq:
print("writing checkpoint")
bbs.write_checkpoint(checkpoint, completed)
t_checkpoint = time.time()

# flush add
for ta in to_add:
bbs.add_results_to_heap(*ta)
pool.close()
if prefetch_threads != 0:
prefetch_pool.close()
if computation_threads != 0:
comp_pool.close()
prepare_task_manager.join()
compute_task_manager.join()

bbs.tic("finalize heap")
bbs.rh.finalize()
Expand Down
14 changes: 7 additions & 7 deletions tests/test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import platform
import os
import random
import tempfile

from faiss.contrib import datasets
from faiss.contrib import inspect_tools
Expand Down Expand Up @@ -507,7 +508,7 @@ def do_test(self, factory_string, metric=faiss.METRIC_L2):
Dref, Iref = index.search(ds.get_queries(), k)
# faiss.omp_set_num_threads(1)
for method in ("pairwise_distances", "knn_function", "index"):
for threaded in 0, 1, 3, 8:
for threaded in 0, 1, 2:
Dnew, Inew = big_batch_search.big_batch_search(
index, ds.get_queries(),
k, method=method,
Expand Down Expand Up @@ -537,16 +538,15 @@ def test_checkpoint(self):
index.nprobe = 5
Dref, Iref = index.search(ds.get_queries(), k)

r = random.randrange(1 << 60)
checkpoint = "/tmp/test_big_batch_checkpoint.%d" % r
checkpoint = tempfile.mktemp()
try:
# First big batch search
try:
Dnew, Inew = big_batch_search.big_batch_search(
index, ds.get_queries(),
k, method="knn_function",
threaded=4,
checkpoint=checkpoint, checkpoint_freq=4,
threaded=2,
checkpoint=checkpoint, checkpoint_freq=0.1,
crash_at=20
)
except ZeroDivisionError:
Expand All @@ -557,8 +557,8 @@ def test_checkpoint(self):
Dnew, Inew = big_batch_search.big_batch_search(
index, ds.get_queries(),
k, method="knn_function",
threaded=4,
checkpoint=checkpoint, checkpoint_freq=4
threaded=2,
checkpoint=checkpoint, checkpoint_freq=5
)
self.assertLess((Inew != Iref).sum() / Iref.size, 1e-4)
np.testing.assert_almost_equal(Dnew, Dref, decimal=4)
Expand Down

0 comments on commit 092606b

Please sign in to comment.