Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into fcitx
Browse files Browse the repository at this point in the history
  • Loading branch information
Fcitx Bot committed May 21, 2024
2 parents c44dd82 + 022bbef commit bf3dbbb
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 104 deletions.
1 change: 1 addition & 0 deletions src/converter/segments.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class Segment final {
DICTIONARY_PREDICTOR_ZERO_QUERY_EMOJI = 1 << 3,
DICTIONARY_PREDICTOR_ZERO_QUERY_BIGRAM = 1 << 4,
DICTIONARY_PREDICTOR_ZERO_QUERY_SUFFIX = 1 << 5,
DICTIONARY_PREDICTOR_ZERO_QUERY_SUPPLEMENTAL_MODEL = 1 << 7,

USER_HISTORY_PREDICTOR = 1 << 6,
};
Expand Down
13 changes: 10 additions & 3 deletions src/engine/supplemental_model_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ class SupplementalModelInterface {
// Returns std::nullopt when the composition spellchecker is not
// enabled/available.
virtual std::optional<std::vector<composer::TypeCorrectedQuery>>
CorrectComposition(absl::string_view query, absl::string_view context,
bool disable_toggle_correction,
const commands::Request &request) const {
CorrectComposition(const ConversionRequest &request,
absl::string_view context) const {
return std::nullopt;
}

Expand All @@ -93,6 +92,14 @@ class SupplementalModelInterface {
virtual void RescoreResults(const ConversionRequest &request,
const Segments &segments,
absl::Span<prediction::Result> results) const {}

// Performs next word/phrase prediction given the context `segments`. Results
// are appended to `results`. Returns true if prediction was performed.
virtual bool Predict(const ConversionRequest &request,
const Segments &segments,
std::vector<prediction::Result> &results) const {
return false;
}
};

} // namespace mozc::engine
Expand Down
1 change: 0 additions & 1 deletion src/prediction/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ mozc_cc_library(
"//request:conversion_request",
"//request:request_util",
"//transliteration",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
Expand Down
33 changes: 7 additions & 26 deletions src/prediction/dictionary_prediction_aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,6 @@ bool IsLanguageAwareInputEnabled(const ConversionRequest &request) {
return lang_aware == Request::LANGUAGE_AWARE_SUGGESTION;
}

#if MOZC_ENABLE_NGRAM_RESCORING
bool IsNgramNextWordPredictionEnabled(const ConversionRequest &request) {
return request.request().decoder_experiment_params().ngram_enable_nwp();
}
#endif // MOZC_ENABLE_NGRAM_RESCORING

bool IsZeroQuerySuffixPredictionDisabled(const ConversionRequest &request) {
return request.request()
.decoder_experiment_params()
Expand Down Expand Up @@ -723,6 +717,12 @@ PredictionTypes DictionaryPredictionAggregator::AggregatePredictionForZeroQuery(
selected_types |= BIGRAM;
}
if (segments.history_segments_size() > 0) {
const engine::SupplementalModelInterface *supplemental_model =
modules_.GetSupplementalModel();
if (supplemental_model != nullptr &&
supplemental_model->Predict(request, segments, *results)) {
selected_types |= SUPPLEMENTAL_MODEL;
}
AggregateZeroQuerySuffixPrediction(request, segments, results);
selected_types |= SUFFIX;
}
Expand Down Expand Up @@ -1623,7 +1623,6 @@ void DictionaryPredictionAggregator::AggregateZeroQuerySuffixPrediction(
}
}


void DictionaryPredictionAggregator::AggregateEnglishPrediction(
const ConversionRequest &request, const Segments &segments,
std::vector<Result> *results) const {
Expand Down Expand Up @@ -1687,26 +1686,8 @@ void DictionaryPredictionAggregator::AggregateTypingCorrectedPrediction(
return;
}

auto disable_toggle_correction = [](const ConversionRequest &request) {
if (request.request().special_romanji_table() !=
Request::TOGGLE_FLICK_TO_HIRAGANA) {
return false;
}
const int length = request.composer().GetLength();
for (int i = 0; i < length; ++i) {
auto raw = request.composer().GetRawSubString(i, 1);
absl::string_view s(raw);
while (!s.empty() && absl::EndsWith(s, "*")) s.remove_suffix(1);
if (s.size() >= 2) return false;
}
return true;
};

const std::string asis = request.composer().GetStringForTypeCorrection();
const std::optional<std::vector<TypeCorrectedQuery>> corrected =
supplemental_model->CorrectComposition(asis, segments.history_key(),
disable_toggle_correction(request),
request.request());
supplemental_model->CorrectComposition(request, segments.history_key());
if (!corrected) {
return;
}
Expand Down
12 changes: 0 additions & 12 deletions src/prediction/dictionary_prediction_aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,6 @@ class DictionaryPredictionAggregator : public PredictionAggregatorInterface {
const ConversionRequest &request,
const Segments &segments) const override;

#if MOZC_ENABLE_NGRAM_RESCORING
void SetNgramModelForTesting(const ngram::NgramModelInterface *ngram_model) {
ngram_model_ = ngram_model;
}
#endif // MOZC_ENABLE_NGRAM_RESCORING

private:
class PredictiveLookupCallback;
class PrefixLookupCallback;
Expand Down Expand Up @@ -229,12 +223,6 @@ class DictionaryPredictionAggregator : public PredictionAggregatorInterface {
const Segments &segments,
std::vector<Result> *results) const;

#if MOZC_ENABLE_NGRAM_RESCORING
void AggregateZeroQueryNgramPrediction(const ConversionRequest &request,
const Segments &segments,
std::vector<Result> *results) const;
#endif // MOZC_ENABLE_NGRAM_RESCORING

void AggregateEnglishPrediction(const ConversionRequest &request,
const Segments &segments,
std::vector<Result> *results) const;
Expand Down
13 changes: 2 additions & 11 deletions src/prediction/dictionary_prediction_aggregator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ class DictionaryPredictionAggregatorTestPeer {
aggregator_.AggregateZeroQuerySuffixPrediction(request, segments, results);
}


void AggregateEnglishPrediction(const ConversionRequest &request,
const Segments &segments,
std::vector<Result> *results) const {
Expand Down Expand Up @@ -428,12 +427,6 @@ class MockDataAndAggregator {
modules_.SetSupplementalModel(supplemental_model);
}

#if MOZC_ENABLE_NGRAM_RESCORING
void set_ngram_model(const ngram::NgramModelInterface *ngram_model) {
aggregator_->SetNgramModel(ngram_model);
}
#endif // MOZC_ENABLE_NGRAM_RESCORING

private:
MockConverter converter_;
MockImmutableConverter mock_immutable_converter_;
Expand Down Expand Up @@ -1961,7 +1954,6 @@ TEST_F(DictionaryPredictionAggregatorTest, AggregateZeroQuerySuffixPrediction) {
}
}


struct EnglishPredictionTestEntry {
std::string name;
transliteration::TransliterationType input_mode;
Expand Down Expand Up @@ -2052,8 +2044,7 @@ TEST_F(DictionaryPredictionAggregatorTest,
public:
MOCK_METHOD(std::optional<std::vector<TypeCorrectedQuery>>,
CorrectComposition,
(absl::string_view, absl::string_view, bool,
const commands::Request &),
(const ConversionRequest &, absl::string_view),
(const, override));
MOCK_METHOD(void, PostCorrect, (Segments *), (const, override));
};
Expand Down Expand Up @@ -2087,7 +2078,7 @@ TEST_F(DictionaryPredictionAggregatorTest,
TypeCorrectedQuery::KANA_MODIFIER_INSENTIVE_ONLY);

auto mock = std::make_unique<MockSupplementalModel>();
EXPECT_CALL(*mock, CorrectComposition("よろさく", "ほんじつは", false, _))
EXPECT_CALL(*mock, CorrectComposition(_, "ほんじつは"))
.WillOnce(Return(expected));

data_and_aggregator->set_supplemental_model(mock.get());
Expand Down
3 changes: 3 additions & 0 deletions src/prediction/dictionary_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,9 @@ std::string DictionaryPredictor::GetPredictionTypeDebugString(
if (types & PredictionType::TYPING_CORRECTION) {
debug_desc.append("T");
}
if (types & PredictionType::SUPPLEMENTAL_MODEL) {
debug_desc.append(1, 'X');
}
return debug_desc;
}

Expand Down
33 changes: 0 additions & 33 deletions src/prediction/dictionary_predictor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1703,39 +1703,6 @@ TEST_F(DictionaryPredictorTest, MaybePopulateTypingCorrectedResultsTest) {
EXPECT_EQ(results.size(), 4);
}

{
request_->mutable_decoder_experiment_params()
->set_typing_correction_literal_on_top_correction_score_max_diff(0.5);
request_->mutable_decoder_experiment_params()
->set_typing_correction_literal_on_top_conversion_cost_max_diff(500);
auto results = base_results;
predictor.MaybePopulateTypingCorrectedResults(*convreq_for_prediction_,
segments, &results);
EXPECT_EQ(results.size(), 4);
}

{
request_->mutable_decoder_experiment_params()
->set_typing_correction_literal_on_top_correction_score_max_diff(1.0);
request_->mutable_decoder_experiment_params()
->set_typing_correction_literal_on_top_conversion_cost_max_diff(500);
auto results = base_results;
predictor.MaybePopulateTypingCorrectedResults(*convreq_for_prediction_,
segments, &results);
EXPECT_EQ(results.size(), 4);
}

{
request_->mutable_decoder_experiment_params()
->set_typing_correction_literal_on_top_correction_score_max_diff(0.5);
request_->mutable_decoder_experiment_params()
->set_typing_correction_literal_on_top_conversion_cost_max_diff(1000);
auto results = base_results;
predictor.MaybePopulateTypingCorrectedResults(*convreq_for_prediction_,
segments, &results);
EXPECT_EQ(results.size(), 4);
}

// disable typing correction.
{
config_->set_use_typing_correction(false);
Expand Down
6 changes: 2 additions & 4 deletions src/prediction/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,8 @@ enum PredictionType {
// entries from single kanji dictionary.
SINGLE_KANJI = 256,

#if MOZC_ENABLE_NGRAM_RESCORING
// entries from N-gram model.
NGRAM = 512,
#endif // MOZC_ENABLE_NGRAM_RESCORING
// entries from a supplemental model.
SUPPLEMENTAL_MODEL = 512,

// Suggests from |converter_|. The difference from REALTIME is that it uses
// the full converter with rewriter, history, etc.
Expand Down
33 changes: 20 additions & 13 deletions src/protocol/commands.proto
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ message Capability {
[default = NO_TEXT_DELETION_CAPABILITY];
}

// Next ID: 70
// Next ID: 74
// Bundles together some Android experiment flags so that they can be easily
// retrieved throughout the native code. These flags are generally specific to
// the decoder, and are made available when the decoder is initialized.
Expand Down Expand Up @@ -614,21 +614,13 @@ message DecoderExperimentParams {

reserved 38; // Deprecated typing_correction_conversion_cost_max_diff
reserved 46; // Deprecated enable_typing_correction_mixer_v2
reserved 47; // Deprecated typing_correction_literal_on_top_correction_score_max_diff
reserved 48; // Deprecated typing_correction_literal_on_top_conversion_cost_max_diff

// Trigger literal on top if
// correction_score <=
// typing_correction_literal_on_top_correction_score_max_diff
// OR
// conversion_cost_diff <=
// typing_correction_literal_on_top_spatial_score_max_diff
// OR
// correction_score <=
// typing_correction_literal_on_top_length_score_max_diff
// * (typing_correction_literal_on_top_length_decay^(input_length-3))
optional float typing_correction_literal_on_top_correction_score_max_diff = 47
[default = 0.0];
optional int32 typing_correction_literal_on_top_conversion_cost_max_diff = 48
[default = 0];
// Literal candidate is placed at least second position.
optional bool typing_correction_literal_at_least_second = 49
[default = false];
Expand Down Expand Up @@ -723,9 +715,24 @@ message DecoderExperimentParams {
[default = 0.0];
// When the top correction is kana_modifier_insensitive_correction, suppress
// other corrections.
optional bool
typing_correction_promote_kana_modifier_insensitive_only = 69
optional bool typing_correction_promote_kana_modifier_insensitive_only = 69
[default = false];

// Penalty for the first character mismatch.
optional float typing_correction_first_char_mismatch_penalty = 70
[default = 0.0];

// Penalties for the modifier corrections when intended modifiers exist.
// `local` is for the same modifier type, global ignores the modifier type.
optional float typing_correction_intended_modifier_local_penalty = 71
[default = 0.0];
optional float typing_correction_intended_modifier_global_penalty = 72
[default = 0.0];
// exponential decay factor to assign larger penalty when
// intended and added modifiers are closely located.
// The actual penalty is computed as
// (global|local)_penalty * decay^(distance - 1);
optional float typing_correction_intended_modifier_decay = 73 [default = 0.0];
}

// Clients' request to the server.
Expand Down
1 change: 0 additions & 1 deletion src/renderer/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ mozc_objc_library(
sdk_frameworks = [
"Carbon",
"Cocoa",
"ObjectiveC",
],
deps = [
":renderer_interface",
Expand Down

0 comments on commit bf3dbbb

Please sign in to comment.