diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 9230bdd9f..7dc324cfa 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1201,10 +1201,45 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons } #ifdef USE_BING_INFRA -bool getNextCompletedRequest(std::shared_ptr &reader, IOContext &ctx, int &completedIndex) +bool getNextCompletedRequest(std::shared_ptr &reader, IOContext &ctx, size_t size, int &completedIndex) { - reader->wait(ctx, completedIndex); - return completedIndex != -1; + if ((*ctx.m_pRequests)[0].m_callback) + { + bool waitsRemaining = false; + long completeCount = ctx.m_completeCount; + do + { + for (int i = 0; i < size; i++) + { + auto ithStatus = (*ctx.m_pRequestsStatus)[i]; + if (ithStatus == IOContext::Status::READ_SUCCESS) + { + completedIndex = i; + return true; + } + else if (ithStatus == IOContext::Status::READ_WAIT) + { + waitsRemaining = true; + } + } + + // if we didn't find one in READ_SUCCESS, wait for one to complete. + if (waitsRemaining) + { + WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); + // this assumes the knowledge of the reader behavior (implicit + // contract). need better factoring? + } + } while (waitsRemaining); + + completedIndex = -1; + return false; + } + else + { + reader->wait(ctx, completedIndex); + return completedIndex != -1; + } } #endif @@ -1513,7 +1548,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t long requestCount = static_cast(frontier_read_reqs.size()); // If we issued read requests and if a read is complete or there are // reads in wait state, then enter the while loop. - while (requestCount > 0 && getNextCompletedRequest(reader, ctx, completedIndex)) + while (requestCount > 0 && getNextCompletedRequest(reader, ctx, requestCount, completedIndex)) { assert(completedIndex >= 0); auto &frontier_nhood = frontier_nhoods[completedIndex];