Skip to content

Commit

Permalink
Add string_view to dictionary for fast lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
kpu committed Nov 28, 2021
1 parent a20c0d2 commit ffee8e4
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
cmake_minimum_required(VERSION 2.8.9)
project(fasttext)

set(CMAKE_CXX_STANDARD 17)

# The version number.
set (fasttext_VERSION_MAJOR 0)
set (fasttext_VERSION_MINOR 1)
Expand Down
66 changes: 58 additions & 8 deletions src/dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ Dictionary::Dictionary(std::shared_ptr<Args> args, std::istream& in)
load(in);
}

int32_t Dictionary::find(const std::string& w) const {
int32_t Dictionary::find(const std::string_view w) const {
return find(w, hash(w));
}

int32_t Dictionary::find(const std::string& w, uint32_t h) const {
int32_t Dictionary::find(const std::string_view w, uint32_t h) const {
int32_t word2intsize = word2int_.size();
int32_t id = h % word2intsize;
while (word2int_[id] != -1 && words_[word2int_[id]].word != w) {
Expand Down Expand Up @@ -126,12 +126,12 @@ bool Dictionary::discard(int32_t id, real rand) const {
return rand > pdiscard_[id];
}

int32_t Dictionary::getId(const std::string& w, uint32_t h) const {
int32_t Dictionary::getId(const std::string_view w, uint32_t h) const {
int32_t id = find(w, h);
return word2int_[id];
}

int32_t Dictionary::getId(const std::string& w) const {
int32_t Dictionary::getId(const std::string_view w) const {
int32_t h = find(w);
return word2int_[h];
}
Expand All @@ -142,7 +142,7 @@ entry_type Dictionary::getType(int32_t id) const {
return words_[id].type;
}

entry_type Dictionary::getType(const std::string& w) const {
entry_type Dictionary::getType(const std::string_view w) const {
return (w.find(args_->label) == 0) ? entry_type::label : entry_type::word;
}

Expand All @@ -160,7 +160,7 @@ std::string Dictionary::getWord(int32_t id) const {
// Since all fasttext models that were already released were trained
// using signed char, we fixed the hash function to make models
// compatible whatever compiler is used.
uint32_t Dictionary::hash(const std::string& str) const {
uint32_t Dictionary::hash(const std::string_view str) const {
uint32_t h = 2166136261;
for (size_t i = 0; i < str.size(); i++) {
h = h ^ uint32_t(int8_t(str[i]));
Expand Down Expand Up @@ -324,11 +324,16 @@ void Dictionary::addWordNgrams(

void Dictionary::addSubwords(
std::vector<int32_t>& line,
const std::string& token,
const std::string_view token,
int32_t wid) const {
if (wid < 0) { // out of vocab
if (token != EOS) {
computeSubwords(BOW + token + EOW, line);
std::string concat;
concat.reserve(BOW.size() + token.size() + EOW.size());
concat += BOW;
concat.append(token.data(), token.size());
concat += EOW;
computeSubwords(concat, line);
}
} else {
if (args_->maxn <= 0) { // in vocab w/o subwords
Expand Down Expand Up @@ -406,6 +411,51 @@ int32_t Dictionary::getLine(
return ntokens;
}

namespace {
bool readWordNoNewline(std::string_view& in, std::string_view& word) {
const std::string_view spaces(" \n\r\t\v\f\0");
std::string_view::size_type begin = in.find_first_not_of(spaces);
if (begin == std::string_view::npos) {
in.remove_prefix(in.size());
return false;
}
in.remove_prefix(begin);
word = in.substr(0, in.find_first_of(spaces));
in.remove_prefix(word.size());
return true;
}
} // namespace

int32_t Dictionary::getStringNoNewline(
std::string_view in,
std::vector<int32_t>& words,
std::vector<int32_t>& labels) const {
std::vector<int32_t> word_hashes;
std::string_view token;
int32_t ntokens = 0;

words.clear();
labels.clear();
while (readWordNoNewline(in, token)) {
uint32_t h = hash(token);
int32_t wid = getId(token, h);
entry_type type = wid < 0 ? getType(token) : getType(wid);

ntokens++;
if (type == entry_type::word) {
addSubwords(words, token, wid);
word_hashes.push_back(h);
} else if (type == entry_type::label && wid >= 0) {
labels.push_back(wid - nwords_);
}
if (token == EOS) {
break;
}
}
addWordNgrams(words, word_hashes, args_->wordNgrams);
return ntokens;
}

void Dictionary::pushHash(std::vector<int32_t>& hashes, int32_t id) const {
if (pruneidx_size_ == 0 || id < 0) {
return;
Expand Down
17 changes: 10 additions & 7 deletions src/dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ostream>
#include <random>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>

Expand All @@ -36,13 +37,13 @@ class Dictionary {
static const int32_t MAX_VOCAB_SIZE = 30000000;
static const int32_t MAX_LINE_SIZE = 1024;

int32_t find(const std::string&) const;
int32_t find(const std::string&, uint32_t h) const;
int32_t find(const std::string_view) const;
int32_t find(const std::string_view, uint32_t h) const;
void initTableDiscard();
void initNgrams();
void reset(std::istream&) const;
void pushHash(std::vector<int32_t>&, int32_t) const;
void addSubwords(std::vector<int32_t>&, const std::string&, int32_t) const;
void addSubwords(std::vector<int32_t>&, const std::string_view, int32_t) const;

std::shared_ptr<Args> args_;
std::vector<int32_t> word2int_;
Expand Down Expand Up @@ -71,10 +72,10 @@ class Dictionary {
int32_t nwords() const;
int32_t nlabels() const;
int64_t ntokens() const;
int32_t getId(const std::string&) const;
int32_t getId(const std::string&, uint32_t h) const;
int32_t getId(const std::string_view) const;
int32_t getId(const std::string_view, uint32_t h) const;
entry_type getType(int32_t) const;
entry_type getType(const std::string&) const;
entry_type getType(const std::string_view) const;
bool discard(int32_t, real) const;
std::string getWord(int32_t) const;
const std::vector<int32_t>& getSubwords(int32_t) const;
Expand All @@ -87,7 +88,7 @@ class Dictionary {
const std::string&,
std::vector<int32_t>&,
std::vector<std::string>* substrings = nullptr) const;
uint32_t hash(const std::string& str) const;
uint32_t hash(const std::string_view str) const;
void add(const std::string&);
bool readWord(std::istream&, std::string&) const;
void readFromFile(std::istream&);
Expand All @@ -99,6 +100,8 @@ class Dictionary {
const;
int32_t getLine(std::istream&, std::vector<int32_t>&, std::minstd_rand&)
const;
int32_t getStringNoNewline(std::string_view, std::vector<int32_t>&,
std::vector<int32_t>&) const;
void threshold(int64_t, int64_t);
void prune(std::vector<int32_t>&);
bool isPruned() {
Expand Down

0 comments on commit ffee8e4

Please sign in to comment.