Skip to content
This repository has been archived by the owner on Apr 23, 2024. It is now read-only.

Commit

Permalink
Merge pull request #45 from xbelonogov/add-dropout
Browse files Browse the repository at this point in the history
Add dropout
  • Loading branch information
Dmitry Yutkin authored Nov 19, 2019
2 parents 662c0f4 + 2114307 commit 8d72e38
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 34 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Class constructor. Loads the trained model.
Class `youtokentome.BPE` has the following methods:
#### encode
```python
encode(self, sentences, output_type=yttm.OutputType.ID, bos=False, eos=False, reverse=False)
encode(self, sentences, output_type=yttm.OutputType.ID, bos=False, eos=False, reverse=False, dropout_prob=0)
```

**Args:**
Expand All @@ -117,6 +117,7 @@ encode(self, sentences, output_type=yttm.OutputType.ID, bos=False, eos=False, re
* `bos`: bool, if True then token “beginning of sentence” will be added
* `eos`: bool, if True then token “end of sentence” will be added
* `reverse`: bool, if True the output sequence of tokens will be reversed
* `dropout_prob`: float, BPE-dropout probability (the probability of a merge being dropped). Must be in the range [0, 1].


**Returns:** If `output_type` is equal to `youtokentome.OutputType.ID` or `youtokentome.OutputType.SUBWORD`
Expand Down Expand Up @@ -258,6 +259,7 @@ Options:
--eos Add tab 'end of sentence'.
--reverse Reverse output sequence of tokens.
--stream Process each line before reading the next one.
--dropout_prob BPE-dropout probability (the probability of a merge being dropped). [default: 0]
--help Show this message and exit.
```

Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

setup(
name="youtokentome",
version="1.0.3",
version="1.0.4",
packages=find_packages(),
description="Unsupervised text tokenizer focused on computational efficiency",
long_description=LONG_DESCRIPTION,
Expand All @@ -49,4 +49,5 @@
"Programming Language :: C++",
],
ext_modules=cythonize(extensions),
)
)

103 changes: 89 additions & 14 deletions youtokentome/cpp/bpe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <string>
#include <thread>
#include <vector>
#include <random>
#include <unordered_set>

#include "third_party/flat_hash_map.h"
Expand Down Expand Up @@ -1459,6 +1460,71 @@ Status train_bpe(const string &input_path, const string &model_path,
return Status();
}


template<typename T>
class BasePriorityQueue {
public:
virtual void push(T x) = 0;
virtual bool pop(T& x) = 0;
virtual ~BasePriorityQueue() {}
};

template<typename T>
class STLQueue : public BasePriorityQueue<T> {
std::priority_queue<T> q;
void push(T x) override {
q.push(x);
}
bool pop(T& x) override {
if (q.empty()) {
return false;
}
x = q.top();
q.pop();
return true;
}
};

std::mt19937 rnd;

template<typename T>
class DropoutQueue : public BasePriorityQueue<T> {
double skip_prob;
std::uniform_real_distribution<> dist;
std::priority_queue<T> q;
vector<T> skipped_elements;
public:
explicit DropoutQueue(double _skip_prob):skip_prob(_skip_prob), dist(std::uniform_real_distribution<>(0, 1)) {}
void push(T x) override {
q.push(x);
}
bool pop(T& x) override {
assert(skipped_elements.empty());
while (true) {
if (q.empty()) {
for (auto y: skipped_elements) {
q.push(y);
}
skipped_elements.clear();
return false;
}
T temp = q.top();
q.pop();
if (dist(rnd) < skip_prob) {
skipped_elements.push_back(temp);
}
else {
for (auto y: skipped_elements) {
q.push(y);
}
skipped_elements.clear();
x = temp;
return true;
}
}
}
};

DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8,
const EncodingConfig &encoding_config,
OutputType output_type) const {
Expand Down Expand Up @@ -1539,27 +1605,36 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8,
}
list.back().next = -1;

std::priority_queue<MergeEvent2> queue;

auto pair_code = [&](uint64_t first_pos) {
auto second_pos = list[first_pos].next;
return int2comb(list[first_pos].token_id, list[second_pos].token_id);
};

std::unique_ptr<BasePriorityQueue<MergeEvent2>> queue(nullptr);
if (encoding_config.dropout_prob == 0) {
queue.reset(new STLQueue<MergeEvent2>());
}
else {
queue.reset(new DropoutQueue<MergeEvent2>(encoding_config.dropout_prob));
}

auto push_in_queue_if_rule_exist = [&](uint64_t pos) {
auto it = rule2id.find(pair_code(pos));
if (it != rule2id.end()) {
queue.push({it->second, static_cast<int>(pos)});
queue->push({it->second, static_cast<int>(pos)});
}
};

for (uint64_t j = 0; j + 1 < list.size(); j++) {
push_in_queue_if_rule_exist(j);
}

while (!queue.empty()) {
MergeEvent2 event = queue.top();
queue.pop();
while (true) {
MergeEvent2 event;
if (!queue->pop(event)) {
break;
}
int rule_id = event.priority;
int pos_1 = event.pos;
int pos_2 = list[pos_1].next;
Expand Down Expand Up @@ -1737,8 +1812,8 @@ Status BaseEncoder::encode_parallel(

Status BaseEncoder::encode_as_ids(const vector<string> &sentences, vector<vector<int>> *ids,
bool bos, bool eos,
bool reverse) const {
EncodingConfig encoding_config = {bos, eos, reverse};
bool reverse, double dropout_prob) const {
EncodingConfig encoding_config = {bos, eos, reverse, dropout_prob};

std::vector<DecodeResult> decode_results;
Status status = encode_parallel(sentences, encoding_config, ID, &decode_results);
Expand All @@ -1755,9 +1830,9 @@ Status BaseEncoder::encode_as_ids(const vector<string> &sentences, vector<vector
Status BaseEncoder::encode_as_subwords(
const vector<string> &sentences,
vector<vector<string>> *subwords,
bool bos, bool eos, bool reverse) const {
bool bos, bool eos, bool reverse, double dropout_prob) const {
time_check("");
EncodingConfig encoding_config = {bos, eos, reverse};
EncodingConfig encoding_config = {bos, eos, reverse, dropout_prob};
std::vector<DecodeResult> decode_results;
Status status = encode_parallel(sentences, encoding_config, SUBWORD, &decode_results);
if (!status.ok()) {
Expand Down Expand Up @@ -1939,7 +2014,7 @@ void BaseEncoder::vocab_cli(bool verbose) const {
}

Status BaseEncoder::encode_cli(const string &output_type_str, bool stream,
bool bos, bool eos, bool reverse) const {
bool bos, bool eos, bool reverse, double dropout_prob) const {
std::ios_base::sync_with_stdio(false);
OutputType output_type;
if (output_type_str == "id") {
Expand All @@ -1953,7 +2028,7 @@ Status BaseEncoder::encode_cli(const string &output_type_str, bool stream,
string sentence;
while (getline(std::cin, sentence)) {
vector<vector<string>> subwords;
Status status = encode_as_subwords({sentence}, &subwords, bos, eos, reverse);
Status status = encode_as_subwords({sentence}, &subwords, bos, eos, reverse, dropout_prob);
if (!status.ok()) {
return status;
}
Expand All @@ -1964,7 +2039,7 @@ Status BaseEncoder::encode_cli(const string &output_type_str, bool stream,
string sentence;
while (getline(std::cin, sentence)) {
vector<vector<int>> ids;
Status status = encode_as_ids({sentence}, &ids, bos, eos, reverse);
Status status = encode_as_ids({sentence}, &ids, bos, eos, reverse, dropout_prob);
if (!status.ok()) {
return status;
}
Expand All @@ -1983,15 +2058,15 @@ Status BaseEncoder::encode_cli(const string &output_type_str, bool stream,
auto sentences = read_lines_from_stdin(batch_limit, &processed);
if (output_type == SUBWORD) {
vector<vector<string>> subwords;
Status status = encode_as_subwords(sentences, &subwords, bos, eos, reverse);
Status status = encode_as_subwords(sentences, &subwords, bos, eos, reverse, dropout_prob);
if (!status.ok()) {
return status;
}
write_to_stdout(subwords, false);
} else {
assert(output_type == ID);
vector<vector<int>> ids;
Status status = encode_as_ids(sentences, &ids, bos, eos, reverse);
Status status = encode_as_ids(sentences, &ids, bos, eos, reverse, dropout_prob);
if (!status.ok()) {
return status;
}
Expand Down
6 changes: 3 additions & 3 deletions youtokentome/cpp/bpe.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ class BaseEncoder {

Status encode_as_ids(
const std::vector<std::string> &sentences, std::vector<std::vector<int>> *ids, bool bos = false,
bool eos = false, bool reverse = false) const;
bool eos = false, bool reverse = false, double dropout_prob=0) const;

Status encode_as_subwords(
const std::vector<std::string> &sentences,
std::vector<std::vector<std::string>> *subwords,
bool bos = false,
bool eos = false, bool reverse = false) const;
bool eos = false, bool reverse = false, double dropout_prob=0) const;

Status id_to_subword(int id, std::string *subword, bool replace_space = false) const;

Expand All @@ -65,7 +65,7 @@ class BaseEncoder {
std::vector<std::string> vocabulary() const;

Status encode_cli(const std::string &output_type, bool stream, bool bos = false,
bool eos = false, bool reverse = false) const;
bool eos = false, bool reverse = false, double dropout_prob = 0) const;

Status decode_cli(const std::unordered_set<int> *ignore_ids) const;

Expand Down
1 change: 1 addition & 0 deletions youtokentome/cpp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ struct EncodingConfig {
bool bos;
bool eos;
bool reverse;
double dropout_prob;
};

bool is_space(uint32_t ch);
Expand Down
23 changes: 12 additions & 11 deletions youtokentome/cpp/yttm.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ cdef extern from "bpe.h" namespace "vkcom":
cdef cppclass BaseEncoder:
BaseEncoder(const string& model_path, int n_threads, Status* status)

Status encode_as_ids(const vector[string] &sentences, vector[vector[int]]* ids, bool bos, bool eos, bool reverse) const
Status encode_as_subwords(const vector[string]& sentences, vector[vector[string]]* subwords, bool bos, bool eos, bool reverse) const
Status encode_as_ids(const vector[string] &sentences, vector[vector[int]]* ids, bool bos, bool eos, bool reverse, double dropout_prob) const
Status encode_as_subwords(const vector[string]& sentences, vector[vector[string]]* subwords, bool bos, bool eos, bool reverse, double dropout_prob) const

Status encode_cli(string output_type, bool stream, bool bos, bool eos, bool reverse) const
Status encode_cli(string output_type, bool stream, bool bos, bool eos, bool reverse, double dropout_prob) const

Status decode_cli(const unordered_set[int]* ignore_ids) const

Expand Down Expand Up @@ -84,38 +84,39 @@ cdef class BPE:
if status.code != 0:
raise ValueError(status.message.decode())

def encode(self, sentences, output_type, bos, eos, reverse):
def encode(self, sentences, output_type, bos, eos, reverse, dropout_prob):
cdef vector[string] s
cdef vector[vector[string]] ret_subwords
cdef vector[vector[int]] ret_ids
cdef Status status
if dropout_prob < 0 or dropout_prob > 1:
raise ValueError("dropout_prob value must be in the range [0, 1]. Current value of dropout_prob = " + str(dropout_prob))
if output_type == 'id':
if isinstance(sentences, str):
s = [sentences.encode()]
assert len(s) == 1
status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse)
status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse, dropout_prob)
if status.code != 0:
raise ValueError(status.message.decode())
return ret_ids[0]

assert isinstance(sentences, list) or isinstance(sentences, tuple)
s = [x.encode() for x in sentences]
status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse)
status = self.encoder.encode_as_ids(s, &ret_ids, bos, eos, reverse, dropout_prob)
if status.code != 0:
raise ValueError(status.message.decode())
return ret_ids
elif output_type == 'subword':
if isinstance(sentences, str):
s = [sentences.encode()]
status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse)
status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse, dropout_prob)
if status.code != 0:
raise ValueError(status.message.decode())
assert len(ret_subwords) == 1
return [piece.decode() for piece in ret_subwords[0]]

assert isinstance(sentences, list) or isinstance(sentences, tuple)
s = [x.encode() for x in sentences]
status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse)
status = self.encoder.encode_as_subwords(s, &ret_subwords, bos, eos, reverse, dropout_prob)
if status.code != 0:
raise ValueError(status.message.decode())
return [[piece.decode() for piece in sentence] for sentence in ret_subwords]
Expand Down Expand Up @@ -163,8 +164,8 @@ cdef class BPE:
cdef vector[string] vocab = self.encoder.vocabulary()
return [token.decode() for token in vocab]

def encode_cli(self, output_type, stream, bos, eos, reverse):
cdef Status status = self.encoder.encode_cli(output_type.encode(), stream, bos, eos, reverse)
def encode_cli(self, output_type, stream, bos, eos, reverse, dropout_prob):
cdef Status status = self.encoder.encode_cli(output_type.encode(), stream, bos, eos, reverse, dropout_prob)
if status.code != 0:
raise ValueError(status.message.decode())

Expand Down
2 changes: 2 additions & 0 deletions youtokentome/youtokentome.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def encode(
bos: bool = False,
eos: bool = False,
reverse: bool = False,
dropout_prob: float = 0,
) -> Union[List[List[int]], List[List[str]]]:
if not isinstance(output_type, OutputType):
raise TypeError(
Expand All @@ -62,6 +63,7 @@ def encode(
bos=bos,
eos=eos,
reverse=reverse,
dropout_prob=dropout_prob,
)

def vocab_size(self) -> int:
Expand Down
13 changes: 10 additions & 3 deletions youtokentome/yttm_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def main():
@click.option(
"--coverage",
type=click.FLOAT,
help="Amount of characters covered by the model.",
help="Percentage of characters covered by the model.",
default=1.0,
show_default=True,
)
Expand Down Expand Up @@ -98,7 +98,14 @@ def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eo
@click.option(
"--stream", is_flag=True, help="Process each line before reading the next one."
)
def encode(model, output_type, n_threads, bos, eos, reverse, stream):
@click.option(
"--dropout_prob",
type=click.FLOAT,
default=0,
show_default=True,
help="BPE-dropout probability (the probability of a merge being dropped)",
)
def encode(model, output_type, n_threads, bos, eos, reverse, stream, dropout_prob):
"""Encode text to ids or subwords."""
if n_threads < -1 or n_threads == 0:
raise ValueError(
Expand All @@ -107,7 +114,7 @@ def encode(model, output_type, n_threads, bos, eos, reverse, stream):
)

bpe = yttmc.BPE(model, n_threads)
bpe.encode_cli(output_type, stream, bos, eos, reverse)
bpe.encode_cli(output_type, stream, bos, eos, reverse, dropout_prob)


def validate_ignore_ids(ctx, param, value):
Expand Down

0 comments on commit 8d72e38

Please sign in to comment.