Skip to content

Commit

Permalink
JS bindings for loading model and shortlist files as bytes (#117)
Browse files Browse the repository at this point in the history
* Bindings to load model and shortlist files as bytes
* Modified wasm test page for byte based loading of files
* Updates wasm README for byte loading based usage of TranslationModel
  • Loading branch information
abhi-agg authored Apr 29, 2021
1 parent e5ec5bd commit de0abfd
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 31 deletions.
19 changes: 16 additions & 3 deletions wasm/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
## Using Bergamot Translator in JavaScript
The example file `bergamot.html` in the folder `test_page` demonstrates how to use the bergamot translator in JavaScript via a `<script>` tag.

Please note that everything below assumes that the [bergamot project specific model files](https://github.com/mozilla-applied-ml/bergamot-models) were packaged in wasm binary (using the compile instructions given in the top level README).
### <a name="Pre-requisite"></a> Pre-requisite: Download files required for translation

### Using JS APIs
Please note that [Using JS APIs](#Using-JS-APIs) and [Demo](#Demo) section below assumes that the [bergamot project specific model files](https://github.com/mozilla-applied-ml/bergamot-models) are already downloaded and present in the `test_page` folder. If this is not done then use following instructions to do so:

```bash
cd test_page
mkdir models
git clone --depth 1 --branch main --single-branch https://github.com/mozilla-applied-ml/bergamot-models
cp -rf bergamot-models/prod/* models
gunzip models/*/*
```

### <a name="Using-JS-APIs"></a> Using JS APIs

```js
// The model configuration as YAML formatted string. For available configuration options, please check: https://marian-nmt.github.io/docs/cmd/marian-decoder/
Expand Down Expand Up @@ -34,13 +44,16 @@ request.delete();
input.delete();
```

### Demo (see everything in action)
### <a name="Demo"></a> Demo (see everything in action)

* Make sure that you followed [Pre-requisite](#Pre-requisite) instructions before moving forward.

* Start the test webserver (ensure you have the latest nodejs installed)
```bash
cd test_page
bash start_server.sh
```

* Open any of the browsers below
* Firefox Nightly +87: make sure the following prefs are on (about:config)
```
Expand Down
21 changes: 19 additions & 2 deletions wasm/bindings/TranslationModelBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,27 @@

using namespace emscripten;

// Binding code
val getByteArrayView(marian::bergamot::AlignedMemory& alignedMemory) {
return val(typed_memory_view(alignedMemory.size(), alignedMemory.as<char>()));
}

EMSCRIPTEN_BINDINGS(aligned_memory) {
class_<marian::bergamot::AlignedMemory>("AlignedMemory")
.constructor<std::size_t, std::size_t>()
.function("size", &marian::bergamot::AlignedMemory::size)
.function("getByteArrayView", &getByteArrayView)
;
}

TranslationModel* TranslationModelFactory(const std::string &config,
marian::bergamot::AlignedMemory* modelMemory,
marian::bergamot::AlignedMemory* shortlistMemory) {
return new TranslationModel(config, std::move(*modelMemory), std::move(*shortlistMemory));
}

EMSCRIPTEN_BINDINGS(translation_model) {
class_<TranslationModel>("TranslationModel")
.constructor<std::string>()
.constructor(&TranslationModelFactory, allow_raw_pointers())
.function("translate", &TranslationModel::translate)
.function("isAlignmentSupported", &TranslationModel::isAlignmentSupported)
;
Expand Down
96 changes: 70 additions & 26 deletions wasm/test_page/bergamot.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<html>
<head>
<link rel="icon" href="data:,">
<meta http-equiv="Content-Type" content="text/html;charset=ISO-8859-1">
<meta http-equiv="Content-Type" content="text/html;charset=UTF-8">
</head>
<style>
body, html, div {
Expand Down Expand Up @@ -61,9 +61,27 @@
</div>

<script>
// This function downloads file from a url and returns the array buffer
const downloadAsArrayBuffer = async(url) => {
const response = await fetch(url);
if (!response.ok) {
throw Error(`HTTP ${response.status} - ${response.statusText}`);
}
return response.arrayBuffer();
}

// This function constructs the AlignedMemory from the array buffer and the alignment size
function constructAlignedMemoryFromBuffer(buffer, alignmentSize) {
var byteArray = new Int8Array(buffer);
console.debug("byteArray size: ", byteArray.byteLength);
var alignedMemory = new Module.AlignedMemory(byteArray.byteLength, alignmentSize);
const alignedByteArrayView = alignedMemory.getByteArrayView();
alignedByteArrayView.set(byteArray);
return alignedMemory;
}

var model, request, input = undefined;
const loadModel = (from, to) => {
var translationModel, request, input = undefined;
const constructTranslationModel = async (from, to) => {

const languagePair = `${from}${to}`;

Expand All @@ -72,7 +90,7 @@

// Set the Model Configuration as YAML formatted string.
// For available configuration options, please check: https://marian-nmt.github.io/docs/cmd/marian-decoder/
const modelConfig = `models:
/*const modelConfig = `models:
- /${languagePair}/model.${languagePair}.intgemm.alphas.bin
vocabs:
- /${vocabLanguagePair}/vocab.${vocabLanguagePair}.spm
Expand All @@ -93,22 +111,53 @@
- 50
- 50
`;
/*
This config is not valid anymore in new APIs
mini-batch: 32
maxi-batch: 100
maxi-batch-sort: src
*/

const modelConfigWithoutModelAndShortList = `vocabs:
- /${vocabLanguagePair}/vocab.${vocabLanguagePair}.spm
- /${vocabLanguagePair}/vocab.${vocabLanguagePair}.spm
beam-size: 1
normalize: 1.0
word-penalty: 0
max-length-break: 128
mini-batch-words: 1024
workspace: 128
max-length-factor: 2.0
skip-cost: true
cpu-threads: 0
quiet: true
quiet-translation: true
`;

// TODO: Use in model config when wormhole is enabled:
// gemm-precision: int8shift
// TODO: Use in model config when loading of binary models is supported and we use model.intgemm.alphas.bin:
// gemm-precision: int8shiftAlphaAll

console.debug("modelConfig: ", modelConfig);

// Instantiate the TranslationModel
if (model) model.delete();
model = new Module.TranslationModel(modelConfig);
const modelFile = `${languagePair}/model.${languagePair}.intgemm.alphas.bin`;
console.debug("modelFile: ", modelFile);
const shortlistFile = `${languagePair}/lex.${languagePair}.s2t.bin`;
console.debug("shortlistFile: ", shortlistFile);

try {
// Download the files as buffers from the given urls
let start = Date.now();
const downloadedBuffers = await Promise.all([downloadAsArrayBuffer(modelFile), downloadAsArrayBuffer(shortlistFile)]);
const modelBuffer = downloadedBuffers[0];
const shortListBuffer = downloadedBuffers[1];
log(`${languagePair} file download took ${(Date.now() - start) / 1000} secs`);

// Construct AlignedMemory objects with downloaded buffers
var alignedModelMemory = constructAlignedMemoryFromBuffer(modelBuffer, 256);
var alignedShortlistMemory = constructAlignedMemoryFromBuffer(shortListBuffer, 64);

// Instantiate the TranslationModel
if (translationModel) translationModel.delete();
console.debug("Creating TranslationModel with config:", modelConfigWithoutModelAndShortList);
translationModel = new Module.TranslationModel(modelConfigWithoutModelAndShortList, alignedModelMemory, alignedShortlistMemory);
} catch (error) {
console.error(error);
}
}

const translate = (paragraphs) => {
Expand All @@ -127,16 +176,9 @@
})
// Access input (just for debugging)
console.log('Input size=', input.size());
/*
for (let i = 0; i < input.size(); i++) {
console.log(' val:' + input.get(i));
}
*/

// Translate the input; the result is a vector<TranslationResult>
let result = model.translate(input, request);
// Access original and translated text from each entry of vector<TranslationResult>
//console.log('Result size=', result.size(), ' - TimeDiff - ', (Date.now() - start)/1000);
let result = translationModel.translate(input, request);
const translatedParagraphs = [];
for (let i = 0; i < result.size(); i++) {
translatedParagraphs.push(result.get(i).getTranslatedText());
Expand All @@ -147,14 +189,16 @@
return translatedParagraphs;
}

document.querySelector("#load").addEventListener("click", () => {
document.querySelector("#load").addEventListener("click", async() => {
document.querySelector("#load").disabled = true;
const lang = document.querySelector('input[name="modellang"]:checked').value;
const from = lang.substring(0, 2);
const to = lang.substring(2, 4);
let start = Date.now();
loadModel(from, to)
log(`model ${from}${to} loaded in ${(Date.now() - start) / 1000} secs`);
//log('Model Alignment:', model.isAlignmentSupported());
await constructTranslationModel(from, to);
log(`translation model ${from}${to} construction took ${(Date.now() - start) / 1000} secs`);
document.querySelector("#load").disabled = false;
//log('Model Alignment:', translationModel.isAlignmentSupported());
});

const translateCall = () => {
Expand Down

0 comments on commit de0abfd

Please sign in to comment.