Skip to content

Commit

Permalink
apply label parse improve to memory index (#524)
Browse files Browse the repository at this point in the history
* Fix compile issue

* Fix issue
  • Loading branch information
Sanhaoji2 authored Mar 7, 2024
1 parent 9bb0cf0 commit 416c661
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 12 deletions.
2 changes: 2 additions & 0 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
DISKANN_DLLEXPORT size_t load_delete_set(const std::string &filename);
#endif

size_t search_string_range(const std::string& str, char ch, size_t start, size_t end);

private:
// Distance functions
Metric _dist_metric = diskann::L2;
Expand Down
100 changes: 88 additions & 12 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2103,17 +2103,34 @@ void Index<T, TagT, LabelT>::convert_pts_label_to_bitmask(std::vector<std::vecto
template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::parse_label_file_in_bitset(const std::string& label_file, size_t& num_points, size_t num_labels)
{
std::ifstream infile(label_file);
std::ifstream infile(label_file, std::ios::binary);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
}
infile.seekg(0, std::ios::end);
size_t file_size = infile.tellg();

std::string buffer(file_size, ' ');

infile.seekg(0, std::ios::beg);
infile.read(&buffer[0], file_size);
infile.close();

std::string line, token;
unsigned line_cnt = 0;

while (std::getline(infile, line))
size_t cur_pos = 0;
size_t next_pos = 0;
while (cur_pos < file_size && cur_pos != std::string::npos)
{
next_pos = buffer.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
break;
}

cur_pos = next_pos + 1;

line_cnt++;
}

Expand All @@ -2124,23 +2141,68 @@ void Index<T, TagT, LabelT>::parse_label_file_in_bitset(const std::string& label
infile.seekg(0, std::ios::beg);
line_cnt = 0;

while (std::getline(infile, line))
std::string label_str;
cur_pos = 0;
next_pos = 0;
while (cur_pos < file_size && cur_pos != std::string::npos)
{
std::istringstream iss(line);
std::vector<LabelT> lbls(0);
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ','))
next_pos = buffer.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
LabelT token_as_num = std::stoul(token);
break;
}

size_t lbl_pos = cur_pos;
size_t next_lbl_pos = 0;
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
next_lbl_pos = search_string_range(buffer, ',', lbl_pos, next_pos);
if (next_lbl_pos == std::string::npos) // the last label in the whole file
{
next_lbl_pos = next_pos;
}

if (next_lbl_pos > next_pos) // the last label in one line
{
next_lbl_pos = next_pos;
}

label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos);
if (label_str[label_str.length() - 1] == '\t')
{
label_str.erase(label_str.length() - 1);
}

LabelT token_as_num = (LabelT)std::stoul(label_str);
simple_bitmask bm(_bitmask_buf.get_bitmask(line_cnt), _bitmask_buf._bitmask_size);
bm.set(token_as_num);
_labels.insert(token_as_num);

lbl_pos = next_lbl_pos + 1;
}

cur_pos = next_pos + 1;

line_cnt++;
}

//while (std::getline(infile, line))
//{
// std::istringstream iss(line);
// std::vector<LabelT> lbls(0);
// getline(iss, token, '\t');
// std::istringstream new_iss(token);
// while (getline(new_iss, token, ','))
// {
// token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
// token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
// LabelT token_as_num = std::stoul(token);
// simple_bitmask bm(_bitmask_buf.get_bitmask(line_cnt), _bitmask_buf._bitmask_size);
// bm.set(token_as_num);
// _labels.insert(token_as_num);
// }
// line_cnt++;
//}
num_points = (size_t)line_cnt;
diskann::cout << "Identified " << _labels.size() << " distinct label(s)" << std::endl;
}
Expand Down Expand Up @@ -3504,6 +3566,20 @@ void Index<T, TagT, LabelT>::search_with_optimized_layout(const T *query, size_t
}
}

template <typename T, typename TagT, typename LabelT>
size_t Index<T, TagT, LabelT>::search_string_range(const std::string& str, char ch, size_t start, size_t end)
{
for (; start != end; start++)
{
if (str[start] == ch)
{
return start;
}
}

return std::string::npos;
}

/* Internals of the library */
template <typename T, typename TagT, typename LabelT> const float Index<T, TagT, LabelT>::INDEX_GROWTH_FACTOR = 1.5f;

Expand Down

0 comments on commit 416c661

Please sign in to comment.