diff --git a/src/rime/commit_history.h b/src/rime/commit_history.h index 8a56c03b5..dfe8ea0fe 100644 --- a/src/rime/commit_history.h +++ b/src/rime/commit_history.h @@ -29,6 +29,9 @@ class CommitHistory : public list { void Push(const KeyEvent& key_event); void Push(const Composition& composition, const string& input); string repr() const; + string latest_text() const { + return empty() ? string() : back().text; + } }; } // Namespace rime diff --git a/src/rime/gear/poet.cc b/src/rime/gear/poet.cc index f15b6373e..5a5f87038 100644 --- a/src/rime/gear/poet.cc +++ b/src/rime/gear/poet.cc @@ -42,7 +42,8 @@ bool Poet::LeftAssociateCompare(const Sentence& one, const Sentence& other) { } an Poet::MakeSentence(const WordGraph& graph, - size_t total_length) { + size_t total_length, + const string& preceding_text) { // TODO: save more intermediate sentence candidates map> sentences; sentences[0] = New(language_); @@ -61,7 +62,8 @@ an Poet::MakeSentence(const WordGraph& graph, const DictEntryList& entries(x.second); for (const auto& entry : entries) { auto new_sentence = New(*sentences[start_pos]); - new_sentence->Extend(*entry, end_pos, is_rear, grammar_.get()); + new_sentence->Extend( + *entry, end_pos, is_rear, preceding_text, grammar_.get()); if (sentences.find(end_pos) == sentences.end() || compare_(*sentences[end_pos], *new_sentence)) { DLOG(INFO) << "updated sentences " << end_pos << ") with " diff --git a/src/rime/gear/poet.h b/src/rime/gear/poet.h index ab47c7672..bbd10c2f4 100644 --- a/src/rime/gear/poet.h +++ b/src/rime/gear/poet.h @@ -35,7 +35,9 @@ class Poet { Compare compare = CompareWeight); ~Poet(); - an MakeSentence(const WordGraph& graph, size_t total_length); + an MakeSentence(const WordGraph& graph, + size_t total_length, + const string& preceding_text); private: const Language* language_; diff --git a/src/rime/gear/script_translator.cc b/src/rime/gear/script_translator.cc index ebf16fecf..ce3cfcae1 100644 --- a/src/rime/gear/script_translator.cc +++ b/src/rime/gear/script_translator.cc @@ -214,6 +214,11 @@ string ScriptTranslator::Spell(const Code& code) { return result; } +string ScriptTranslator::GetPrecedingText() const { + return contextual_suggestions_ ? + engine_->context()->commit_history().latest_text() : string(); +} + bool ScriptTranslator::Memorize(const CommitEntry& commit_entry) { bool update_elements = false; // avoid updating single character entries within a phrase which is @@ -538,12 +543,14 @@ an ScriptTranslation::MakeSentence(Dictionary* dict, } } } - auto sentence = poet_->MakeSentence(graph, syllable_graph.interpreted_length); - if (sentence) { + if (auto sentence = poet_->MakeSentence(graph, + syllable_graph.interpreted_length, + translator_->GetPrecedingText())) { sentence->Offset(start_); sentence->set_syllabifier(syllabifier_); + return sentence; } - return sentence; + return nullptr; } } // namespace rime diff --git a/src/rime/gear/script_translator.h b/src/rime/gear/script_translator.h index ea799e65e..139a4be33 100644 --- a/src/rime/gear/script_translator.h +++ b/src/rime/gear/script_translator.h @@ -37,6 +37,7 @@ class ScriptTranslator : public Translator, string FormatPreedit(const string& preedit); string Spell(const Code& code); + string GetPrecedingText() const; // options int max_homophones() const { return max_homophones_; } diff --git a/src/rime/gear/table_translator.cc b/src/rime/gear/table_translator.cc index 5d35339be..24c8ae6a3 100644 --- a/src/rime/gear/table_translator.cc +++ b/src/rime/gear/table_translator.cc @@ -342,7 +342,9 @@ bool TableTranslator::Memorize(const CommitEntry& commit_entry) { } string phrase; for (; it != history.rend(); ++it) { - if (it->type != "table" && it->type != "sentence" && it->type != "uniquified") + if (it->type != "table" && + it->type != "sentence" && + it->type != "uniquified") break; if (phrase.empty()) { phrase = it->text; // last word @@ -362,6 +364,11 @@ bool TableTranslator::Memorize(const CommitEntry& commit_entry) { return true; } +string TableTranslator::GetPrecedingText() const { + return contextual_suggestions_ ? + engine_->context()->commit_history().latest_text() : string(); +} + // SentenceSyllabifier class SentenceSyllabifier : public PhraseSyllabifier { @@ -680,7 +687,9 @@ TableTranslator::MakeSentence(const string& input, size_t start, } } } - if (auto sentence = poet_->MakeSentence(graph, input.length())) { + if (auto sentence = poet_->MakeSentence(graph, + input.length(), + GetPrecedingText())) { auto result = Cached( this, std::move(sentence), diff --git a/src/rime/gear/table_translator.h b/src/rime/gear/table_translator.h index 05acf9dbd..1977f5942 100644 --- a/src/rime/gear/table_translator.h +++ b/src/rime/gear/table_translator.h @@ -35,7 +35,7 @@ class TableTranslator : public Translator, an MakeSentence(const string& input, size_t start, bool include_prefix_phrases = false); - + string GetPrecedingText() const; UnityTableEncoder* encoder() const { return encoder_.get(); } protected: diff --git a/src/rime/gear/translator_commons.cc b/src/rime/gear/translator_commons.cc index 0192b7cf9..8e305fcd9 100644 --- a/src/rime/gear/translator_commons.cc +++ b/src/rime/gear/translator_commons.cc @@ -91,8 +91,10 @@ bool Spans::HasVertex(size_t vertex) const { void Sentence::Extend(const DictEntry& entry, size_t end_pos, bool is_rear, + const string& preceding_text, Grammar* grammar) { - entry_->weight += Grammar::Evaluate(text(), entry, is_rear, grammar); + const string& context = empty() ? preceding_text : text(); + entry_->weight += Grammar::Evaluate(context, entry, is_rear, grammar); entry_->text.append(entry.text); entry_->code.insert(entry_->code.end(), entry.code.begin(), @@ -118,6 +120,8 @@ TranslatorOptions::TranslatorOptions(const Ticket& ticket) { config->GetString(ticket.name_space + "/delimiter", &delimiters_) || config->GetString("speller/delimiter", &delimiters_); config->GetString(ticket.name_space + "/tag", &tag_); + config->GetBool(ticket.name_space + "/contextual_suggestions", + &contextual_suggestions_); config->GetBool(ticket.name_space + "/enable_completion", &enable_completion_); config->GetBool(ticket.name_space + "/strict_spelling", diff --git a/src/rime/gear/translator_commons.h b/src/rime/gear/translator_commons.h index 091fca691..2f2efff5c 100644 --- a/src/rime/gear/translator_commons.h +++ b/src/rime/gear/translator_commons.h @@ -124,6 +124,7 @@ class Sentence : public Phrase { void Extend(const DictEntry& entry, size_t end_pos, bool is_rear, + const string& preceding_text, Grammar* grammar); void Offset(size_t offset); @@ -159,6 +160,10 @@ class TranslatorOptions { const string& delimiters() const { return delimiters_; } const string& tag() const { return tag_; } void set_tag(const string& tag) { tag_ = tag; } + bool contextual_suggestions() const { return contextual_suggestions_; } + void set_contextual_suggestions(bool enabled) { + contextual_suggestions_ = enabled; + } bool enable_completion() const { return enable_completion_; } void set_enable_completion(bool enabled) { enable_completion_ = enabled; } bool strict_spelling() const { return strict_spelling_; } @@ -171,6 +176,7 @@ class TranslatorOptions { protected: string delimiters_; string tag_ = "abc"; + bool contextual_suggestions_ = false; bool enable_completion_ = true; bool strict_spelling_ = false; double initial_quality_ = 0.;