Skip to content

Commit

Permalink
Merge segmenter_lstm with segmenter
Browse files Browse the repository at this point in the history
  • Loading branch information
makotokato committed Jun 17, 2022
1 parent 8900f22 commit 9010803
Show file tree
Hide file tree
Showing 19 changed files with 120 additions and 518 deletions.
1 change: 0 additions & 1 deletion CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ experimental/collator/ @hsivonen @echeran
experimental/normalizer/ @hsivonen @echeran
experimental/provider_ppucd/ @echeran
experimental/segmenter/ @aethanyc @makotokato
experimental/segmenter_lstm/ @aethanyc @sffc
ffi/capi/ @Manishearth
ffi/cpp/ @Manishearth
ffi/ecma402/ @filmil
Expand Down
16 changes: 3 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ members = [
"experimental/collator",
"experimental/normalizer",
"experimental/segmenter",
"experimental/segmenter_lstm",
"ffi/capi_cdylib",
"ffi/diplomat",
"ffi/capi_staticlib",
Expand Down
7 changes: 5 additions & 2 deletions experimental/segmenter/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ skip_optional_dependencies = true
icu_char16trie = { version = "0.1", path = "../char16trie" }
icu_codepointtrie = { path = "../../utils/codepointtrie" }
icu_provider = { version = "0.6", path = "../../provider/core", features = ["macros"] }
icu_segmenter_lstm = { version = "0.1", path = "../segmenter_lstm", optional = true }
serde = { version = "1.0", default-features = false, features = ["derive", "alloc"], optional = true }
serde_json = { version = "1.0", default-features = false, features = ["alloc"] }
lazy_static = { version = "1.0", features = ["spin_no_std"] }
zerovec = { version = "0.7", path = "../../utils/zerovec", features = ["yoke"] }
crabbake = { version = "0.4", path = "../../experimental/crabbake", optional = true, features = ["derive"] }
litemap = { version = "0.4.0", path = "../../utils/litemap", optional = true, features = ["serde"] }
ndarray = { git = "https://github.com/rust-ndarray/ndarray", rev = "31244100631382bb8ee30721872a928bfdf07f44", default-features = false, optional = true, features = ["serde"] }
unicode-segmentation = { version = "1.3.0", optional = true }
num-traits = { version = "0.2", optional = true }

[dev-dependencies]
criterion = "0.3"
Expand All @@ -57,6 +60,6 @@ required-features = ["lstm"]

[features]
default = []
lstm = ["icu_segmenter_lstm"]
lstm = ["litemap", "ndarray", "num-traits", "serde", "unicode-segmentation"]
serde = ["dep:serde", "zerovec/serde", "icu_codepointtrie/serde"]
datagen = ["serde", "crabbake", "zerovec/crabbake", "icu_codepointtrie/crabbake"]
10 changes: 10 additions & 0 deletions experimental/segmenter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ extern crate lazy_static;
// Use the LSTM when the feature is enabled.
#[cfg(feature = "lstm")]
mod lstm;
#[cfg(feature = "lstm")]
mod lstm_bies;
#[cfg(feature = "lstm")]
mod lstm_error;
#[cfg(feature = "lstm")]
mod lstm_structs;
#[cfg(feature = "lstm")]
mod math_helper;

pub use crate::dictionary::{DictionaryBreakIterator, DictionarySegmenter};
pub use crate::grapheme::{
Expand All @@ -167,6 +175,8 @@ pub use crate::line::{
Latin1Char, LineBreakIterator, LineBreakOptions, LineBreakRule, LineBreakSegmenter, Utf16Char,
WordBreakRule,
};
#[cfg(feature = "lstm")]
pub use crate::lstm_structs::LstmDataMarker;
pub use crate::provider::{
GraphemeClusterBreakDataV1Marker, LineBreakDataV1Marker, RuleBreakDataV1,
RuleBreakPropertyTable, RuleBreakStateTable, SentenceBreakDataV1Marker,
Expand Down
12 changes: 6 additions & 6 deletions experimental/segmenter/src/lstm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).

use crate::language::*;
use crate::lstm_bies::Lstm;
use crate::lstm_structs::{LstmData, LstmDataMarker};

use alloc::string::String;
use alloc::string::ToString;
use core::char::decode_utf16;
use icu_provider::DataError;
use icu_provider::DataPayload;
use icu_segmenter_lstm::lstm::Lstm;
use icu_segmenter_lstm::structs;

// TODO:
// json file is big, So I should use anoher binary format like npy.
Expand All @@ -21,14 +21,14 @@ const BURMESE_MODEL: &[u8; 475209] =
include_bytes!("../tests/testdata/json/core/segmenter_lstm@1/my.json");

lazy_static! {
static ref THAI_LSTM: structs::LstmData<'static> =
static ref THAI_LSTM: LstmData<'static> =
serde_json::from_slice(THAI_MODEL).expect("JSON syntax error");
static ref BURMESE_LSTM: structs::LstmData<'static> =
static ref BURMESE_LSTM: LstmData<'static> =
serde_json::from_slice(BURMESE_MODEL).expect("JSON syntax error");
}

// LSTM model depends on language, So we have to switch models per language.
pub fn get_best_lstm_model(codepoint: u32) -> Option<DataPayload<structs::LstmDataMarker>> {
pub fn get_best_lstm_model(codepoint: u32) -> Option<DataPayload<LstmDataMarker>> {
let lang = get_language(codepoint);
match lang {
Language::Thai => Some(DataPayload::from_owned(THAI_LSTM.clone())),
Expand Down Expand Up @@ -111,7 +111,7 @@ pub struct LstmSegmenter {
}

impl LstmSegmenter {
pub fn try_new(payload: DataPayload<structs::LstmDataMarker>) -> Result<Self, DataError> {
pub fn try_new(payload: DataPayload<LstmDataMarker>) -> Result<Self, DataError> {
let lstm = Lstm::try_new(payload).unwrap();

Ok(Self { lstm })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
// called LICENSE at the top level of the ICU4X source tree
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).

use crate::error::Error;
use crate::lstm_error::Error;
use crate::lstm_structs::LstmDataMarker;
use crate::math_helper;
use crate::structs;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use core::str;
Expand All @@ -13,12 +13,12 @@ use ndarray::{Array1, Array2, ArrayBase, Dim, ViewRepr};
use unicode_segmentation::UnicodeSegmentation;

pub struct Lstm {
data: DataPayload<structs::LstmDataMarker>,
data: DataPayload<LstmDataMarker>,
}

impl Lstm {
/// `try_new` is the initiator of struct `Lstm`
pub fn try_new(data: DataPayload<structs::LstmDataMarker>) -> Result<Self, Error> {
pub fn try_new(data: DataPayload<LstmDataMarker>) -> Result<Self, Error> {
if data.get().dic.len() > core::i16::MAX as usize {
return Err(Error::Limit);
}
Expand All @@ -43,6 +43,7 @@ impl Lstm {
}

/// `get_model_name` returns the name of the LSTM model.
#[allow(dead_code)]
pub fn get_model_name(&self) -> &str {
&self.data.get().model
}
Expand Down Expand Up @@ -168,3 +169,94 @@ impl Lstm {
bies
}
}

#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::BufReader;

/// `TestCase` is a struct used to store a single test case.
/// Each test case has two attributs: `unseg` which denots the unsegmented line, and `true_bies` which indicates the Bies
/// sequence representing the true segmentation.
#[derive(PartialEq, Debug, Serialize, Deserialize)]
pub struct TestCase {
pub unseg: String,
pub expected_bies: String,
pub true_bies: String,
}

/// `TestTextData` is a struct to store a vector of `TestCase` that represents a test text.
#[derive(PartialEq, Debug, Serialize, Deserialize)]
pub struct TestTextData {
pub testcases: Vec<TestCase>,
}

#[derive(Debug)]
pub struct TestText {
pub data: TestTextData,
}

impl TestText {
pub fn new(data: TestTextData) -> Self {
Self { data }
}
}

fn load_lstm_data(filename: &str) -> DataPayload<LstmDataMarker> {
DataPayload::<LstmDataMarker>::try_from_rc_buffer_badly(
std::fs::read(filename)
.expect("File can read to end")
.into(),
|bytes| serde_json::from_slice(bytes),
)
.expect("JSON syntax error")
}

fn load_test_text(filename: &str) -> TestTextData {
let file = File::open(filename).expect("File should be present");
let reader = BufReader::new(file);
serde_json::from_reader(reader).expect("JSON syntax error")
}

#[test]
fn test_model_loading() {
let filename = "tests/testdata/Thai_graphclust_exclusive_model4_heavy/weights.json";
let lstm_data = load_lstm_data(filename);
let lstm = Lstm::try_new(lstm_data).unwrap();
assert_eq!(
lstm.get_model_name(),
String::from("Thai_graphclust_exclusive_model4_heavy")
);
}

#[test]
fn segment_file_by_lstm() {
// Choosing the embedding system. It can be "graphclust" or "codepoints".
let embedding: &str = "codepoints";
let mut model_filename = "tests/testdata/Thai_".to_owned();
model_filename.push_str(embedding);
model_filename.push_str("_exclusive_model4_heavy/weights.json");
let lstm_data = load_lstm_data(&model_filename);
let lstm = Lstm::try_new(lstm_data).unwrap();

// Importing the test data
let mut test_text_filename = "tests/testdata/test_text_".to_owned();
test_text_filename.push_str(embedding);
test_text_filename.push_str(".json");
let test_text_data = load_test_text(&test_text_filename);
let test_text = TestText::new(test_text_data);

// Testing
for test_case in test_text.data.testcases {
let lstm_output = lstm.word_segmenter(&test_case.unseg);
println!("Test case : {}", test_case.unseg);
println!("Expected bies : {}", test_case.expected_bies);
println!("Estimated bies : {}", lstm_output);
println!("True bies : {}", test_case.true_bies);
println!("****************************************************");
assert_eq!(test_case.expected_bies, lstm_output);
}
}
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
45 changes: 0 additions & 45 deletions experimental/segmenter_lstm/Cargo.toml

This file was deleted.

Loading

0 comments on commit 9010803

Please sign in to comment.