Skip to content

Commit

Permalink
chore:(script_translator, syllabifier): use borrowed corrector
Browse files Browse the repository at this point in the history
  • Loading branch information
lotem committed Feb 26, 2019
1 parent c587900 commit bbef968
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 41 deletions.
9 changes: 4 additions & 5 deletions src/rime/algo/syllabifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ int Syllabifier::BuildSyllableGraph(const string &input,
for (auto &m : matches) {
match_set.insert(m.value);
}
if (enable_correction_) {
if (corrector_) {
Corrections corrections;
corrector_->ToleranceSearch(prism, current_input, &corrections, 5);
for (const auto &m : corrections) {
Expand Down Expand Up @@ -188,7 +188,7 @@ int Syllabifier::BuildSyllableGraph(const string &input,
good.insert(i);
}

if (enable_completion_ && farthest < input.length()) {
if (corrector_ && farthest < input.length()) {
DLOG(INFO) << "completion enabled";
const size_t kExpandSearchLimit = 512;
vector<Prism::Match> keys;
Expand Down Expand Up @@ -284,9 +284,8 @@ void Syllabifier::Transpose(SyllableGraph* graph) {
}
}

void Syllabifier::EnableCorrection(an<Corrector> corrector) {
enable_correction_ = true;
corrector_ = std::move(corrector);
void Syllabifier::EnableCorrection(Corrector* corrector) {
corrector_ = corrector;
}

} // namespace rime
5 changes: 2 additions & 3 deletions src/rime/algo/syllabifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Syllabifier {
RIME_API int BuildSyllableGraph(const string &input,
Prism &prism,
SyllableGraph *graph);
RIME_API void EnableCorrection(an<Corrector> corrector);
RIME_API void EnableCorrection(Corrector* corrector);

protected:
void CheckOverlappedSpellings(SyllableGraph *graph,
Expand All @@ -66,8 +66,7 @@ class Syllabifier {
string delimiters_;
bool enable_completion_ = false;
bool strict_spelling_ = false;
an<Corrector> corrector_ = nullptr;
bool enable_correction_ = false;
Corrector* corrector_ = nullptr;
};

} // namespace rime
Expand Down
53 changes: 30 additions & 23 deletions src/rime/gear/script_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,16 @@ static bool syllabify_dfs(SyllabifyTask* task,
class ScriptSyllabifier : public PhraseSyllabifier {
public:
ScriptSyllabifier(ScriptTranslator* translator,
Corrector* corrector,
const string& input,
size_t start)
: translator_(translator), input_(input), start_(start) {
: translator_(translator), input_(input), start_(start),
syllabifier_(translator->delimiters(),
translator->enable_completion(),
translator->strict_spelling()) {
if (corrector) {
syllabifier_.EnableCorrection(corrector);
}
}

virtual Spans Syllabify(const Phrase* phrase);
Expand All @@ -88,17 +95,21 @@ class ScriptSyllabifier : public PhraseSyllabifier {
ScriptTranslator* translator_;
string input_;
size_t start_;
Syllabifier syllabifier_;
SyllableGraph syllable_graph_;
};

class ScriptTranslation : public Translation {
public:
ScriptTranslation(ScriptTranslator* translator,
const string& input, size_t start,
bool enable_correction = false)
: translator_(translator), start_(start),
syllabifier_(New<ScriptSyllabifier>(translator, input, start)),
enable_correction_(enable_correction) {
Corrector* corrector,
const string& input,
size_t start)
: translator_(translator),
start_(start),
syllabifier_(New<ScriptSyllabifier>(
translator, corrector, input, start)),
enable_correction_(corrector) {
set_exhausted(true);
}
bool Evaluate(Dictionary* dict, UserDictionary* user_dict);
Expand Down Expand Up @@ -147,13 +158,14 @@ ScriptTranslator::ScriptTranslator(const Ticket& ticket)
config->GetBool(name_space_ + "/enable_correction", &enable_correction_);
}
if (enable_correction_) {
auto corrector = Corrector::Require("corrector");
corrector_.reset(corrector->Create(ticket));
if (auto* corrector = Corrector::Require("corrector")) {
corrector_.reset(corrector->Create(ticket));
}
}
}

an<Translation> ScriptTranslator::Query(const string& input,
const Segment& segment) {
const Segment& segment) {
if (!dict_ || !dict_->loaded())
return nullptr;
if (!segment.HasTag(tag_))
Expand All @@ -167,7 +179,10 @@ an<Translation> ScriptTranslator::Query(const string& input,
!IsUserDictDisabledFor(input);

// the translator should survive translations it creates
auto result = New<ScriptTranslation>(this, input, segment.start, enable_correction_);
auto result = New<ScriptTranslation>(this,
corrector_.get(),
input,
segment.start);
if (!result ||
!result->Evaluate(dict_.get(),
enable_user_dict ? user_dict_.get() : NULL)) {
Expand Down Expand Up @@ -239,17 +254,9 @@ Spans ScriptSyllabifier::Syllabify(const Phrase* phrase) {
}

size_t ScriptSyllabifier::BuildSyllableGraph(Prism& prism) {
Syllabifier syllabifier(translator_->delimiters(),
translator_->enable_completion(),
translator_->strict_spelling());
if (translator_->enable_correction()) {
syllabifier.EnableCorrection(translator_->corrector());
}
auto consumed = (size_t)syllabifier.BuildSyllableGraph(input_,
prism,
&syllable_graph_);

return consumed;
return (size_t)syllabifier_.BuildSyllableGraph(input_,
prism,
&syllable_graph_);
}

bool ScriptSyllabifier::IsCandidateCorrection(const rime::Phrase &cand) const {
Expand Down Expand Up @@ -499,8 +506,8 @@ bool ScriptTranslation::CheckEmpty() {
return exhausted();
}

an<Sentence>
ScriptTranslation::MakeSentence(Dictionary* dict, UserDictionary* user_dict) {
an<Sentence> ScriptTranslation::MakeSentence(Dictionary* dict,
UserDictionary* user_dict) {
const int kMaxSyllablesForUserPhraseQuery = 5;
const auto& syllable_graph = syllabifier_->syllable_graph();
WordGraph graph;
Expand Down
8 changes: 3 additions & 5 deletions src/rime/gear/script_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
namespace rime {

class Code;
class Corrector;
struct DictEntry;
struct DictEntryCollector;
class Dictionary;
class UserDictionary;
class EditDistanceCorrector;
struct SyllableGraph;

class ScriptTranslator : public Translator,
Expand All @@ -31,7 +31,7 @@ class ScriptTranslator : public Translator,
ScriptTranslator(const Ticket& ticket);

virtual an<Translation> Query(const string& input,
const Segment& segment);
const Segment& segment);
virtual bool Memorize(const CommitEntry& commit_entry);

string FormatPreedit(const string& preedit);
Expand All @@ -40,14 +40,12 @@ class ScriptTranslator : public Translator,
// options
int spelling_hints() const { return spelling_hints_; }
bool always_show_comments() const { return always_show_comments_; }
bool enable_correction() const { return enable_correction_; }
an<Corrector> corrector() const { return corrector_; }

protected:
int spelling_hints_ = 0;
bool always_show_comments_ = false;
bool enable_correction_ = false;
an<Corrector> corrector_ = nullptr;
the<Corrector> corrector_;
};

} // namespace rime
Expand Down
15 changes: 10 additions & 5 deletions test/corrector_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ class RimeCorrectorSearchTest : public ::testing::Test {
std::inserter(keyset, keyset.begin()));
prism_->Build(keyset);

corrector_.reset(new rime::NearSearchCorrector);
}
void TearDown() override {}
protected:
rime::map<rime::string, rime::SyllableId> syllable_id_;
rime::the<rime::Prism> prism_;
rime::the<rime::Corrector> corrector_;
};

class RimeCorrectorTest : public ::testing::Test {
Expand All @@ -56,6 +58,8 @@ class RimeCorrectorTest : public ::testing::Test {
std::copy(syllables.begin(), syllables.end(),
std::inserter(keyset, keyset.begin()));
prism_->Build(keyset);

corrector_.reset(new rime::NearSearchCorrector);
}

virtual void TearDown() {
Expand All @@ -64,11 +68,12 @@ class RimeCorrectorTest : public ::testing::Test {
protected:
rime::map<rime::string, rime::SyllableId> syllable_id_;
rime::the<rime::Prism> prism_;
rime::the<rime::Corrector> corrector_;
};

TEST_F(RimeCorrectorSearchTest, CaseNearSubstitute) {
rime::Syllabifier s;
s.EnableCorrection(std::make_shared<rime::NearSearchCorrector>());
s.EnableCorrection(corrector_.get());
rime::SyllableGraph g;
const rime::string input("chsng");
s.BuildSyllableGraph(input, *prism_, &g);
Expand All @@ -83,7 +88,7 @@ TEST_F(RimeCorrectorSearchTest, CaseNearSubstitute) {

TEST_F(RimeCorrectorSearchTest, CaseFarSubstitute) {
rime::Syllabifier s;
s.EnableCorrection(std::make_shared<rime::NearSearchCorrector>());
s.EnableCorrection(corrector_.get());
rime::SyllableGraph g;
const rime::string input("chpng");
s.BuildSyllableGraph(input, *prism_, &g);
Expand All @@ -95,7 +100,7 @@ TEST_F(RimeCorrectorSearchTest, CaseFarSubstitute) {

TEST_F(RimeCorrectorSearchTest, DISABLED_CaseTranspose) {
rime::Syllabifier s;
s.EnableCorrection(std::make_shared<rime::NearSearchCorrector>());
s.EnableCorrection(corrector_.get());
rime::SyllableGraph g;
const rime::string input("cahng");
s.BuildSyllableGraph(input, *prism_, &g);
Expand All @@ -110,7 +115,7 @@ TEST_F(RimeCorrectorSearchTest, DISABLED_CaseTranspose) {

TEST_F(RimeCorrectorSearchTest, CaseCorrectionSyllabify) {
rime::Syllabifier s;
s.EnableCorrection(std::make_shared<rime::NearSearchCorrector>());
s.EnableCorrection(corrector_.get());
rime::SyllableGraph g;
const rime::string input("chabgtyan");
s.BuildSyllableGraph(input, *prism_, &g);
Expand All @@ -130,7 +135,7 @@ TEST_F(RimeCorrectorSearchTest, CaseCorrectionSyllabify) {

TEST_F(RimeCorrectorTest, CaseMultipleEdges1) {
rime::Syllabifier s;
s.EnableCorrection(std::make_shared<rime::NearSearchCorrector>());
s.EnableCorrection(corrector_.get());
rime::SyllableGraph g;
const rime::string input("jiejue"); // jie'jue jie'jie jue'jue jue'jie
s.BuildSyllableGraph(input, *prism_, &g);
Expand Down

0 comments on commit bbef968

Please sign in to comment.