Skip to content

Commit

Permalink
Merge branch 'main' into jp/collapse-wasm-bindings
Browse files Browse the repository at this point in the history
Conflicts:
  wasm/bindings/TranslationModelBindings.cpp
  • Loading branch information
Jerin Philip committed Apr 29, 2021
2 parents f2efc5a + de0abfd commit 9d80a08
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 39 deletions.
9 changes: 6 additions & 3 deletions src/translator/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ inline marian::ConfigParser createConfigParser() {
}

inline std::shared_ptr<marian::Options>
parseOptions(const std::string &config) {
parseOptions(const std::string &config, bool validate = true) {
marian::Options options;

// @TODO(jerinphilip) There's something off here, @XapaJIaMnu suggests
Expand All @@ -58,8 +58,11 @@ parseOptions(const std::string &config) {
options.parse(config);
YAML::Node configCopy = options.cloneToYamlNode();

marian::ConfigValidator validator(configCopy);
validator.validateOptions(marian::cli::mode::translation);
if (validate) {
// Perform validation on parsed options only when requested
marian::ConfigValidator validator(configCopy);
validator.validateOptions(marian::cli::mode::translation);
}

return std::make_shared<marian::Options>(options);
}
Expand Down
2 changes: 1 addition & 1 deletion src/translator/service.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Service {
explicit Service(const std::string &config,
AlignedMemory modelMemory = AlignedMemory(),
AlignedMemory shortlistMemory = AlignedMemory())
: Service(parseOptions(config), std::move(modelMemory),
: Service(parseOptions(config, /*validate=*/false), std::move(modelMemory),
std::move(shortlistMemory)) {}

/// Explicit destructor to clean up after any threads initialized in
Expand Down
19 changes: 16 additions & 3 deletions wasm/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
## Using Bergamot Translator in JavaScript
The example file `bergamot.html` in the folder `test_page` demonstrates how to use the bergamot translator in JavaScript via a `<script>` tag.

Please note that everything below assumes that the [bergamot project specific model files](https://github.com/mozilla-applied-ml/bergamot-models) were packaged in wasm binary (using the compile instructions given in the top level README).
### <a name="Pre-requisite"></a> Pre-requisite: Download files required for translation

### Using JS APIs
Please note that [Using JS APIs](#Using-JS-APIs) and [Demo](#Demo) section below assumes that the [bergamot project specific model files](https://github.com/mozilla-applied-ml/bergamot-models) are already downloaded and present in the `test_page` folder. If this is not done then use following instructions to do so:

```bash
cd test_page
mkdir models
git clone --depth 1 --branch main --single-branch https://github.com/mozilla-applied-ml/bergamot-models
cp -rf bergamot-models/prod/* models
gunzip models/*/*
```

### <a name="Using-JS-APIs"></a> Using JS APIs

```js
// The model configuration as YAML formatted string. For available configuration options, please check: https://marian-nmt.github.io/docs/cmd/marian-decoder/
Expand Down Expand Up @@ -34,13 +44,16 @@ request.delete();
input.delete();
```

### Demo (see everything in action)
### <a name="Demo"></a> Demo (see everything in action)

* Make sure that you followed [Pre-requisite](#Pre-requisite) instructions before moving forward.

* Start the test webserver (ensure you have the latest nodejs installed)
```bash
cd test_page
bash start_server.sh
```

* Open any of the browsers below
* Firefox Nightly +87: make sure the following prefs are on (about:config)
```
Expand Down
29 changes: 23 additions & 6 deletions wasm/bindings/TranslationModelBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,31 @@ using namespace emscripten;
typedef marian::bergamot::Service TranslationModel;
typedef marian::bergamot::Response TranslationResult;

// Binding code
val getByteArrayView(marian::bergamot::AlignedMemory& alignedMemory) {
return val(typed_memory_view(alignedMemory.size(), alignedMemory.as<char>()));
}

EMSCRIPTEN_BINDINGS(aligned_memory) {
class_<marian::bergamot::AlignedMemory>("AlignedMemory")
.constructor<std::size_t, std::size_t>()
.function("size", &marian::bergamot::AlignedMemory::size)
.function("getByteArrayView", &getByteArrayView)
;
}

TranslationModel* TranslationModelFactory(const std::string &config,
marian::bergamot::AlignedMemory* modelMemory,
marian::bergamot::AlignedMemory* shortlistMemory) {
return new TranslationModel(config, std::move(*modelMemory), std::move(*shortlistMemory));
}

EMSCRIPTEN_BINDINGS(translation_model) {
class_<TranslationModel>("TranslationModel")
.constructor<std::string>()
.function("translate", &TranslationModel::translateMultiple)
.function("isAlignmentSupported",
&TranslationModel::isAlignmentSupported);
// ^ We redirect translateMultiple to translate instead. Sane API is
.constructor(&TranslationModelFactory, allow_raw_pointers())
.function("translate", &TranslationModel::translateMultiple)
.function("isAlignmentSupported", &TranslationModel::isAlignmentSupported)
;
// ^ We redirect Service::translateMultiple to WASMBound::translate instead. Sane API is
// translate. If and when async comes, we can be done with this inconsistency.

register_vector<std::string>("VectorString");
Expand Down
96 changes: 70 additions & 26 deletions wasm/test_page/bergamot.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<html>
<head>
<link rel="icon" href="data:,">
<meta http-equiv="Content-Type" content="text/html;charset=ISO-8859-1">
<meta http-equiv="Content-Type" content="text/html;charset=UTF-8">
</head>
<style>
body, html, div {
Expand Down Expand Up @@ -61,9 +61,27 @@
</div>

<script>
// This function downloads file from a url and returns the array buffer
const downloadAsArrayBuffer = async(url) => {
const response = await fetch(url);
if (!response.ok) {
throw Error(`HTTP ${response.status} - ${response.statusText}`);
}
return response.arrayBuffer();
}

// This function constructs the AlignedMemory from the array buffer and the alignment size
function constructAlignedMemoryFromBuffer(buffer, alignmentSize) {
var byteArray = new Int8Array(buffer);
console.debug("byteArray size: ", byteArray.byteLength);
var alignedMemory = new Module.AlignedMemory(byteArray.byteLength, alignmentSize);
const alignedByteArrayView = alignedMemory.getByteArrayView();
alignedByteArrayView.set(byteArray);
return alignedMemory;
}

var model, request, input = undefined;
const loadModel = (from, to) => {
var translationModel, request, input = undefined;
const constructTranslationModel = async (from, to) => {

const languagePair = `${from}${to}`;

Expand All @@ -72,7 +90,7 @@

// Set the Model Configuration as YAML formatted string.
// For available configuration options, please check: https://marian-nmt.github.io/docs/cmd/marian-decoder/
const modelConfig = `models:
/*const modelConfig = `models:
- /${languagePair}/model.${languagePair}.intgemm.alphas.bin
vocabs:
- /${vocabLanguagePair}/vocab.${vocabLanguagePair}.spm
Expand All @@ -93,22 +111,53 @@
- 50
- 50
`;
/*
This config is not valid anymore in new APIs
mini-batch: 32
maxi-batch: 100
maxi-batch-sort: src
*/

const modelConfigWithoutModelAndShortList = `vocabs:
- /${vocabLanguagePair}/vocab.${vocabLanguagePair}.spm
- /${vocabLanguagePair}/vocab.${vocabLanguagePair}.spm
beam-size: 1
normalize: 1.0
word-penalty: 0
max-length-break: 128
mini-batch-words: 1024
workspace: 128
max-length-factor: 2.0
skip-cost: true
cpu-threads: 0
quiet: true
quiet-translation: true
`;

// TODO: Use in model config when wormhole is enabled:
// gemm-precision: int8shift
// TODO: Use in model config when loading of binary models is supported and we use model.intgemm.alphas.bin:
// gemm-precision: int8shiftAlphaAll

console.debug("modelConfig: ", modelConfig);

// Instantiate the TranslationModel
if (model) model.delete();
model = new Module.TranslationModel(modelConfig);
const modelFile = `${languagePair}/model.${languagePair}.intgemm.alphas.bin`;
console.debug("modelFile: ", modelFile);
const shortlistFile = `${languagePair}/lex.${languagePair}.s2t.bin`;
console.debug("shortlistFile: ", shortlistFile);

try {
// Download the files as buffers from the given urls
let start = Date.now();
const downloadedBuffers = await Promise.all([downloadAsArrayBuffer(modelFile), downloadAsArrayBuffer(shortlistFile)]);
const modelBuffer = downloadedBuffers[0];
const shortListBuffer = downloadedBuffers[1];
log(`${languagePair} file download took ${(Date.now() - start) / 1000} secs`);

// Construct AlignedMemory objects with downloaded buffers
var alignedModelMemory = constructAlignedMemoryFromBuffer(modelBuffer, 256);
var alignedShortlistMemory = constructAlignedMemoryFromBuffer(shortListBuffer, 64);

// Instantiate the TranslationModel
if (translationModel) translationModel.delete();
console.debug("Creating TranslationModel with config:", modelConfigWithoutModelAndShortList);
translationModel = new Module.TranslationModel(modelConfigWithoutModelAndShortList, alignedModelMemory, alignedShortlistMemory);
} catch (error) {
console.error(error);
}
}

const translate = (paragraphs) => {
Expand All @@ -127,16 +176,9 @@
})
// Access input (just for debugging)
console.log('Input size=', input.size());
/*
for (let i = 0; i < input.size(); i++) {
console.log(' val:' + input.get(i));
}
*/

// Translate the input; the result is a vector<TranslationResult>
let result = model.translate(input, request);
// Access original and translated text from each entry of vector<TranslationResult>
//console.log('Result size=', result.size(), ' - TimeDiff - ', (Date.now() - start)/1000);
let result = translationModel.translate(input, request);
const translatedParagraphs = [];
for (let i = 0; i < result.size(); i++) {
translatedParagraphs.push(result.get(i).getTranslatedText());
Expand All @@ -147,14 +189,16 @@
return translatedParagraphs;
}

document.querySelector("#load").addEventListener("click", () => {
document.querySelector("#load").addEventListener("click", async() => {
document.querySelector("#load").disabled = true;
const lang = document.querySelector('input[name="modellang"]:checked').value;
const from = lang.substring(0, 2);
const to = lang.substring(2, 4);
let start = Date.now();
loadModel(from, to)
log(`model ${from}${to} loaded in ${(Date.now() - start) / 1000} secs`);
//log('Model Alignment:', model.isAlignmentSupported());
await constructTranslationModel(from, to);
log(`translation model ${from}${to} construction took ${(Date.now() - start) / 1000} secs`);
document.querySelector("#load").disabled = false;
//log('Model Alignment:', translationModel.isAlignmentSupported());
});

const translateCall = () => {
Expand Down

0 comments on commit 9d80a08

Please sign in to comment.