diff --git a/.clang-format b/.clang-format
new file mode 100644
index 000000000..b7f746ef1
--- /dev/null
+++ b/.clang-format
@@ -0,0 +1,5 @@
+BasedOnStyle: Google
+
+# Maximum line length 80 is too low even for 1080p monitor. @XapaJIaMnu
+# personally would like 120.
+ColumnLimit: 120
diff --git a/.clang-format-ignore b/.clang-format-ignore
new file mode 100644
index 000000000..50795cacb
--- /dev/null
+++ b/.clang-format-ignore
@@ -0,0 +1,4 @@
+3rd_party
+wasm/test_page
+src/translator/aligned.h
+src/translator/pcqueue.h
diff --git a/.clang-tidy b/.clang-tidy
new file mode 100644
index 000000000..bdbadb624
--- /dev/null
+++ b/.clang-tidy
@@ -0,0 +1,32 @@
+Checks: >
+ .*,
+ bugprone-*,
+ concurrency-*,
+ google-*,
+ portability-*,
+ performance-*,
+ clang-analyzer-*,
+ readability-*,
+ -readability-implicit-bool-conversion,
+ -readability-isolate-declaration,
+ -readability-uppercase-literal-suffix,
+ misc-*,
+ -misc-noexcept*,
+ modernize-*,
+ -modernize-deprecated-headers,
+ -modernize-use-nodiscard,
+ -modernize-raw-string-literal,
+ -modernize-return-braced-init-list,
+ -modernize-use-equals-delete,
+ -modernize-use-trailing-return-type,
+
+
+
+CheckOptions:
+ - { key: readability-identifier-naming.ClassCase, value: CamelCase }
+ - { key: readability-identifier-naming.ClassMethodCase, value: camelBack }
+ - { key: readability-identifier-naming.VariableCase, value: camelBack }
+ - { key: readability-identifier-naming.FunctionCase, value: camelBack }
+ - { key: readability-identifier-naming.PrivateMemberSuffix, value: _ }
+ - { key: readability-identifier-naming.ParameterCase, value: camelBack }
+
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 26f6f4418..fa4f321cf 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -1,5 +1,31 @@
+# Firefox Translations review group
+.dockerignore @mozilla/firefox-translations
+.github @mozilla/firefox-translations
+.gitignore @mozilla/firefox-translations
+.gitmodules @mozilla/firefox-translations
+docker @mozilla/firefox-translations
+docs @mozilla/firefox-translations
+utils @mozilla/firefox-translations
+CODE_OF_CONDUCT.md @mozilla/firefox-translations
+LICENSE @mozilla/firefox-translations
+poetry.lock @mozilla/firefox-translations
+pyproject.toml @mozilla/firefox-translations
+README.md @mozilla/firefox-translations
+Taskfile.yml @mozilla/firefox-translations
+
+# Translations Training review group
+configs @mozilla/translations-training
+pipeline @mozilla/translations-training
+snakemake @mozilla/translations-training
+tests @mozilla/translations-training
+tracking @mozilla/translations-training
+
+# Translations Inference review group
+inference-engine @mozilla/translations-inference
+
# Taskcluster pipeline related files. Changes to these ought to be reviewed by
# RelEng to watch for security issues and best practices. These should also
# be reviewed by people familiar with the pipeline itself.
-.taskcluster.yml @mozilla/releng
-taskcluster @mozilla/releng
+.taskcluster.yml @mozilla/releng @mozilla/translations-training
+taskcluster @mozilla/releng @mozilla/translations-training
+
diff --git a/.gitmodules b/.gitmodules
index f1813a444..f88221eb5 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -16,3 +16,29 @@
[submodule "3rd_party/preprocess"]
path = 3rd_party/preprocess
url = https://github.com/kpu/preprocess.git
+[submodule "inference/3rd_party/ssplit-cpp"]
+ path = inference/3rd_party/ssplit-cpp
+ url = https://github.com/browsermt/ssplit-cpp
+# This is the same dependency and repository as `3rd_party/browsermt-marian-dev` below.
+#
+# When forking `inference-engine` into to this project, I made an earnest attempt to utilize the preexisting
+# `3rd_party/browsermt-marian-dev` submodule within `inference-engine`. Unfortunately, I ran into several roadblocks:
+#
+# 1) I cannot directly add `3rd_party/browsermt-marian-dev` as a cmake subdirectory because cmake is aware that
+# this path is not a subdirectory of the `inference-engine` project root.
+#
+# 2) Symbolic links do not appear to work for git submodule direcotires the way that they do for regular directories.
+# Even if the symbolic link had linked correctly, it may have still failed due to the considerations of 1).
+#
+# 3) I tried using cmake to copy the files from `3rd_party/browsermt-marian-dev` into `inference-engine/3rd_party/browsermt-marian-dev`
+# at build time, which would ensure that there is no duplicate reference to the URL in this file, however the upstream dependency itself
+# has hard-coded expectations that the `.git` directory is only one level up, which appears to work correctly for the way git submodules are
+# configured, but does not work if the files are copied over to a regular directory deeper in the repository's directory tree.
+#
+# It may be possible to remove `3rd_party/browsermt-marian-dev` to instead use `inference-engine/3rd-party/browsermt-marian-dev` everywhere
+# within this repository, but I will leave that for a future commit if there is a need to do so.
+#
+# TODO(#869)
+[submodule "inference/3rd_party/browsermt-marian-dev"]
+ path = inference/3rd_party/browsermt-marian-dev
+ url = https://github.com/browsermt/marian-dev
diff --git a/Taskfile.yml b/Taskfile.yml
index c767924e1..745c5f05c 100644
--- a/Taskfile.yml
+++ b/Taskfile.yml
@@ -75,6 +75,30 @@ tasks:
cmds:
- poetry run opuscleaner-server serve --host=0.0.0.0 --port=8000
+ inference-clean:
+ desc: Clean build artifacts from the inference directory.
+ cmds:
+ - >-
+ task docker-run -- ./inference/scripts/clean.sh
+
+ inference-build:
+ desc: Build inference engine.
+ cmds:
+ - >-
+ task docker-run -- ./inference/scripts/build-local.sh
+
+ inference-test:
+ desc: Run inference tests.
+ cmds:
+ - >-
+ task docker-run -- ./inference/scripts/unit-tests.sh
+
+ inference-build-wasm:
+ desc: Build inference engine WASM.
+ cmds:
+ - >-
+ task docker-run -- ./inference/scripts/build-wasm.sh
+
lint-black:
desc: Checks the styling of the Python code with Black.
deps: [poetry-install-black]
diff --git a/inference/.gitignore b/inference/.gitignore
new file mode 100644
index 000000000..78202d979
--- /dev/null
+++ b/inference/.gitignore
@@ -0,0 +1,30 @@
+# vim temporary files
+*.swp
+*.swo
+
+# CMake
+CMakeLists.txt.user
+CMakeCache.txt
+CMakeFiles
+CMakeScripts
+Testing
+Makefile
+cmake_install.cmake
+install_manifest.txt
+compile_commands.json
+CTestTestfile.cmake
+_deps
+
+
+wasm/test_page/node_modules
+/build
+/build-local
+/build-native
+/build-wasm
+/emsdk
+models
+wasm/module/worker/bergamot-translator-worker.*
+wasm/module/browsermt-bergamot-translator-*.tgz
+
+# VSCode
+.vscode
diff --git a/inference/3rd_party/CMakeLists.txt b/inference/3rd_party/CMakeLists.txt
new file mode 100644
index 000000000..62ba02722
--- /dev/null
+++ b/inference/3rd_party/CMakeLists.txt
@@ -0,0 +1,32 @@
+# browsermt-marian-dev is tested elsewhere in both paths, turning off here.
+set(COMPILE_TESTS OFF)
+add_subdirectory(browsermt-marian-dev EXCLUDE_FROM_ALL)
+
+if(COMPILE_WASM)
+ # This is a bad way of adding compilation flags. Will be improved soon.
+ add_compile_options(${WASM_COMPILE_FLAGS})
+ add_link_options(${WASM_LINK_FLAGS})
+endif(COMPILE_WASM)
+
+add_subdirectory(ssplit-cpp EXCLUDE_FROM_ALL)
+
+# Add include directories for 3rd party targets to be able to use it anywhere in the
+# project without explicitly specifying their include directories. Once they
+# fixe this problem, it can be removed.
+get_property(INCDIRS DIRECTORY browsermt-marian-dev/src PROPERTY INCLUDE_DIRECTORIES)
+target_include_directories(marian PUBLIC ${INCDIRS})
+
+get_property(INCLUDE_DIRECTORIES DIRECTORY ssplit-cpp/src PROPERTY INCLUDE_DIRECTORIES)
+target_include_directories(ssplit PUBLIC ${INCLUDE_DIRECTORIES})
+
+get_property(COMPILE_DEFINITIONS DIRECTORY browsermt-marian-dev PROPERTY COMPILE_DEFINITIONS)
+target_compile_definitions(marian PUBLIC ${COMPILE_DEFINITIONS})
+
+get_property(COMPILE_OPTIONS DIRECTORY browsermt-marian-dev PROPERTY COMPILE_OPTIONS)
+target_compile_options(marian PUBLIC ${COMPILE_OPTIONS})
+
+# Compilation flags
+get_directory_property(CMAKE_C_FLAGS DIRECTORY browsermt-marian-dev DEFINITION CMAKE_C_FLAGS)
+get_directory_property(CMAKE_CXX_FLAGS DIRECTORY browsermt-marian-dev DEFINITION CMAKE_CXX_FLAGS)
+set(CMAKE_C_FLAGS ${CMAKE_C_FLAGS} PARENT_SCOPE)
+set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} PARENT_SCOPE)
diff --git a/inference/3rd_party/browsermt-marian-dev b/inference/3rd_party/browsermt-marian-dev
new file mode 160000
index 000000000..2781d735d
--- /dev/null
+++ b/inference/3rd_party/browsermt-marian-dev
@@ -0,0 +1 @@
+Subproject commit 2781d735d4a10dca876d61be587afdab2726293c
diff --git a/inference/3rd_party/ssplit-cpp b/inference/3rd_party/ssplit-cpp
new file mode 160000
index 000000000..a311f9865
--- /dev/null
+++ b/inference/3rd_party/ssplit-cpp
@@ -0,0 +1 @@
+Subproject commit a311f9865ade34db1e8e080e6cc146f55dafb067
diff --git a/inference/BERGAMOT_VERSION b/inference/BERGAMOT_VERSION
new file mode 100644
index 000000000..a423f7f06
--- /dev/null
+++ b/inference/BERGAMOT_VERSION
@@ -0,0 +1 @@
+v0.4.5
diff --git a/inference/CMakeLists.txt b/inference/CMakeLists.txt
new file mode 100644
index 000000000..febff3e6e
--- /dev/null
+++ b/inference/CMakeLists.txt
@@ -0,0 +1,188 @@
+cmake_minimum_required(VERSION 3.5.1)
+set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
+
+if (POLICY CMP0074)
+ cmake_policy(SET CMP0074 NEW) # CMake 3.12
+endif ()
+
+if (POLICY CMP0077)
+ cmake_policy(SET CMP0077 NEW)
+endif()
+
+project(bergamot_translator CXX C)
+
+# Retrieve the parent-directory path of PROJECT_SOURCE_DIR and assign that to REPOSITORY_ROOT_DIR.
+cmake_path(GET PROJECT_SOURCE_DIR PARENT_PATH REPOSITORY_ROOT_DIR)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+# Generate a compile_commands.json in the build directory. The compile commands allow
+# code editors to understand the build process and provide static analysis of the code.
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+# Note that with CMake MSVC build, the option CMAKE_BUILD_TYPE is automatically derived from the key
+# 'configurationType' in CMakeSettings.json configurations
+if(NOT CMAKE_BUILD_TYPE)
+ message(WARNING "CMAKE_BUILD_TYPE not set; setting to Release")
+ set(CMAKE_BUILD_TYPE "Release")
+endif()
+
+if(NOT COMPILE_WASM)
+ # Setting BUILD_ARCH to native invokes CPU intrinsic detection logic below.
+ # Prevent invoking that logic for WASM builds.
+ set(BUILD_ARCH native CACHE STRING "Compile for this CPU architecture.")
+
+ # Unfortunately MSVC supports a limited subset of BUILD_ARCH flags. Instead try to guess
+ # what architecture we can compile to reading BUILD_ARCH and mapping it to MSVC values
+ # references: https://clang.llvm.org/docs/UsersManual.html https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html https://gcc.gnu.org/onlinedocs/gcc-4.8.5/gcc/i386-and-x86-64-Options.html
+ # https://docs.microsoft.com/en-us/cpp/build/reference/arch-x86?redirectedfrom=MSDN&view=vs-2019&view=msvc-170 https://devblogs.microsoft.com/oldnewthing/20201026-00/?p=104397
+ # This is by no means an exhaustive list but should match the most common flags Linux programmers expect to parse to MSVC
+ if(MSVC)
+ if(BUILD_ARCH STREQUAL "native") # avx2 is good default for native. Very few desktop systems support avx512
+ set(MSVC_BUILD_ARCH "/arch:AVX2")
+ elseif(BUILD_ARCH STREQUAL "skylake-avx512" OR BUILD_ARCH STREQUAL "cannonlake" OR BUILD_ARCH STREQUAL "x86-64-v4" OR BUILD_ARCH STREQUAL "tigerlake" OR BUILD_ARCH STREQUAL "cooperlake" OR BUILD_ARCH STREQUAL "cascadelake")
+ set(MSVC_BUILD_ARCH "/arch:AVX512")
+ elseif(BUILD_ARCH STREQUAL "core-avx2" OR BUILD_ARCH STREQUAL "haswell" OR BUILD_ARCH STREQUAL "x86-64-v3" OR BUILD_ARCH STREQUAL "broadwell" OR BUILD_ARCH STREQUAL "skylake")
+ set(MSVC_BUILD_ARCH "/arch:AVX2")
+ elseif(BUILD_ARCH STREQUAL "sandybridge" OR BUILD_ARCH STREQUAL "corei7-avx" OR BUILD_ARCH STREQUAL "core-avx-i" OR BUILD_ARCH STREQUAL "ivybridge")
+ set(MSVC_BUILD_ARCH "/arch:AVX")
+ elseif(BUILD_ARCH STREQUAL "nehalem" OR BUILD_ARCH STREQUAL "westmere" OR BUILD_ARCH STREQUAL "x86-64-v2" OR BUILD_ARCH STREQUAL "corei7" OR BUILD_ARCH STREQUAL "core2")
+ set(MSVC_BUILD_ARCH "/arch:SSE2") # This is MSVC default. We won't go down to SSE because we don't support that hardware at all with intgemm. Marian recommends to only go down to SSE4.1 at most
+ else()
+ message(WARNING "Unknown BUILD_ARCH ${BUILD_ARCH} provided. Default to SSE2 for Windows build")
+ set(MSVC_BUILD_ARCH "/arch:SSE2")
+ endif()
+ endif(MSVC)
+endif()
+
+#MSVC can't seem to pick up correct flags otherwise:
+if(MSVC)
+ add_definitions(-DUSE_SSE2=1) # Supposed to fix something in the sse_mathfun.h but not sure it does
+ set(INTRINSICS ${MSVC_BUILD_ARCH}) # ARCH we're targetting on win32. @TODO variable
+
+ set(CMAKE_CXX_FLAGS "/EHsc /DWIN32 /D_WINDOWS /DUNICODE /D_UNICODE /D_CRT_NONSTDC_NO_WARNINGS /D_CRT_SECURE_NO_WARNINGS /bigobj")
+ set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} /MT /O2 ${INTRINSICS} /MP /GL /DNDEBUG")
+ set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} /MTd /Od /Ob0 ${INTRINSICS} /RTC1 /Zi /D_DEBUG")
+
+ # ignores warning LNK4049: locally defined symbol free imported - this comes from zlib
+ set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /ignore:4049")
+ set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRT")
+ set(CMAKE_EXE_LINKER_FLAGS_DEBUG "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRTD")
+ set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} /LTCG:incremental")
+endif(MSVC)
+
+include(CMakeDependentOption)
+
+# Project specific cmake options
+option(COMPILE_WASM "Compile for WASM" OFF)
+cmake_dependent_option(USE_WASM_COMPATIBLE_SOURCE "Use wasm compatible sources" OFF "NOT COMPILE_WASM" ON)
+
+# WASM disables a million libraries, which also includes the unit test-library.
+cmake_dependent_option(COMPILE_UNIT_TESTS "Compile unit tests" OFF "USE_WASM_COMPATIBLE_SOURCE" ON)
+option(COMPILE_TESTS "Compile bergamot-tests" OFF)
+cmake_dependent_option(ENABLE_CACHE_STATS "Enable stats on cache" ON "COMPILE_TESTS" OFF)
+
+
+# Set 3rd party submodule specific cmake options for this project
+SET(COMPILE_CUDA OFF CACHE BOOL "Compile GPU version")
+SET(USE_SENTENCEPIECE ON CACHE BOOL "Download and compile SentencePiece")
+SET(USE_STATIC_LIBS ON CACHE BOOL "Link statically against non-system libs")
+SET(SSPLIT_COMPILE_LIBRARY_ONLY ON CACHE BOOL "Do not compile ssplit tests")
+if (USE_WASM_COMPATIBLE_SOURCE)
+ SET(COMPILE_LIBRARY_ONLY ON CACHE BOOL "Build only the Marian library and exclude all executables.")
+ SET(USE_MKL OFF CACHE BOOL "Compile with MKL support")
+ # # Setting the ssplit-cpp submodule specific cmake options for wasm
+ SET(SSPLIT_USE_INTERNAL_PCRE2 ON CACHE BOOL "Use internal PCRE2 instead of system PCRE2")
+endif()
+
+# Documentation: https://cliutils.gitlab.io/modern-cmake/chapters/projects/submodule.html
+# Ensures the submodules are set correctly during a build.
+find_package(Git QUIET)
+if(GIT_FOUND AND EXISTS "${REPOSITORY_ROOT_DIR}/.git")
+# Update submodules as needed
+ option(GIT_SUBMODULE "Check submodules during build" ON)
+ if(GIT_SUBMODULE)
+ message(STATUS "Submodule update")
+ execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+ RESULT_VARIABLE GIT_SUBMOD_RESULT)
+ if(NOT GIT_SUBMOD_RESULT EQUAL "0")
+ message(FATAL_ERROR "git submodule update --init failed with ${GIT_SUBMOD_RESULT}, please checkout submodules")
+ endif()
+ endif()
+endif()
+
+# Project versioning
+include(GetVersionFromFile)
+message(STATUS "Project name: ${PROJECT_NAME}")
+message(STATUS "Project version: ${PROJECT_VERSION_STRING_FULL}")
+
+if(COMPILE_WASM)
+ # See https://github.com/emscripten-core/emscripten/blob/main/src/settings.js
+ list(APPEND WASM_COMPILE_FLAGS
+ -O3
+ # Preserve whitespaces in JS even for release builds; this doesn't increase wasm binary size
+ $<$:-g1>
+ # Relevant Debug info only for release with debug builds as this increases wasm binary size
+ $<$:-g2>
+ -fPIC
+ -mssse3
+ -msimd128
+ # -fno-exceptions # Can't do that because spdlog uses exceptions
+ -sDISABLE_EXCEPTION_CATCHING=1
+ -sSTRICT=1
+ )
+ list(APPEND WASM_LINK_FLAGS
+ -O3
+ # Preserve whitespaces in JS even for release builds; this doesn't increase wasm binary size
+ $<$:-g1>
+ # Relevant Debug info only for release with debug builds as this increases wasm binary size
+ $<$:-g2>
+ -lembind
+ # Save some code, and some speed
+ -sASSERTIONS=0
+ -sDISABLE_EXCEPTION_CATCHING=1
+ # the intgemm functions we call will be undefined since these are linked at
+ # runtime by our own javascript.
+ -sLLD_REPORT_UNDEFINED
+ -sERROR_ON_UNDEFINED_SYMBOLS=0
+ # Cause we can!
+ -sSTRICT=1
+ # You know we need it
+ -sALLOW_MEMORY_GROWTH=1
+ -sENVIRONMENT=web,worker
+ # No need to call main(), there's nothing there.
+ -sINVOKE_RUN=0
+ # No need for filesystem code in the generated Javascript
+ -sFILESYSTEM=0
+ # If you turn this on, it will mangle names which makes the dynamic linking hard.
+ -sDECLARE_ASM_MODULE_EXPORTS=0
+ # Export all of the intgemm functions in case we need to fall back to using the embedded intgemm
+ -sEXPORTED_FUNCTIONS=[_int8PrepareAFallback,_int8PrepareBFallback,_int8PrepareBFromTransposedFallback,_int8PrepareBFromQuantizedTransposedFallback,_int8PrepareBiasFallback,_int8MultiplyAndAddBiasFallback,_int8SelectColumnsOfBFallback]
+ # Necessary for mozintgemm linking. This prepares the `wasmMemory` variable ahead of time as
+ # opposed to delegating that task to the wasm binary itself. This way we can link MozIntGEMM
+ # module to the same memory as the main bergamot-translator module.
+ -sIMPORTED_MEMORY=1
+ # Dynamic execution is either frowned upon or blocked inside browser extensions
+ -sDYNAMIC_EXECUTION=0
+ )
+endif(COMPILE_WASM)
+
+# Needs to be enabled before including the folder containing tests (src/tests)
+if(COMPILE_TESTS)
+ enable_testing()
+endif(COMPILE_TESTS)
+
+add_subdirectory(3rd_party)
+add_subdirectory(src)
+
+if(COMPILE_WASM)
+ add_subdirectory(wasm)
+endif(COMPILE_WASM)
+
+option(COMPILE_PYTHON "Compile python bindings. Intended to be activated with setup.py" OFF)
+if(COMPILE_PYTHON)
+ add_subdirectory(bindings/python)
+endif(COMPILE_PYTHON)
+
diff --git a/inference/cmake/GetVersionFromFile.cmake b/inference/cmake/GetVersionFromFile.cmake
new file mode 100644
index 000000000..47c35bc23
--- /dev/null
+++ b/inference/cmake/GetVersionFromFile.cmake
@@ -0,0 +1,60 @@
+##
+# This CMake modules sets the project version from a version file.
+#
+# The module sets the following variables:
+#
+# * PROJECT_VERSION_STRING
+# * PROJECT_VERSION_STRING_FULL
+# * PROJECT_VERSION_MAJOR
+# * PROJECT_VERSION_MINOR
+# * PROJECT_VERSION_PATCH
+# * PROJECT_VERSION_TWEAK
+# * PROJECT_VERSION_GIT_SHA
+#
+# This module is public domain, use it as it fits you best.
+##
+
+# Get full string version from file
+if(PROJECT_VERSION_FILE)
+ file(STRINGS ${PROJECT_VERSION_FILE} PROJECT_VERSION_STRING)
+else()
+ file(STRINGS ${CMAKE_CURRENT_SOURCE_DIR}/BERGAMOT_VERSION PROJECT_VERSION_STRING)
+endif()
+
+# Get current commit SHA from git
+execute_process(COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+ OUTPUT_VARIABLE PROJECT_VERSION_GIT_SHA
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+
+# Get partial versions into a list
+string(REGEX MATCHALL "-.*$|[0-9]+" PROJECT_PARTIAL_VERSION_LIST
+ ${PROJECT_VERSION_STRING})
+
+# Set the version numbers
+list(GET PROJECT_PARTIAL_VERSION_LIST 0 PROJECT_VERSION_MAJOR)
+list(GET PROJECT_PARTIAL_VERSION_LIST 1 PROJECT_VERSION_MINOR)
+list(GET PROJECT_PARTIAL_VERSION_LIST 2 PROJECT_VERSION_PATCH)
+
+# The tweak part is optional, so check if the list contains it
+list(LENGTH PROJECT_PARTIAL_VERSION_LIST PROJECT_PARTIAL_VERSION_LIST_LEN)
+if(PROJECT_PARTIAL_VERSION_LIST_LEN GREATER 3)
+ list(GET PROJECT_PARTIAL_VERSION_LIST 3 PROJECT_VERSION_TWEAK)
+ string(SUBSTRING ${PROJECT_VERSION_TWEAK} 1 -1 PROJECT_VERSION_TWEAK)
+endif()
+
+# Unset the list
+unset(PROJECT_PARTIAL_VERSION_LIST)
+
+# Set full project version string
+set(PROJECT_VERSION_STRING_FULL
+ ${PROJECT_VERSION_STRING}+${PROJECT_VERSION_GIT_SHA})
+
+# Print all variables for debugging
+#message(STATUS ${PROJECT_VERSION_STRING_FULL})
+#message(STATUS ${PROJECT_VERSION_STRING})
+#message(STATUS ${PROJECT_VERSION_MAJOR})
+#message(STATUS ${PROJECT_VERSION_MINOR})
+#message(STATUS ${PROJECT_VERSION_PATCH})
+#message(STATUS ${PROJECT_VERSION_TWEAK})
+#message(STATUS ${PROJECT_VERSION_GIT_SHA})
diff --git a/inference/examples/run-native.sh b/inference/examples/run-native.sh
new file mode 100644
index 000000000..84e1302f0
--- /dev/null
+++ b/inference/examples/run-native.sh
@@ -0,0 +1,19 @@
+# In source-root folder
+
+# Obtain an example model from the web.
+mkdir -p models
+wget --quiet --continue --directory models/ \
+ https://data.statmt.org/bergamot/models/deen/ende.student.tiny11.v2.93821e13b3c511b5.tar.gz
+(cd models && tar -xzf ende.student.tiny11.v2.93821e13b3c511b5.tar.gz)
+
+# Patch the config-files generated from marian for use in bergamot.
+python3 bergamot-translator-tests/tools/patch-marian-for-bergamot.py \
+ --config-path models/ende.student.tiny11/config.intgemm8bitalpha.yml \
+ --ssplit-prefix-file $(realpath 3rd_party/ssplit-cpp/nonbreaking_prefixes/nonbreaking_prefix.en)
+
+# Patched config file will be available with .bergamot.yml suffix.
+CONFIG=models/ende.student.tiny11/config.intgemm8bitalpha.yml.bergamot.yml
+
+build/app/bergamot --model-config-paths $CONFIG --cpu-threads 4 <<< "Hello World!"
+# Hallo Welt!
+
diff --git a/inference/patches/01-marian-fstream-for-macos.patch b/inference/patches/01-marian-fstream-for-macos.patch
new file mode 100644
index 000000000..6b521ba7e
--- /dev/null
+++ b/inference/patches/01-marian-fstream-for-macos.patch
@@ -0,0 +1,13 @@
+diff --git a/3rd_party/browsermt-marian-dev/src/3rd_party/zstr/strict_fstream.hpp b/3rd_party/browsermt-marian-dev/src/3rd_party/zstr/strict_fstream.hpp
+index 7b1173931df977e69021f3995fa064a492f89d38..948e91eaf99b6b29ce41cf793fba6717f3b5f5b5 100644
+--- a/3rd_party/browsermt-marian-dev/src/3rd_party/zstr/strict_fstream.hpp
++++ b/3rd_party/browsermt-marian-dev/src/3rd_party/zstr/strict_fstream.hpp
+@@ -27,7 +27,7 @@ static std::string strerror()
+ {
+ buff = "Unknown error";
+ }
+-#elif (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600 || __APPLE__) && ! _GNU_SOURCE
++#elif (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600 || __APPLE__)
+ // XSI-compliant strerror_r()
+ if (strerror_r(errno, &buff[0], buff.size()) != 0)
+ {
diff --git a/inference/scripts/build-local.sh b/inference/scripts/build-local.sh
new file mode 100755
index 000000000..ae64689fe
--- /dev/null
+++ b/inference/scripts/build-local.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+set -e
+
+# Run script from the context of inference directory
+cd "$(dirname $0)/.."
+
+# Ensure script is running within docker
+./scripts/detect-docker.sh inference-build
+
+# Return the number of available CPUs, or default to 1 if nproc is unavailable.
+detect_cpus() {
+ if command -v nproc >/dev/null 2>&1; then
+ nproc
+ else
+ echo 1
+ fi
+}
+
+# Parse command-line arguments for the --test flag
+COMPILE_TESTS=OFF
+while [[ "$#" -gt 0 ]]; do
+ case $1 in
+ "--test") COMPILE_TESTS=ON ;;
+ *) echo "Unknown parameter passed: $1"; exit 1 ;;
+ esac
+ shift
+done
+
+if [ ! -d "build-local" ]; then
+ echo "Creating build-local directory..."
+ mkdir build-local
+else
+ echo "build-local directory already exists. Skipping creation."
+fi
+
+cd build-local || exit
+
+# Run cmake with optional COMPILE_TESTS flag
+echo "Running cmake for build-local..."
+if [ "$COMPILE_TESTS" = "ON" ]; then
+ cmake ../ -DCOMPILE_TESTS=ON
+else
+ cmake ../
+fi
+
+# Run make using the detected number of CPUs
+CPUS=$(detect_cpus)
+echo "Running make for build-local with $CPUS CPUs..."
+make -j ${CPUS}
+
diff --git a/inference/scripts/build-wasm.sh b/inference/scripts/build-wasm.sh
new file mode 100755
index 000000000..c21eea985
--- /dev/null
+++ b/inference/scripts/build-wasm.sh
@@ -0,0 +1,69 @@
+#!/usr/bin/env bash
+set -e
+
+# Run script from the context of inference directory
+cd "$(dirname $0)/.."
+
+# Ensure script is running within docker
+./scripts/detect-docker.sh inference-build-wasm
+
+set -x
+
+# Prerequisite: Download and Install Emscripten using following instructions (unless the EMSDK env var is already set)
+if [ "$EMSDK" == "" ]; then
+ EMSDK_UPDATE_REQUIRED=0
+ if [ ! -d "emsdk" ]; then
+ git clone https://github.com/emscripten-core/emsdk.git
+ EMSDK_UPDATE_REQUIRED=1
+ else
+ cd emsdk
+ git fetch
+ # Only pull if necessary
+ if [ $(git rev-parse HEAD) != $(git rev-parse @{u}) ]; then
+ git pull --ff-only
+ EMSDK_UPDATE_REQUIRED=1
+ fi
+ cd -
+ fi
+ if [ "$EMSDK_UPDATE_REQUIRED" == "1" ]; then
+ cd emsdk
+ ./emsdk install 3.1.8
+ ./emsdk activate 3.1.8
+ cd -
+ fi
+ source ./emsdk/emsdk_env.sh
+fi
+
+# Compile
+# 1. Create a folder where you want to build all the artifacts and compile
+BUILD_DIRECTORY="build-wasm"
+if [ ! -d ${BUILD_DIRECTORY} ]; then
+ mkdir ${BUILD_DIRECTORY}
+fi
+cd ${BUILD_DIRECTORY}
+
+emcmake cmake -DCOMPILE_WASM=on ../
+emmake make -j2
+
+# 2. Import GEMM library from a separate wasm module
+bash ../wasm/patch-artifacts-import-gemm-module.sh
+
+set +x
+echo ""
+echo "Build complete"
+echo ""
+echo " ./build-wasm/bergamot-translator-worker.js"
+echo " ./build-wasm/bergamot-translator-worker.wasm"
+
+WASM_SIZE=$(wc -c bergamot-translator-worker.wasm | awk '{print $1}')
+GZIP_SIZE=$(gzip -c bergamot-translator-worker.wasm | wc -c | xargs) # xargs trims the whitespace
+
+# Convert it to human readable.
+WASM_SIZE="$(awk 'BEGIN {printf "%.2f",'$WASM_SIZE'/1048576}')M ($WASM_SIZE bytes)"
+GZIP_SIZE="$(awk 'BEGIN {printf "%.2f",'$GZIP_SIZE'/1048576}')M ($GZIP_SIZE bytes)"
+
+echo " Uncompressed wasm size: $WASM_SIZE"
+echo " Compressed wasm size: $GZIP_SIZE"
+
+# The artifacts (.js and .wasm files) will be available in the build directory
+exit 0
diff --git a/inference/scripts/clean.sh b/inference/scripts/clean.sh
new file mode 100755
index 000000000..73f5ae5eb
--- /dev/null
+++ b/inference/scripts/clean.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+set -e
+
+# Run script from the context of inference directory
+cd "$(dirname $0)/.."
+
+# Ensure script is running within docker
+./scripts/detect-docker.sh inference-clean
+
+# List of directories to clean
+dirs=("build-local" "build-wasm" "emsdk")
+
+# Flag to track if any directories were cleaned
+cleaned=false
+
+# Check and remove directories
+for dir in "${dirs[@]}"; do
+ if [ -d "$dir" ]; then
+ echo "Removing $dir..."
+ rm -rf "$dir"
+ cleaned=true
+ fi
+done
+
+# If no directories were cleaned, print a message
+if [ "$cleaned" = false ]; then
+ echo "Nothing to clean"
+fi
+
diff --git a/inference/scripts/detect-docker.sh b/inference/scripts/detect-docker.sh
new file mode 100755
index 000000000..c1065349a
--- /dev/null
+++ b/inference/scripts/detect-docker.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+help_task=$1
+
+if [ -z "${IS_DOCKER}" ]; then
+ if [ "${ALLOW_RUN_ON_HOST}" != "1" ]; then
+ echo >&2
+ echo "Error: This script needs to be run inside Docker, or you must set ALLOW_RUN_ON_HOST=1." >&2
+ echo >&2
+ if [ -n "${help_task}" ]; then
+ echo " Help: To run this script directly in docker, run: task ${help_task}" >&2
+ fi
+ echo " Help: To enter docker, run: task docker" >&2
+ exit 1
+ else
+ echo >&2
+ echo "ALLOW_RUN_ON_HOST is set to 1. Continuing..." >&2
+ fi
+fi
diff --git a/inference/scripts/unit-tests.sh b/inference/scripts/unit-tests.sh
new file mode 100755
index 000000000..dd8be9925
--- /dev/null
+++ b/inference/scripts/unit-tests.sh
@@ -0,0 +1,49 @@
+#!/bin/bash
+set -e
+
+# Run script from the context of inference directory
+cd "$(dirname $0)/.."
+
+# Ensure script is running within docker
+./scripts/detect-docker.sh inference-test
+
+# Check if build-local/src/tests/units directory exists
+if [ ! -d "build-local/src/tests/units" ]; then
+ echo "Directory build-local/src/tests/units does not exist. Running build."
+ ./scripts/build-local.sh --test
+else
+ echo "Directory build-local/src/tests/units already exists. Skipping build."
+fi
+
+# Change to the unit tests directory
+cd build-local/src/tests/units
+
+# List of test commands
+tests=(
+ "./run_annotation_tests"
+ "./run_cache_tests"
+ "./run_html_tests"
+ "./run_quality_estimator_tests"
+ "./run_xh_scanner_tests"
+)
+
+# Run all tests, collect failures
+failures=0
+
+for test in "${tests[@]}"; do
+ echo "Running $test..."
+ if ! $test; then
+ echo "$test failed!"
+ failures=$((failures + 1))
+ fi
+done
+
+# If any test failed, exit with a non-zero status
+if [ $failures -gt 0 ]; then
+ echo "$failures test(s) failed."
+ exit 1
+else
+ echo "All tests passed successfully."
+ exit 0
+fi
+
diff --git a/inference/src/CMakeLists.txt b/inference/src/CMakeLists.txt
new file mode 100644
index 000000000..856831be9
--- /dev/null
+++ b/inference/src/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_subdirectory(translator)
+
+if (COMPILE_TESTS)
+ add_subdirectory(tests)
+endif(COMPILE_TESTS)
+
diff --git a/inference/src/tests/CMakeLists.txt b/inference/src/tests/CMakeLists.txt
new file mode 100644
index 000000000..cd0e4c777
--- /dev/null
+++ b/inference/src/tests/CMakeLists.txt
@@ -0,0 +1,24 @@
+# Unit tests
+
+# Include Catch explicitly from marian.
+set(CATCH_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/3rd_party/browsermt-marian-dev/3rd-party)
+add_library(Catch INTERFACE)
+target_include_directories(Catch INTERFACE ${CATCH_INCLUDE_DIR})
+
+if (COMPILE_UNIT_TESTS)
+ add_subdirectory(units)
+endif (COMPILE_UNIT_TESTS)
+
+
+
+if(NOT MSVC)
+ # Testing apps
+ set(TEST_BINARIES async blocking intgemm-resolve wasm)
+ foreach(binary ${TEST_BINARIES})
+ add_executable("${binary}" "${binary}.cpp")
+ target_link_libraries("${binary}" bergamot-translator)
+ set_target_properties("${binary}" PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/tests/")
+ endforeach(binary)
+
+endif(NOT MSVC)
+
diff --git a/inference/src/tests/async.cpp b/inference/src/tests/async.cpp
new file mode 100644
index 000000000..25ba334ae
--- /dev/null
+++ b/inference/src/tests/async.cpp
@@ -0,0 +1,27 @@
+#include "common.h"
+#include "translator/parser.h"
+#include "translator/service.h"
+#include "translator/translation_model.h"
+
+using namespace marian::bergamot;
+
+int main(int argc, char *argv[]) {
+ ConfigParser configParser("AsyncService test-suite", /*multiOpMode=*/true);
+ configParser.parseArgs(argc, argv);
+ auto &config = configParser.getConfig();
+
+ AsyncService service(config.serviceConfig);
+
+ std::vector> models;
+
+ for (auto &modelConfigPath : config.modelConfigPaths) {
+ TranslationModel::Config modelConfig = parseOptionsFromFilePath(modelConfigPath);
+ std::shared_ptr model = service.createCompatibleModel(modelConfig);
+ models.push_back(model);
+ }
+
+ TestSuite testSuite(service);
+ testSuite.run(config.opMode, models);
+
+ return 0;
+}
diff --git a/inference/src/tests/blocking.cpp b/inference/src/tests/blocking.cpp
new file mode 100644
index 000000000..3bbb45634
--- /dev/null
+++ b/inference/src/tests/blocking.cpp
@@ -0,0 +1,25 @@
+#include "common.h"
+using namespace marian::bergamot;
+
+int main(int argc, char *argv[]) {
+ ConfigParser configParser("BlockingService test-suite", /*multiOpMode=*/true);
+ configParser.parseArgs(argc, argv);
+
+ auto &config = configParser.getConfig();
+ BlockingService service(config.serviceConfig);
+
+ TestSuite testSuite(service);
+ std::vector> models;
+
+ for (auto &modelConfigPath : config.modelConfigPaths) {
+ TranslationModel::Config modelConfig = parseOptionsFromFilePath(modelConfigPath);
+ std::shared_ptr model = std::make_shared(modelConfig);
+ models.push_back(model);
+ }
+
+ /// WASM is one special case where WASM path is being checked, involving translateMultiple and a multi-line feed.
+ /// Hence we do not bind it at a single input-blob single Response constraint imposed by the TestSuite.
+ testSuite.run(config.opMode, models);
+
+ return 0;
+}
diff --git a/inference/src/tests/common-impl.cpp b/inference/src/tests/common-impl.cpp
new file mode 100644
index 000000000..431ddaa71
--- /dev/null
+++ b/inference/src/tests/common-impl.cpp
@@ -0,0 +1,316 @@
+
+#ifndef BERGAMOT_TESTS_COMMON_IMPL
+#error "This is an impl file and must not be included directly!"
+#endif
+
+Response Bridge::translate(BlockingService &service, std::shared_ptr &model,
+ std::string &&source, const ResponseOptions &responseOptions) {
+ // project source to a vector of std::string, send in, unpack the first element from
+ // vector, return.
+ std::vector sources = {source};
+ std::vector options = {responseOptions};
+ return service.translateMultiple(model, std::move(sources), options).front();
+}
+
+Response Bridge::translate(AsyncService &service, std::shared_ptr &model,
+ std::string &&source, const ResponseOptions &responseOptions) {
+ // downgrade to blocking via promise, future, wait and return response;
+ std::promise responsePromise;
+ std::future responseFuture = responsePromise.get_future();
+
+ auto callback = [&responsePromise](Response &&response) { responsePromise.set_value(std::move(response)); };
+ service.translate(model, std::move(source), callback, responseOptions);
+
+ responseFuture.wait();
+
+ Response response = responseFuture.get();
+ return response;
+}
+
+Response Bridge::pivot(BlockingService &service, std::shared_ptr &sourceToPivot,
+ std::shared_ptr &pivotToTarget, std::string &&source,
+ const ResponseOptions &responseOptions) {
+ std::vector sources = {source};
+ std::vector options = {responseOptions};
+ return service.pivotMultiple(sourceToPivot, pivotToTarget, std::move(sources), options).front();
+}
+
+Response Bridge::pivot(AsyncService &service, std::shared_ptr &sourceToPivot,
+ std::shared_ptr &pivotToTarget, std::string &&source,
+ const ResponseOptions &responseOptions) {
+ std::promise responsePromise;
+ std::future responseFuture = responsePromise.get_future();
+
+ auto callback = [&responsePromise](Response &&response) { responsePromise.set_value(std::move(response)); };
+ service.pivot(sourceToPivot, pivotToTarget, std::move(source), callback, responseOptions);
+ responseFuture.wait();
+ Response response = responseFuture.get();
+ return response;
+}
+
+template
+TestSuite::TestSuite(Service &service) : service_{service} {}
+
+template
+void TestSuite::TestSuite::run(const std::string &opModeAsString, std::vector> &models) {
+ if (opModeAsString == "decoder") {
+ benchmarkDecoder(models.front());
+ } else if (opModeAsString == "test-response-source-sentences") {
+ annotatedTextSentences(models.front(), /*source=*/true);
+ } else if (opModeAsString == "test-response-target-sentences") {
+ annotatedTextSentences(models.front(), /*source=*/false);
+ } else if (opModeAsString == "test-response-source-words") {
+ annotatedTextWords(models.front(), /*source=*/true);
+ } else if (opModeAsString == "test-response-target-words") {
+ annotatedTextWords(models.front(), /*source=*/false);
+ } else if (opModeAsString == "test-forward-backward") {
+ forwardAndBackward(models);
+ } else if (opModeAsString == "test-quality-estimator-words") {
+ qualityEstimatorWords(models.front());
+ } else if (opModeAsString == "test-quality-estimator-scores") {
+ qualityEstimatorScores(models.front());
+ } else if (opModeAsString == "test-translation-cache") {
+ translationCache(models.front());
+ } else if (opModeAsString == "test-pivot") {
+ pivotTranslate(models);
+ } else if (opModeAsString == "test-pivot-with-html") {
+ pivotTranslateWithHTML(models);
+ } else if (opModeAsString == "test-html-translation") {
+ htmlTranslation(models.front());
+ } else {
+ std::cerr << "Incompatible test mode. Choose from the one of the valid test-modes";
+ std::abort();
+ }
+}
+
+template
+void TestSuite::benchmarkDecoder(Ptr &model) {
+ marian::timer::Timer decoderTimer;
+ std::string source = readFromStdin();
+
+ ResponseOptions responseOptions;
+ Response response = bridge_.translate(service_, model, std::move(source), responseOptions);
+
+ for (size_t sentenceIdx = 0; sentenceIdx < response.size(); sentenceIdx++) {
+ std::cout << response.target.sentence(sentenceIdx) << "\n";
+ }
+
+ std::cerr << "Total time: " << std::setprecision(5) << decoderTimer.elapsed() << "s wall" << std::endl;
+}
+
+// Reads from stdin and translates. Prints the tokens separated by space for each sentence. Prints words from source
+// side text annotation if source=true, target annotation otherwise.
+template
+void TestSuite::annotatedTextWords(Ptr model, bool sourceSide /*=true*/) {
+ ResponseOptions responseOptions;
+ std::string source = readFromStdin();
+ Response response = bridge_.translate(service_, model, std::move(source), responseOptions);
+ AnnotatedText &annotatedText = sourceSide ? response.source : response.target;
+ for (size_t s = 0; s < annotatedText.numSentences(); s++) {
+ for (size_t w = 0; w < annotatedText.numWords(s); w++) {
+ std::cout << (w == 0 ? "" : "\t");
+ std::cout << annotatedText.word(s, w);
+ }
+ std::cout << "\n";
+ }
+}
+
+// Reads from stdin and translates the read content. Prints the sentences in source or target in constructed response
+// in each line, depending on source = true or false respectively.
+template
+void TestSuite::annotatedTextSentences(Ptr model, bool sourceSide /*=true*/) {
+ ResponseOptions responseOptions;
+ std::string source = readFromStdin();
+ Response response = bridge_.translate(service_, model, std::move(source), responseOptions);
+ AnnotatedText &annotatedText = sourceSide ? response.source : response.target;
+ for (size_t s = 0; s < annotatedText.numSentences(); s++) {
+ std::cout << annotatedText.sentence(s) << "\n";
+ }
+}
+
+template
+void TestSuite::forwardAndBackward(std::vector> &models) {
+ ABORT_IF(models.size() != 2, "Forward and backward test needs two models.");
+ ResponseOptions responseOptions;
+ std::string source = readFromStdin();
+ Response forwardResponse = bridge_.translate(service_, models.front(), std::move(source), responseOptions);
+
+ // Make a copy of target
+ std::string target = forwardResponse.target.text;
+ Response backwardResponse = bridge_.translate(service_, models.back(), std::move(target), responseOptions);
+
+ // Print both onto the command-line
+ std::cout << forwardResponse.source.text;
+ std::cout << "----------------\n";
+ std::cout << forwardResponse.target.text;
+ std::cout << "----------------\n";
+ std::cout << backwardResponse.target.text;
+}
+
+// Reads from stdin and translates the read content. Prints the quality words for each sentence.
+template
+void TestSuite::qualityEstimatorWords(Ptr model) {
+ ResponseOptions responseOptions;
+ responseOptions.qualityScores = true;
+ std::string source = readFromStdin();
+ const Response response = bridge_.translate(service_, model, std::move(source), responseOptions);
+
+ for (size_t sentenceIdx = 0; sentenceIdx < response.qualityScores.size(); ++sentenceIdx) {
+ const auto &sentenceQualityEstimate = response.qualityScores[sentenceIdx];
+ std::cout << "[SentenceBegin]\n";
+
+ for (const auto &wordByteRange : getWordByteRanges(response, sentenceIdx)) {
+ const string_view word(response.target.text.data() + wordByteRange.begin, wordByteRange.size());
+ std::cout << word << "\n";
+ }
+ std::cout << "[SentenceEnd]\n\n";
+ }
+}
+
+template
+void TestSuite::htmlTranslation(Ptr model) {
+ ResponseOptions responseOptions;
+ responseOptions.HTML = true;
+ responseOptions.alignment = true;
+ std::string source = readFromStdin();
+ const Response response = bridge_.translate(service_, model, std::move(source), responseOptions);
+
+ std::cout << response.target.text;
+}
+
+// Reads from stdin and translates the read content. Prints the quality scores for each sentence.
+template
+void TestSuite::qualityEstimatorScores(Ptr model) {
+ ResponseOptions responseOptions;
+ responseOptions.qualityScores = true;
+
+ std::string source = readFromStdin();
+ const Response response = bridge_.translate(service_, model, std::move(source), responseOptions);
+
+ for (const auto &sentenceQualityEstimate : response.qualityScores) {
+ std::cout << std::fixed << std::setprecision(3) << sentenceQualityEstimate.sentenceScore << "\n";
+
+ for (const float &wordScore : sentenceQualityEstimate.wordScores) {
+ std::cout << std::fixed << std::setprecision(3) << wordScore << "\n";
+ }
+ std::cout << "\n";
+ }
+}
+
+template
+void TestSuite::translationCache(Ptr model) {
+ ResponseOptions responseOptions;
+
+ // Read a large input text blob from stdin
+ const std::string source = readFromStdin();
+
+ // Round 1
+ std::string buffer = source;
+ Response firstResponse = bridge_.translate(service_, model, std::move(buffer), responseOptions);
+
+ auto statsFirstRun = service_.cacheStats();
+ LOG(info, "Cache Hits/Misses = {}/{}", statsFirstRun.hits, statsFirstRun.misses);
+ ABORT_IF(statsFirstRun.hits != 0, "Expecting no cache hits, but hits found.");
+
+ // Round 2; There should be cache hits
+ buffer = source;
+ Response secondResponse = bridge_.translate(service_, model, std::move(buffer), responseOptions);
+
+ auto statsSecondRun = service_.cacheStats();
+ LOG(info, "Cache Hits/Misses = {}/{}", statsSecondRun.hits, statsSecondRun.misses);
+ ABORT_IF(statsSecondRun.hits <= 0, "At least one hit expected, none found.");
+ if (statsSecondRun.hits != statsFirstRun.misses) {
+ std::cerr << "Mismatch in expected hits (Hits, Misses = " << statsSecondRun.hits << ", " << statsSecondRun.misses
+ << "). This can happen due to random eviction." << std::endl;
+ }
+
+ ABORT_IF(firstResponse.target.text != secondResponse.target.text,
+ "Recompiled string provided different output when operated with cache. On the same hardware while using "
+ "same path, this is expected to be same.");
+
+ std::cout << firstResponse.target.text;
+}
+
+template
+void TestSuite::pivotTranslateWithHTML(std::vector> &models) {
+ ABORT_IF(models.size() != 2, "Forward and backward test needs two models.");
+ ResponseOptions responseOptions;
+ responseOptions.HTML = true;
+ std::string source = readFromStdin();
+ std::promise responsePromise;
+ std::future responseFuture = responsePromise.get_future();
+ Response response = bridge_.pivot(service_, models.front(), models.back(), std::move(source), responseOptions);
+ std::cout << response.source.text;
+ std::cout << response.target.text;
+}
+
+template
+void TestSuite::pivotTranslate(std::vector> &models) {
+ // We expect a source -> pivot; pivot -> source model to get source -> source and build this test using accuracy of
+ // matches.
+ ABORT_IF(models.size() != 2, "Forward and backward test needs two models.");
+ ResponseOptions responseOptions;
+ responseOptions.alignment = true;
+ std::string source = readFromStdin();
+ std::promise responsePromise;
+ std::future responseFuture = responsePromise.get_future();
+
+ Response response = bridge_.pivot(service_, models.front(), models.back(), std::move(source), responseOptions);
+
+ const float EPS = 1e-5;
+ size_t totalOutcomes = 0;
+ size_t favourableOutcomes = 0;
+
+ for (size_t sentenceId = 0; sentenceId < response.source.numSentences(); sentenceId++) {
+ std::cout << "> " << response.source.sentence(sentenceId) << "\n";
+ std::cout << "< " << response.target.sentence(sentenceId) << "\n\n";
+
+ // Assert what we have is a probability distribution over source-tokens given a target token.
+ for (size_t t = 0; t < response.alignments[sentenceId].size(); t++) {
+ float sum = 0.0f;
+ for (size_t s = 0; s < response.alignments[sentenceId][t].size(); s++) {
+ sum += response.alignments[sentenceId][t][s];
+ }
+
+ std::cerr << fmt::format("Sum @ (target-token = {}, sentence = {}) = {}", t, sentenceId, sum) << std::endl;
+ ABORT_IF((std::abs(sum - 1.0f) > EPS), "Not a probability distribution, something's going wrong");
+ }
+
+ // For each target token, find argmax s, i.e find argmax p(s | t), max p(s | t)
+ for (size_t t = 0; t < response.alignments[sentenceId].size(); t++) {
+ bool valid = false;
+ float maxV = 0.0f;
+ auto argmaxV = std::make_pair(-1, -1);
+ for (size_t s = 0; s < response.alignments[sentenceId][t].size(); s++) {
+ auto v = response.alignments[sentenceId][t][s];
+ if (v > maxV) {
+ maxV = v;
+ argmaxV = std::make_pair(t, s);
+ }
+ }
+
+ auto sourceToken = response.source.word(sentenceId, argmaxV.second);
+ auto targetToken = response.target.word(sentenceId, argmaxV.first);
+ if (sourceToken == targetToken) {
+ favourableOutcomes += 1;
+ }
+
+ std::cerr << sourceToken << " " << targetToken << " " << maxV << std::endl;
+
+ totalOutcomes += 1;
+ }
+
+ // Assert each alignment over target is a valid probability distribution.
+ }
+
+ // Measure accuracy of word match.
+ float accuracy = static_cast(favourableOutcomes) / static_cast(totalOutcomes);
+
+ // This is arbitrary value chosen by @jerinphilip, but should be enough to check if things fail.
+ // This value is calibrated on bergamot input in BRT. All this is supposed to do is let the developers know if
+ // something's largely amiss to the point alignments are not working.
+ ABORT_IF(accuracy < 0.70, "Accuracy {} not enough. Please check if something's off.", accuracy * 100);
+
+ std::cout << response.source.text;
+ std::cout << response.target.text;
+}
diff --git a/inference/src/tests/common.h b/inference/src/tests/common.h
new file mode 100644
index 000000000..238a62357
--- /dev/null
+++ b/inference/src/tests/common.h
@@ -0,0 +1,100 @@
+#pragma once
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "common/definitions.h"
+#include "common/timer.h"
+#include "common/utils.h"
+#include "marian.h"
+#include "translator/byte_array_util.h"
+#include "translator/parser.h"
+#include "translator/response.h"
+#include "translator/response_options.h"
+#include "translator/service.h"
+#include "translator/utils.h"
+
+namespace marian::bergamot {
+
+/// Due to the stubborn-ness of the extension and native to not agree on API (e.g, translateMultiple vs translate),
+/// different underlying cache we have the following "bridge" at test-applications - taking into account the fact that
+/// the most commonly used primitives across both Services is a single text blob in and corresponding Response out, in a
+/// blocking fashion.
+///
+/// The following contraption constrains a single sentence to single Response parameterized by Service, in a test-suite
+/// below. This allows sharing of code for test-suite between WebAssembly's workflows and Native's workflows.
+///
+/// The intention here is to use templating to achieve the same thing an ifdef would have at compile-time. Also mandates
+/// after bridge layer, both WebAssembly and Native paths compile correctly (this does not guarantee outputs are the
+/// same through both code-paths, or that both are tested at runtime - only that both compile and work with a bridge).
+///
+/// For any complex workflows involving non-blocking concurrent translation, it is required to write something not
+/// constrained by the following.
+
+template
+struct Bridge : public std::false_type {};
+
+template <>
+struct Bridge : public std::true_type {
+ Response translate(BlockingService &service, std::shared_ptr &model, std::string &&source,
+ const ResponseOptions &responseOptions);
+ Response pivot(BlockingService &service, std::shared_ptr &sourceToPivot,
+ std::shared_ptr &pivotToTarget, std::string &&source,
+ const ResponseOptions &responseOptions);
+};
+
+template <>
+struct Bridge : public std::true_type {
+ Response translate(AsyncService &service, std::shared_ptr &model, std::string &&source,
+ const ResponseOptions &responseOptions);
+ Response pivot(AsyncService &service, std::shared_ptr &sourceToPivot,
+ std::shared_ptr &pivotToTarget, std::string &&source,
+ const ResponseOptions &responseOptions);
+};
+
+template
+class TestSuite {
+ private:
+ Bridge bridge_;
+ Service &service_;
+
+ public:
+ TestSuite(Service &service);
+ void run(const std::string &opModeAsString, std::vector> &models);
+
+ private:
+ void benchmarkDecoder(Ptr &model);
+
+ // Reads from stdin and translates. Prints the tokens separated by space for each sentence. Prints words from source
+ // side text annotation if source=true, target annotation otherwise.
+ void annotatedTextWords(Ptr model, bool sourceSide = true);
+
+ // Reads from stdin and translates the read content. Prints the sentences in source or target in constructed response
+ // in each line, depending on source = true or false respectively.
+ void annotatedTextSentences(Ptr model, bool sourceSide = true);
+
+ void forwardAndBackward(std::vector> &models);
+
+ // Reads from stdin and translates the read content. Prints the quality words for each sentence.
+ void qualityEstimatorWords(Ptr model);
+
+ // Reads from stdin and translates the read content. Prints the quality scores for each sentence.
+ void qualityEstimatorScores(Ptr model);
+
+ void translationCache(Ptr model);
+
+ void pivotTranslate(std::vector> &models);
+
+ void pivotTranslateWithHTML(std::vector> &models);
+
+ void htmlTranslation(Ptr model);
+};
+
+#define BERGAMOT_TESTS_COMMON_IMPL
+#include "common-impl.cpp"
+#undef BERGAMOT_TESTS_COMMON_IMPL
+
+} // namespace marian::bergamot
diff --git a/inference/src/tests/intgemm-resolve.cpp b/inference/src/tests/intgemm-resolve.cpp
new file mode 100644
index 000000000..f95d0c449
--- /dev/null
+++ b/inference/src/tests/intgemm-resolve.cpp
@@ -0,0 +1,8 @@
+#include
+
+#include "intgemm/intgemm.h"
+
+int main() {
+ std::cout << static_cast(intgemm::kCPU) << "\n";
+ return 0;
+}
diff --git a/inference/src/tests/units/CMakeLists.txt b/inference/src/tests/units/CMakeLists.txt
new file mode 100644
index 000000000..9cfb50006
--- /dev/null
+++ b/inference/src/tests/units/CMakeLists.txt
@@ -0,0 +1,25 @@
+# Unit tests
+set(UNIT_TESTS
+ annotation_tests
+ cache_tests
+ quality_estimator_tests
+ html_tests
+ xh_scanner_tests)
+
+foreach(test ${UNIT_TESTS})
+ add_executable("run_${test}" run_tests.cpp "${test}.cpp")
+ target_include_directories("run_${test}" PRIVATE ${CATCH_INCLUDE_DIR} "${CMAKE_SOURCE_DIR}/src")
+
+ if(CUDA_FOUND)
+ target_link_libraries("run_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS} Catch bergamot-translator)
+ else(CUDA_FOUND)
+ target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch bergamot-translator)
+ endif(CUDA_FOUND)
+
+ if(msvc)
+ # disable c4305: truncation from 'double' to '_ty'
+ target_compile_options("run_${test}" public /wd4305)
+ endif(msvc)
+
+ add_test(NAME ${test} COMMAND "run_${test}")
+endforeach(test)
diff --git a/inference/src/tests/units/annotation_tests.cpp b/inference/src/tests/units/annotation_tests.cpp
new file mode 100644
index 000000000..d7178f4df
--- /dev/null
+++ b/inference/src/tests/units/annotation_tests.cpp
@@ -0,0 +1,214 @@
+#include
+#include
+
+#include "catch.hpp"
+#include "translator/annotation.h"
+
+using namespace marian::bergamot;
+
+TEST_CASE("Test Annotation API with random sentences") {
+ /// Objective here is to test insertion for sentences, and that whatever comes
+ /// out adheres to the way it was inserted. Towards this, we keep externally
+ /// which sentence went in where and try to use accessor methods on
+ /// AnnotatedText to check if what we have as ground-truth by construction is
+ /// consistent with what is returned.
+ size_t sentences = 500;
+ size_t maxWords = 40;
+
+ // Set in case needed to see output. The output is in lines of #sentences +
+ // header, which can be split and compared for easy understanding. The ideal
+ // way to inspect what is going wrong is to redirect output and use to split
+ // the different stages by sentences + 1 lines and check the diff.
+ bool debug{false};
+
+ std::mt19937 randomIntGen_;
+ randomIntGen_.seed(42);
+
+ // External book-keeping so we have ground truths. Each element represents a
+ // sentence.
+
+ // word byte ranges - for testAnnotation.word(sId, wId)
+ std::vector> groundTruthWords;
+ // sentence byte ranges - for testAnnotation.sentence(sId, wId)
+ std::vector groundTruthSentences;
+
+ // Prepare the text and construct ByteRanges as intended for sentences and
+ // words. The ByteRanges we construct here are expected to be the
+ // ground-truths for words and sentences. The string being constructed is like
+ // as follows:
+ //
+ // 0-0 0-1 0-2 0-3
+ // 1-0 1-1 1-2 1-3 1-4
+ // 2-0 2-1
+ //
+ // 4-0 4-1 4-2 4-3
+ //
+ // Tokens are contiguous because that's how SentencePiece works.
+ //
+ // Below, we accumulate the text with intended structure as above, and
+ // ground-truth tables populated to be aware of the ByteRanges where they are
+ // meant to be.
+ if (debug) {
+ std::cout << "Preparing text and ground truth-tables" << std::endl;
+ }
+ std::string text;
+ for (size_t idx = 0; idx < sentences; idx++) {
+ if (idx != 0) text += "\n";
+
+ // Words can be zero, we need to support empty word sentences as well.
+ size_t numWords = randomIntGen_() % maxWords;
+
+ std::vector wordByteRanges;
+ wordByteRanges.reserve(numWords);
+
+ // For empty sentence, we expect it to be empty and marked in position where
+ // the existing string is if needed to be pointed out.
+ size_t before = text.size() - 1;
+ size_t sentenceBegin{before}, sentenceEnd{before};
+
+ for (size_t idw = 0; idw < numWords; idw++) {
+ // Get new beginning, accounting for space above.
+ before = text.size();
+
+ // Add the word
+ std::string word = std::to_string(idx) + "-" + std::to_string(idw);
+ text += word;
+
+ // Do math, before, before + new-word's size.
+ wordByteRanges.push_back((ByteRange){before, before + word.size()});
+
+ if (debug) {
+ std::cout << word;
+ }
+
+ if (idw == 0) {
+ sentenceBegin = before;
+ }
+ if (idw == numWords - 1) {
+ sentenceEnd = before + word.size();
+ }
+ }
+ if (debug) {
+ std::cout << std::endl;
+ }
+
+ groundTruthWords.push_back(wordByteRanges);
+ groundTruthSentences.push_back((ByteRange){sentenceBegin, sentenceEnd});
+ }
+
+ AnnotatedText testAnnotation(std::move(text)); // This the container we add through API and
+ // check if the access is correct.
+
+ // We prepare string_views now with the known ByteRanges and use the
+ // string_view based AnnotatedText.addSentence(...) API to add sentences to
+ // transparently convert from string_views to ByteRanges, rebasing/working out
+ // the math underneath.
+
+ if (debug) {
+ std::cout << "Inserting words onto container and save ground-truth-table:" << std::endl;
+ }
+
+ std::vector> wordStringViews;
+ std::vector::const_iterator sentence_iter = groundTruthSentences.begin();
+ for (auto &sentence : groundTruthWords) {
+ std::vector wordByteRanges;
+ bool first{true};
+ for (auto &word : sentence) {
+ marian::string_view wordView(&testAnnotation.text[word.begin], word.size());
+ wordByteRanges.push_back(wordView);
+ if (debug) {
+ if (first) {
+ first = false;
+ } else {
+ std::cout << " ";
+ }
+ std::cout << std::string(wordView);
+ }
+ }
+ testAnnotation.recordExistingSentence(wordByteRanges.begin(), wordByteRanges.end(),
+ testAnnotation.text.data() + sentence_iter->begin);
+ ++sentence_iter;
+ wordStringViews.push_back(wordByteRanges);
+ if (debug) {
+ std::cout << std::endl;
+ }
+ }
+
+ if (debug) {
+ std::cout << "Inserting sentences onto container and save ground-truth-table" << std::endl;
+ }
+ std::vector sentenceStringViews;
+ for (auto &sentenceByteRange : groundTruthSentences) {
+ char *data = &(testAnnotation.text[sentenceByteRange.begin]);
+ marian::string_view sentenceView(data, sentenceByteRange.size());
+ sentenceStringViews.push_back(sentenceView);
+
+ if (debug) {
+ std::cout << sentenceView << std::endl;
+ }
+ }
+
+ // Access from the sentence(sentenceIdx) API and confirm that the ground truth
+ // we expect is same as what comes out of the container.
+ if (debug) {
+ std::cout << "From container: Sentences" << std::endl;
+ }
+ for (int idx = 0; idx < groundTruthSentences.size(); idx++) {
+ ByteRange expected = groundTruthSentences[idx];
+ ByteRange obtained = testAnnotation.sentenceAsByteRange(idx);
+ if (debug) {
+ std::cout << std::string(testAnnotation.sentence(idx)) << std::endl;
+ }
+ CHECK(expected.begin == obtained.begin);
+ CHECK(expected.end == obtained.end);
+ std::string expected_string = std::string(sentenceStringViews[idx]);
+ std::string obtained_string = std::string(testAnnotation.sentence(idx));
+ CHECK(expected_string == obtained_string);
+ }
+
+ /// Access the word(sentenceIdx, wordIdx) API and confirm what we hold as
+ /// expected words are the same as those obtained from the container.
+ if (debug) {
+ std::cout << "From container: Words" << std::endl;
+ }
+
+ CHECK(groundTruthWords.size() == testAnnotation.numSentences());
+ for (int idx = 0; idx < groundTruthWords.size(); idx++) {
+ CHECK(groundTruthWords[idx].size() == testAnnotation.numWords(idx));
+ }
+
+ for (int idx = 0; idx < groundTruthWords.size(); idx++) {
+ for (int idw = 0; idw < groundTruthWords[idx].size(); idw++) {
+ ByteRange expected = groundTruthWords[idx][idw];
+ ByteRange obtained = testAnnotation.wordAsByteRange(idx, idw);
+ if (debug) {
+ std::cout << std::string(testAnnotation.word(idx, idw)) << " ";
+ }
+ CHECK(expected.begin == obtained.begin);
+ CHECK(expected.end == obtained.end);
+
+ std::string expected_string = std::string(wordStringViews[idx][idw]);
+ std::string obtained_string = std::string(testAnnotation.word(idx, idw));
+ CHECK(expected_string == obtained_string);
+ }
+ if (debug) {
+ std::cout << std::endl;
+ }
+ }
+
+ // Try inserting an empty Sentence. This is ensuring we check for empty
+ // Sentence if the random test above does not cover it for some reason.
+ int emptySentenceIdx = sentences;
+ std::vector emptySentence;
+ testAnnotation.recordExistingSentence(emptySentence.begin(), emptySentence.end(),
+ testAnnotation.text.data() + testAnnotation.text.size());
+
+ // There are no words.
+ CHECK(testAnnotation.numWords(emptySentenceIdx) == 0);
+
+ // Empty sentence expected at output.
+ std::string expectedEmptyString = "";
+ marian::string_view emptyView = testAnnotation.sentence(emptySentenceIdx);
+ std::string obtainedString = std::string(emptyView.data(), emptyView.size());
+ CHECK(expectedEmptyString == obtainedString);
+}
diff --git a/inference/src/tests/units/cache_tests.cpp b/inference/src/tests/units/cache_tests.cpp
new file mode 100644
index 000000000..f2f1b19ed
--- /dev/null
+++ b/inference/src/tests/units/cache_tests.cpp
@@ -0,0 +1,56 @@
+
+#include
+#include
+
+#include "catch.hpp"
+#include "translator/cache.h"
+#include "translator/history.h"
+
+using namespace marian::bergamot;
+
+TEST_CASE("Test Cache in a threaded setting") {
+ size_t numThreads = 100;
+ size_t numIters = 10000;
+ using Key = int;
+ using Value = int;
+ using TestCache = AtomicCache;
+
+ TestCache cache(/*size=*/300, /*mutexBuckets=*/16);
+
+ auto op = [numIters, &cache]() {
+ std::mt19937_64 randomGenerator;
+ randomGenerator.seed(42); // reproducible outputs
+ Value randMax = 2000;
+
+ for (size_t i = 0; i < numIters; i++) {
+ Key query = randomGenerator() % randMax;
+ std::pair result = cache.find(query);
+ if (result.first) {
+ REQUIRE(result.second == query);
+ }
+
+ Value value = query;
+ cache.store(/*key=*/query, std::move(value));
+ }
+ };
+
+ std::vector workers;
+ for (size_t t = 0; t < numThreads; t++) {
+ workers.emplace_back(op);
+ }
+
+ for (size_t t = 0; t < numThreads; t++) {
+ workers[t].join();
+ }
+
+ TestCache::Stats stats = cache.stats();
+ float hitRate = static_cast(stats.hits) / static_cast(stats.hits + stats.misses);
+
+ // This is non-deterministic due to threads.
+ std::cout << "Hit-Rate:" << hitRate << "\n";
+ std::cout << "(Hits, Misses) = " << stats.hits << " " << stats.misses << "\n";
+
+ // Can we create a specialization of the actual cache-type we want? Does it compile, at least?
+ // We already have Ptr, it's easier to move Ptr to cache.
+ TranslationCache translationCache(/*size=*/300, /*mutexBuckets=*/16);
+}
diff --git a/inference/src/tests/units/html_tests.cpp b/inference/src/tests/units/html_tests.cpp
new file mode 100644
index 000000000..96eff5aad
--- /dev/null
+++ b/inference/src/tests/units/html_tests.cpp
@@ -0,0 +1,880 @@
+#include "html_tests.h"
+
+#include
+
+#include "catch.hpp"
+#include "data/types.h" // for marian::string_view
+#include "translator/html.h"
+#include "translator/response.h"
+
+using namespace marian::bergamot;
+using marian::string_view;
+
+class MarianThrowsExceptionsFixture {
+ protected:
+ MarianThrowsExceptionsFixture() : prev_(marian::getThrowExceptionOnAbort()) {
+ marian::setThrowExceptionOnAbort(true);
+ }
+ ~MarianThrowsExceptionsFixture() { marian::setThrowExceptionOnAbort(prev_); }
+
+ private:
+ bool prev_;
+};
+
+std::ostream &operator<<(std::ostream &out, std::pair const &b) {
+ return out << '(' << b.first << ',' << b.second << ')';
+}
+
+std::ostream &operator<<(std::ostream &out, ByteRange const &b) { return out << '{' << b.begin << ',' << b.end << '}'; }
+
+std::vector asByteRanges(AnnotatedText const &annotation) {
+ std::vector words;
+ words.emplace_back(annotation.annotation.gap(0));
+ for (size_t sentenceIdx = 0; sentenceIdx < annotation.numSentences(); ++sentenceIdx) {
+ for (size_t wordIdx = 0; wordIdx < annotation.numWords(sentenceIdx); ++wordIdx)
+ words.emplace_back(annotation.wordAsByteRange(sentenceIdx, wordIdx));
+ words.emplace_back(annotation.annotation.gap(sentenceIdx + 1));
+ }
+ return words;
+}
+
+std::vector asTokens(AnnotatedText const &annotation) {
+ std::vector words;
+ words.emplace_back(annotation.gap(0));
+ for (size_t sentenceIdx = 0; sentenceIdx < annotation.numSentences(); ++sentenceIdx) {
+ for (size_t wordIdx = 0; wordIdx < annotation.numWords(sentenceIdx); ++wordIdx)
+ words.emplace_back(annotation.word(sentenceIdx, wordIdx));
+ words.emplace_back(annotation.gap(sentenceIdx + 1));
+ }
+ return words;
+}
+
+void recordSentenceFromByteRange(AnnotatedText &text, std::vector const &ranges) {
+ assert(ranges.size() > 0);
+
+ std::vector tokens;
+ tokens.reserve(ranges.size());
+
+ for (auto &&range : ranges) tokens.emplace_back(text.text.data() + range.begin, range.size());
+
+ text.recordExistingSentence(tokens.begin(), tokens.end(), text.text.data() + ranges[0].begin);
+}
+
+template
+std::vector> identity_matrix(size_t size) {
+ std::vector> rows(size);
+ for (size_t row = 0; row < size; ++row) {
+ rows[row].resize(size, T(0));
+ rows[row][row] = T(1);
+ }
+ return rows;
+}
+
+TEST_CASE("Ignore HTML if process_markup is false") {
+ std::string html_code("
This text & has HTML in it
");
+
+ std::string input(html_code);
+ HTML html(std::move(input), false);
+ CHECK(input == html_code);
+
+ Response response;
+ response.source.text = html_code;
+ response.target.text = html_code;
+ // Note: response.alignments is empty, which is allowed in this case
+ html.restore(response);
+
+ // Assert that restore() does not mess with my HTML code
+ CHECK(response.source.text == html_code);
+}
+
+TEST_CASE_METHOD(MarianThrowsExceptionsFixture, "Abort if alignments are missing") {
+ std::string input("
\n");
+ HTML html(std::move(input), true);
+
+ AnnotatedText source("hello world\n");
+ recordSentenceFromByteRange(source, {
+ ByteRange{0, 4}, // 0.0 "hell"
+ ByteRange{4, 5}, // 0.1 "o"
+ ByteRange{5, 11}, // 0.2 " world"
+ ByteRange{11, 11} // 0.3 ""
+ });
+
+ AnnotatedText target("hallo Welt\n");
+ recordSentenceFromByteRange(target, {
+ ByteRange{0, 4}, // 0.0 "hall"
+ ByteRange{4, 5}, // 0.1 "o"
+ ByteRange{5, 10}, // 0.2 " Welt"
+ ByteRange{10, 10} // 0.3 ""
+ });
+
+ Response response;
+ response.source = source;
+ response.target = target;
+
+ // If the model is misconfigured to not give any alignment information,
+ // response will have entries for each target word, but they will all be empty.
+ response.alignments = {{{}, {}, {}, {}}};
+
+ CHECK_THROWS_WITH(
+ html.restore(response),
+ "Response object does not contain alignments. TranslationModel or ResponseOptions is misconfigured?");
+}
+
+TEST_CASE("Do not abort if the input is just empty") {
+ std::string input("");
+ HTML html(std::move(input), true);
+ CHECK(input == "");
+
+ Response response;
+ html.restore(response);
+ CHECK(response.source.text == "");
+ CHECK(response.target.text == "");
+}
+
+TEST_CASE("Do not abort if the input is just empty element") {
+ std::string input("");
+ HTML html(std::move(input), true);
+ CHECK(input == "");
+
+ Response response;
+ html.restore(response);
+ CHECK(response.source.text == "");
+ CHECK(response.target.text == "");
+}
+
+TEST_CASE("Tag names are case insensitive") {
+ // Tests
vs
and should be recognized as a void tag .
+ // should be recognized as inline.
+ std::string test_str("
Space please?
");
+
+ std::string input(test_str);
+ HTML html(std::move(input), true);
+ CHECK(input == "Spa ce\n\nplease?");
+}
+
+TEST_CASE("Test case html entities") {
+ // These are all entities I would expect in innerHTML, since all other entities
+ // can be encoded as UTF-8 so there's no need to encode them through &...; when
+ // innerHTML encodes the DOM as HTML.
+ std::string input("
This is a sentence <with> named & entities
");
+ HTML html(std::move(input), true);
+ CHECK(input == "This is a sentence named & entities");
+}
+
+TEST_CASE("Test self-closing tags should be treated as paragraph break") {
+ std::string test_str("
Space please?
");
+
+ std::string input(test_str);
+ HTML html(std::move(input), true);
+ CHECK(input == "Space\n\nplease?");
+
+ Response response;
+ std::string source_str("Space\n\nplease?");
+ std::vector source_tokens{
+ string_view(source_str.data() + 0, 5), // Space
+ string_view(source_str.data() + 5, 0), // [EOS]
+ string_view(source_str.data() + 5, 2), // \n\n
+ string_view(source_str.data() + 7, 1), // p
+ string_view(source_str.data() + 8, 5), // lease
+ string_view(source_str.data() + 13, 1), // ?
+ string_view(source_str.data() + 14, 0), // EOS
+ };
+ response.source.appendSentence("", source_tokens.begin(), source_tokens.begin() + 2);
+ response.source.appendSentence("\n\n", source_tokens.begin() + 3, source_tokens.end());
+
+ std::string target_str("Platz\n\nbitte?");
+ std::vector target_tokens{
+ string_view(target_str.data() + 0, 5), // Platz
+ string_view(target_str.data() + 5, 0), // [EOS]
+ string_view(target_str.data() + 5, 2), // \n\n
+ string_view(target_str.data() + 7, 5), // bitte
+ string_view(target_str.data() + 12, 1), // ?
+ string_view(target_str.data() + 13, 0), // [EOS]
+ };
+ response.target.appendSentence("", target_tokens.begin(), target_tokens.begin() + 2);
+ response.target.appendSentence("", target_tokens.begin() + 3, target_tokens.end());
+ response.alignments = {{
+ {1.0, 0.0}, // Platz <- Space
+ {0.0, 1.0} // [EOS] <- [EOS]
+ },
+ {
+ {0.1, 0.9, 0.0, 0.0}, // _bitte <- _p + lease
+ {0.0, 0.0, 1.0, 0.0}, // ? <- ?
+ {0.0, 0.0, 0.0, 1.0}, // [EOS] <- [EOS]
+ }};
+
+ // Main focus of this test is that the space that was introduced in the text
+ // that was being translated does not end up in the translation.
+ html.restore(response);
+ CHECK(response.source.text == "
");
+ HTML html(std::move(input), true);
+ CHECK(input == "hello world and other creatures"); // Note double space between "hello" and "world"
+}
+
+TEST_CASE("Test self-closing tag (XHTML)") {
+ std::string input("
";
+ markup::instream in(html_str.data());
+ markup::Scanner scanner(in);
+
+ CHECK(scanner.next() == markup::Scanner::TT_TAG_START);
+ CHECK(scanner.tag() == "div");
+ CHECK(scanner.next() == markup::Scanner::TT_ATTRIBUTE);
+ CHECK(scanner.attribute() == "id");
+ CHECK(scanner.value() == "test-id");
+ CHECK(scanner.next() == markup::Scanner::TT_ATTRIBUTE);
+ CHECK(scanner.attribute() == "class");
+ CHECK(scanner.value() == "a b c ");
+ CHECK(scanner.next() == markup::Scanner::TT_TEXT);
+ CHECK(scanner.value() == "\n");
+ CHECK(scanner.next() == markup::Scanner::TT_TAG_START);
+ CHECK(scanner.tag() == "span");
+ CHECK(scanner.next() == markup::Scanner::TT_ATTRIBUTE);
+ CHECK(scanner.attribute() == "x-custom-attribute");
+ CHECK(scanner.value() == "Hello "world""); // We do not decode entities in attributes
+ CHECK(scanner.next() == markup::Scanner::TT_COMMENT_START);
+ CHECK(scanner.next() == markup::Scanner::TT_DATA);
+ CHECK(scanner.value() == "\nthis is a comment ");
+ CHECK(scanner.next() == markup::Scanner::TT_COMMENT_END);
+ CHECK(scanner.next() == markup::Scanner::TT_TEXT);
+ CHECK(scanner.value() == "this is ");
+ CHECK(scanner.next() == markup::Scanner::TT_TEXT);
+ CHECK(scanner.value() == "&");
+ CHECK(scanner.next() == markup::Scanner::TT_TEXT);
+ CHECK(scanner.value() == " text\n");
+ CHECK(scanner.next() == markup::Scanner::TT_TAG_END);
+ CHECK(scanner.tag() == "span");
+ CHECK(scanner.next() == markup::Scanner::TT_TAG_END);
+ CHECK(scanner.tag() == "div");
+ CHECK(scanner.next() == markup::Scanner::TT_EOF);
+}
+
+TEST_CASE("test long text (#273)") {
+ std::string test_str;
+ for (size_t i = 0; i < 1024; ++i) test_str.append("testing ");
+
+ markup::instream in(test_str.data());
+ markup::Scanner scanner(in);
+
+ CHECK(scanner.next() == markup::Scanner::TT_TEXT);
+ CHECK(scanner.value() == test_str);
+ CHECK(scanner.next() == markup::Scanner::TT_EOF);
+}
+
+TEST_CASE("scan self-closing element") {
+ markup::instream in("before after");
+ markup::Scanner scanner(in);
+
+ CHECK(scanner.next() == markup::Scanner::TT_TEXT);
+ CHECK(scanner.value() == "before ");
+ CHECK(scanner.next() == markup::Scanner::TT_TAG_START);
+ CHECK(scanner.tag() == "img");
+ CHECK(scanner.next() == markup::Scanner::TT_ATTRIBUTE);
+ CHECK(scanner.attribute() == "src");
+ CHECK(scanner.value() == "#");
+ CHECK(scanner.next() == markup::Scanner::TT_TAG_END);
+ CHECK(scanner.tag() == "img");
+ CHECK(scanner.next() == markup::Scanner::TT_TEXT);
+ CHECK(scanner.value() == " after");
+ CHECK(scanner.next() == markup::Scanner::TT_EOF);
+}
+
+TEST_CASE("scan script") {
+ markup::instream in("");
+ markup::Scanner scanner(in);
+
+ CHECK(scanner.next() == markup::Scanner::TT_TAG_START);
+ CHECK(scanner.tag() == "script");
+ CHECK(scanner.next() == markup::Scanner::TT_ATTRIBUTE);
+ CHECK(scanner.attribute() == "async");
+ CHECK(scanner.value() == "");
+ CHECK(scanner.next() == markup::Scanner::TT_DATA);
+ CHECK(scanner.value() == "true && document.body.length > 10");
+ CHECK(scanner.next() == markup::Scanner::TT_TAG_END);
+ CHECK(scanner.next() == markup::Scanner::TT_EOF);
+}
+
+TEST_CASE("scan style") {
+ markup::instream in("");
+ markup::Scanner scanner(in);
+
+ CHECK(scanner.next() == markup::Scanner::TT_TAG_START);
+ CHECK(scanner.tag() == "style");
+ CHECK(scanner.next() == markup::Scanner::TT_DATA);
+ CHECK(scanner.value() == "body { background: url(test.png); }");
+ CHECK(scanner.next() == markup::Scanner::TT_TAG_END);
+ CHECK(scanner.next() == markup::Scanner::TT_EOF);
+}
+
+TEST_CASE("scan processing instruction") {
+ // Based on https://searchfox.org/mozilla-central/source/dom/base/nsContentUtils.cpp#8961
+ // element.outerHTML can produce processing instructions in the html. These
+ // should be treated similar to .
+ markup::instream in("");
+ markup::Scanner scanner(in);
+
+ CHECK(scanner.next() == markup::Scanner::TT_PROCESSING_INSTRUCTION_START);
+ CHECK(scanner.next() == markup::Scanner::TT_DATA);
+ CHECK(scanner.value() == "xml version=\"1.0\"");
+ CHECK(scanner.next() == markup::Scanner::TT_PROCESSING_INSTRUCTION_END);
+ CHECK(scanner.next() == markup::Scanner::TT_EOF);
+}
\ No newline at end of file
diff --git a/inference/src/tests/wasm.cpp b/inference/src/tests/wasm.cpp
new file mode 100644
index 000000000..97f0fc801
--- /dev/null
+++ b/inference/src/tests/wasm.cpp
@@ -0,0 +1,54 @@
+#include "common.h"
+using namespace marian::bergamot;
+
+void wasm(BlockingService &service, std::shared_ptr &model) {
+ std::vector responseOptions;
+ std::vector texts;
+
+ // WASM always requires HTML and alignment.
+ // TODO(jerinphilip): Fix this, bring in actual tests.
+ // responseOptions.HTML = true;
+ // responseOptions.alignment = true; // Necessary for HTML
+
+ // Hide the translateMultiple operation
+ for (std::string line; std::getline(std::cin, line);) {
+ texts.emplace_back(line);
+ responseOptions.emplace_back();
+ }
+
+ auto results = service.translateMultiple(model, std::move(texts), responseOptions);
+
+ for (auto &result : results) {
+ std::cout << result.getTranslatedText() << std::endl;
+ }
+}
+
+int main(int argc, char *argv[]) {
+ ConfigParser configParser("WebAssembly test-suite", /*multiOpMode=*/true);
+ configParser.parseArgs(argc, argv);
+
+ auto &config = configParser.getConfig();
+ BlockingService service(config.serviceConfig);
+
+ TestSuite testSuite(service);
+ std::vector> models;
+
+ for (auto &modelConfigPath : config.modelConfigPaths) {
+ TranslationModel::Config modelConfig = parseOptionsFromFilePath(modelConfigPath);
+ // Anything WASM is expected to use the byte-array-loads. So we hard-code grabbing MemoryBundle from FS and use the
+ // MemoryBundle capable constructor.
+ MemoryBundle memoryBundle = getMemoryBundleFromConfig(modelConfig);
+ std::shared_ptr model = std::make_shared(modelConfig, std::move(memoryBundle));
+ models.push_back(model);
+ }
+
+ /// WASM is one special case where WASM path is being checked, involving translateMultiple and a multi-line feed.
+ /// Hence we do not bind it at a single input-blob single Response constraint imposed by the TestSuite.
+ if (config.opMode == "wasm") {
+ wasm(service, models.front());
+ } else {
+ testSuite.run(config.opMode, models);
+ }
+
+ return 0;
+}
diff --git a/inference/src/translator/CMakeLists.txt b/inference/src/translator/CMakeLists.txt
new file mode 100644
index 000000000..1d773b46b
--- /dev/null
+++ b/inference/src/translator/CMakeLists.txt
@@ -0,0 +1,45 @@
+# Generate version file
+configure_file(${CMAKE_CURRENT_SOURCE_DIR}/project_version.h.in
+ ${CMAKE_CURRENT_BINARY_DIR}/project_version.h @ONLY)
+
+add_library(bergamot-translator STATIC
+ byte_array_util.cpp
+ text_processor.cpp
+ translation_model.cpp
+ request.cpp
+ batching_pool.cpp
+ aggregate_batching_pool.cpp
+ response_builder.cpp
+ quality_estimator.cpp
+ batch.cpp
+ annotation.cpp
+ service.cpp
+ parser.cpp
+ response.cpp
+ html.cpp
+ xh_scanner.cpp
+)
+if (USE_WASM_COMPATIBLE_SOURCE)
+ # Using wasm compatible sources should include this compile definition;
+ # Has to be done here because we are including marian headers + some sources
+ # in local repository use these definitions
+ target_compile_definitions(bergamot-translator PUBLIC USE_SSE2 WASM_COMPATIBLE_SOURCE)
+endif()
+
+if(COMPILE_WASM)
+ target_compile_definitions(bergamot-translator PUBLIC WASM)
+ # Enable code that is required for generating JS bindings
+ target_compile_definitions(bergamot-translator PRIVATE WASM_BINDINGS)
+ target_compile_options(bergamot-translator PRIVATE ${WASM_COMPILE_FLAGS})
+ target_link_options(bergamot-translator PRIVATE ${WASM_LINK_FLAGS})
+endif(COMPILE_WASM)
+
+if(ENABLE_CACHE_STATS)
+ target_compile_definitions(bergamot-translator PUBLIC ENABLE_CACHE_STATS)
+endif(ENABLE_CACHE_STATS)
+
+target_link_libraries(bergamot-translator marian ssplit)
+
+target_include_directories(bergamot-translator
+ PUBLIC ${PROJECT_SOURCE_DIR}
+ ${PROJECT_SOURCE_DIR}/src)
diff --git a/inference/src/translator/aggregate_batching_pool.cpp b/inference/src/translator/aggregate_batching_pool.cpp
new file mode 100644
index 000000000..5f405110a
--- /dev/null
+++ b/inference/src/translator/aggregate_batching_pool.cpp
@@ -0,0 +1,36 @@
+
+#include "aggregate_batching_pool.h"
+
+namespace marian {
+namespace bergamot {
+
+AggregateBatchingPool::AggregateBatchingPool() {
+ // TODO(@jerinphilip): Set aggregate limits
+}
+
+size_t AggregateBatchingPool::enqueueRequest(Ptr model, Ptr request) {
+ size_t sentencesEnqueued = model->enqueueRequest(request);
+ aggregateQueue_.insert(model);
+ return sentencesEnqueued;
+}
+
+size_t AggregateBatchingPool::generateBatch(Ptr& model, Batch& batch) {
+ while (!aggregateQueue_.empty()) {
+ auto candidateItr = aggregateQueue_.begin();
+ Ptr candidate = *candidateItr;
+ size_t numSentences = candidate->generateBatch(batch);
+ if (numSentences > 0) {
+ model = candidate;
+ return numSentences;
+ } else {
+ // Try the next model's batching pool.
+ aggregateQueue_.erase(candidateItr);
+ }
+ }
+ return /*numSentences=*/0;
+}
+
+void AggregateBatchingPool::clear() { aggregateQueue_.clear(); }
+
+} // namespace bergamot
+} // namespace marian
diff --git a/inference/src/translator/aggregate_batching_pool.h b/inference/src/translator/aggregate_batching_pool.h
new file mode 100644
index 000000000..6775591e0
--- /dev/null
+++ b/inference/src/translator/aggregate_batching_pool.h
@@ -0,0 +1,72 @@
+#ifndef SRC_BERGAMOT_AGGREGATE_BATCHING_POOL_H_
+#define SRC_BERGAMOT_AGGREGATE_BATCHING_POOL_H_
+
+#include
+#include
+
+#include "data/types.h"
+#include "translation_model.h"
+
+namespace marian {
+namespace bergamot {
+
+/// Hashes a pointer to an object using the address the pointer points to. If two pointers point to the same address,
+/// they hash to the same value. Useful to put widely shared_ptrs of entities (eg: TranslationModel, Vocab, Shortlist)
+/// etc into containers which require the members to be hashable (std::unordered_set, std::unordered_map).
+template
+struct HashPtr {
+ size_t operator()(const std::shared_ptr& t) const {
+ size_t address = reinterpret_cast(t.get());
+ return std::hash()(address);
+ }
+};
+
+/// Aggregates request queueing and generation of batches from multiple TranslationModels (BatchingPools within,
+/// specifically), thereby acting as an intermediary to enable multiple translation model capability in BlockingService
+/// and AsyncService.
+///
+/// A simple queue containing shared owning references to TranslationModels are held here from which batches are
+/// generated on demand. Since a queue is involved, the ordering is first-come first serve on requests except there are
+/// leaks effectively doing priority inversion if an earlier request with the same TranslationModel is pending
+/// to be consumed for translation.
+//
+/// Actual storage for the request and batch generation are within the respective TranslationModels, which owns its own
+/// BatchingPool.
+///
+/// Matches API provided by BatchingPool except arguments additionally parameterized by TranslationModel.
+///
+/// Note: This class is not thread-safe. You may use this class wrapped with ThreadsafeBatchingPool for a thread-safe
+/// equivalent of this class, if needed.
+class AggregateBatchingPool {
+ public:
+ /// Create an AggregateBatchingPool with (tentatively) global (across all BatchingPools) limits
+ /// imposed here.
+ AggregateBatchingPool();
+
+ /// Enqueue an existing request onto model, also keep account of that this model and request are now pending.
+ ///
+ /// @param [in] model: Model to use in translation. A shared ownership to this model is accepted by this object to
+ /// keep the model alive until translation is complete.
+ /// @param [in] request: A request to be enqueued to model.
+ /// @returns number of sentences added for translation.
+ size_t enqueueRequest(Ptr model, Ptr request);
+
+ /// Generate a batch from pending requests, obtained from available TranslationModels.
+ ///
+ /// @param [out] model: TranslationModel
+ /// @param [out] batch: Batch to write onto, which is consumed at translation elsewhere.
+ /// @returns Number of sentences in the generated batch.
+ size_t generateBatch(Ptr& model, Batch& batch);
+
+ /// Clear the aggregate queue. Does not clear the underlying model/request pairs but the next call
+ /// to `generateBatch()` will return 0. (Unless `enqueueRequest()` was called in the mean time.)
+ void clear();
+
+ private:
+ std::unordered_set, HashPtr> aggregateQueue_;
+};
+
+} // namespace bergamot
+} // namespace marian
+
+#endif // SRC_BERGAMOT_AGGREGATE_BATCHING_POOL_H_
diff --git a/inference/src/translator/aligned.h b/inference/src/translator/aligned.h
new file mode 100644
index 000000000..73e82edc3
--- /dev/null
+++ b/inference/src/translator/aligned.h
@@ -0,0 +1,92 @@
+#pragma once
+#include
+#include
+#ifdef _MSC_VER
+// Ensure _HAS_EXCEPTIONS is defined
+#include
+#include
+#endif
+
+#if !((defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS))
+#include
+#endif
+
+// Aligned simple vector.
+
+namespace marian {
+namespace bergamot {
+
+template class AlignedVector {
+ public:
+ AlignedVector() : mem_(nullptr), size_(0) {}
+
+ explicit AlignedVector(std::size_t size, std::size_t alignment = 64 /* CPU cares about this */)
+ : size_(size) {
+#ifdef _MSC_VER
+ mem_ = static_cast(_aligned_malloc(size * sizeof(T), alignment));
+ if (!mem_) {
+# if (defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS)
+ throw std::bad_alloc();
+# else
+ std::abort();
+# endif
+ }
+#else
+ if (posix_memalign(reinterpret_cast(&mem_), alignment, size * sizeof(T))) {
+# if (defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS)
+ throw std::bad_alloc();
+# else
+ std::abort();
+# endif
+ }
+#endif
+ }
+
+ AlignedVector(AlignedVector &&from) : mem_(from.mem_), size_(from.size_) {
+ from.mem_ = nullptr;
+ from.size_ = 0;
+ }
+
+ AlignedVector &operator=(AlignedVector &&from) {
+ if (this == &from) return *this;
+ release();
+ mem_ = from.mem_;
+ size_ = from.size_;
+ from.mem_ = nullptr;
+ from.size_ = 0;
+ return *this;
+ }
+
+ AlignedVector(const AlignedVector&) = delete;
+ AlignedVector& operator=(const AlignedVector&) = delete;
+
+ ~AlignedVector() { release(); }
+
+ std::size_t size() const { return size_; }
+
+ T &operator[](std::size_t offset) { return mem_[offset]; }
+ const T &operator[](std::size_t offset) const { return mem_[offset]; }
+
+ T *begin() { return mem_; }
+ const T *begin() const { return mem_; }
+ T *end() { return mem_ + size_; }
+ const T *end() const { return mem_ + size_; }
+
+ template
+ ReturnType *as() { return reinterpret_cast(mem_); }
+
+ private:
+ T *mem_;
+ std::size_t size_;
+
+ void release() {
+#ifdef _MSC_VER
+ _aligned_free(mem_);
+#else
+ std::free(mem_);
+#endif
+ }
+};
+
+} // namespace bergamot
+} // namespace marian
diff --git a/inference/src/translator/annotation.cpp b/inference/src/translator/annotation.cpp
new file mode 100644
index 000000000..e05a6a77d
--- /dev/null
+++ b/inference/src/translator/annotation.cpp
@@ -0,0 +1,70 @@
+#include "annotation.h"
+
+#include
+
+namespace marian {
+namespace bergamot {
+
+AnnotatedText::AnnotatedText(std::string &&t) : text(std::move(t)) {
+ // Treat the entire text as a gap that recordExistingSentence will break.
+ annotation.token_begin_.back() = text.size();
+}
+
+void AnnotatedText::appendSentence(string_view prefix, std::vector::iterator begin,
+ std::vector::iterator end) {
+ assert(annotation.token_begin_.back() == text.size());
+
+ // prefix is just end of the previous one.
+ appendEndingWhitespace(prefix);
+
+ // Appending sentence text.
+ std::size_t offset = text.size();
+ for (std::vector::iterator token = begin; token != end; ++token) {
+ offset += token->size();
+ annotation.token_begin_.push_back(offset);
+ }
+ if (begin != end) {
+ text.append(begin->data(), (end - 1)->data() + (end - 1)->size());
+ assert(offset == text.size()); // Tokens should be contiguous.
+ }
+
+ // Add the gap after the sentence. This is empty for now, but will be
+ // extended with appendEndingWhitespace or another appendSentence.
+ annotation.gap_.push_back(annotation.token_begin_.size() - 1);
+ annotation.token_begin_.push_back(offset);
+}
+
+void AnnotatedText::appendEndingWhitespace(string_view whitespace) {
+ text.append(whitespace.data(), whitespace.size());
+ annotation.token_begin_.back() = text.size();
+}
+
+void AnnotatedText::recordExistingSentence(std::vector::iterator begin,
+ std::vector::iterator end, const char *sentence_begin) {
+ assert(sentence_begin >= text.data());
+ assert(sentence_begin <= text.data() + text.size());
+ assert(begin == end || sentence_begin == begin->data());
+ assert(!annotation.token_begin_.empty());
+ assert(annotation.token_begin_.back() == text.size());
+ // Clip off size token ending.
+ annotation.token_begin_.pop_back();
+ for (std::vector::iterator i = begin; i != end; ++i) {
+ assert(i->data() >= text.data()); // In range.
+ assert(i->data() + i->size() <= text.data() + text.size()); // In range
+ assert(i + 1 == end || i->data() + i->size() == (i + 1)->data()); // Contiguous
+ annotation.token_begin_.push_back(i->data() - text.data());
+ }
+ // Gap token after sentence.
+ annotation.gap_.push_back(annotation.token_begin_.size());
+ if (begin != end) {
+ annotation.token_begin_.push_back((end - 1)->data() + (end - 1)->size() - text.data());
+ } else {
+ // empty sentence.
+ annotation.token_begin_.push_back(sentence_begin - text.data());
+ }
+ // Add back size token ending.
+ annotation.token_begin_.push_back(text.size());
+}
+
+} // namespace bergamot
+} // namespace marian
diff --git a/inference/src/translator/annotation.h b/inference/src/translator/annotation.h
new file mode 100644
index 000000000..5a17dfcfe
--- /dev/null
+++ b/inference/src/translator/annotation.h
@@ -0,0 +1,232 @@
+#ifndef BERGAMOT_SENTENCE_RANGES_H_
+#define BERGAMOT_SENTENCE_RANGES_H_
+
+#include
+#include
+#include
+
+#include "data/types.h"
+#include "definitions.h"
+
+namespace marian {
+namespace bergamot {
+
+/// Annotation expresses sentence and token boundary information as ranges of
+/// bytes in a string, but does not itself own the string.
+///
+/// See also AnnotatedText, which owns Annotation and the string. AnnotatedText
+/// wraps these ByteRange functions to provide a string_view interface.
+///
+/// Text is divided into gaps (whitespace between sentences) and sentences like
+/// so:
+/// gap sentence gap sentence gap
+/// Because gaps appear at the beginning and end of the text, there's always
+/// one more gap than there are sentences.
+///
+/// The entire text is a unbroken sequence of tokens (i.e. the end of a token
+/// is the beginning of the next token). A gap is exactly one token containing
+/// whatever whitespace is between the sentences. A sentence is a sequence of
+/// tokens.
+///
+/// Since we are using SentencePiece, a token can include whitespace. The term
+/// "word" is used, somewhat incorrectly, as a synonym of token.
+///
+/// A gap can be empty (for example there may not have been whitespace at the
+/// beginning). A sentence can also be empty (typically the translation system
+/// produced empty output). That's fine, these are just empty ranges as you
+/// would expect.
+class Annotation {
+ public:
+ /// Initially an empty string. Populated by AnnotatedText.
+ Annotation() {
+ token_begin_.push_back(0);
+ token_begin_.push_back(0);
+ gap_.push_back(0);
+ }
+
+ size_t numSentences() const { return gap_.size() - 1; }
+
+ /// Returns number of words in the sentence identified by `sentenceIdx`.
+ size_t numWords(size_t sentenceIdx) const {
+ return gap_[sentenceIdx + 1] - gap_[sentenceIdx] - 1 /* minus the gap */;
+ }
+
+ /// Returns a ByteRange representing `wordIdx` in sentence indexed by
+ /// `sentenceIdx`. `wordIdx` follows 0-based indexing, and should be less than
+ /// `.numWords()` for `sentenceIdx` for defined behaviour.
+ ByteRange word(size_t sentenceIdx, size_t wordIdx) const {
+ size_t tokenIdx = gap_[sentenceIdx] + 1 + wordIdx;
+ return ByteRange{token_begin_[tokenIdx], token_begin_[tokenIdx + 1]};
+ }
+
+ /// Returns a ByteRange representing sentence corresponding to `sentenceIdx`.
+ /// `sentenceIdx` follows 0-based indexing, and behaviour is defined only when
+ /// less than `.numSentences()`.
+ ByteRange sentence(size_t sentenceIdx) const {
+ return ByteRange{
+ token_begin_[gap_[sentenceIdx] + 1], /*end of whitespace before */
+ token_begin_[gap_[sentenceIdx + 1]] /*beginning of whitespace after */
+ };
+ }
+
+ ByteRange gap(size_t gapIdx) const {
+ size_t tokenIdx = gap_[gapIdx];
+ return ByteRange{token_begin_[tokenIdx], token_begin_[tokenIdx + 1]};
+ }
+
+ private:
+ friend class AnnotatedText;
+ /// Map from token index to byte offset at which it begins. Token i is:
+ /// [token_begin_[i], token_begin_[i+1])
+ /// The vector is padded so that these indices are always valid, even at the
+ /// end. So tokens_begin_.size() is the number of tokens plus 1.
+ std::vector token_begin_;
+
+ /// Indices of tokens that correspond to gaps between sentences. These are
+ /// indices into token_begin_.
+ /// Gap g is byte range:
+ /// [token_begin_[gap_[w]], token_begin_[gap_[w]+1])
+ /// Sentence s is byte range:
+ /// [token_begin_[gap_[s]+1], token_begin_[gap_[s+1]])
+ /// A sentence does not include whitespace at the beginning or end.
+ ///
+ /// gap_.size() == numSentences() + 1.
+ ///
+ /// Example: empty text "" -> just an empty gap.
+ /// token_begin_ = {0, 0};
+ /// gap_ = {0};
+ ///
+ /// Example: only space " " -> just a gap containing the space.
+ /// token_begin_ = {0, 1};
+ /// gap_ = {0};
+ ///
+ /// Example: one token "hi" -> empty gap, sentence with one token, empty gap
+ /// token_begin_ = {0, 0, 2, 2};
+ /// gap_ = {0, 2};
+ std::vector gap_;
+};
+
+/// AnnotatedText is effectively std::string text + Annotation, providing the
+/// following additional desiderata.
+///
+/// 1. Access to processed string_views for convenience rather than ByteRanges
+/// (which only provides index information).
+///
+/// 2. Transparently convert string_views into ByteRanges for the Annotation
+/// referring to the text bound by this structure.
+///
+/// 3. Bind the text and annotations together, to move around as a meaningful
+/// unit.
+struct AnnotatedText {
+ public:
+ std::string text; ///< Blob of string elements in annotation refers to.
+ Annotation annotation; ///< sentence and (sub-) word annotations.
+
+ /// Construct an empty AnnotatedText. This is useful when the target string or
+ /// ByteRanges are not known yet, but the public members can be used to
+ /// populate it. One use-case, when translated-text is created decoding from
+ /// histories and the ByteRanges only known after the string has been
+ /// constructed.
+ AnnotatedText() {}
+
+ /// Construct moving in a string (for efficiency purposes, copying string
+ /// constructor is disallowed).
+ AnnotatedText(std::string &&text);
+
+ /// Appends a sentence to the existing text and transparently rebases
+ /// string_views. Since this tracks only prefix, remember
+ /// appendEndingWhitespace.
+ /// The string_views must not already be in text.
+ void appendSentence(string_view prefix, std::vector::iterator tokens_begin,
+ std::vector::iterator tokens_end);
+
+ /// Append the whitespace at the end of input. string_view must not be in
+ /// text.
+ void appendEndingWhitespace(string_view whitespace);
+
+ /// Record the existence of a sentence that is already in text. The
+ /// iterators are over string_views for each token that must be in text
+ /// already. This function must be called to record sentences in order.
+ /// Normally the beginning of the sentence can be inferred from
+ /// tokens_begin->data() but the tokens could be empty, so sentence_begin is
+ /// required to know where the sentence is.
+ void recordExistingSentence(std::vector::iterator tokens_begin,
+ std::vector::iterator tokens_end, const char *sentence_begin);
+
+ /// Returns the number of sentences in the annotation structure.
+ const size_t numSentences() const { return annotation.numSentences(); }
+
+ /// Returns number of words in the sentece identified by sentenceIdx.
+ const size_t numWords(size_t sentenceIdx) const { return annotation.numWords(sentenceIdx); }
+
+ /// Returns a string_view representing wordIdx in sentenceIdx
+ string_view word(size_t sentenceIdx, size_t wordIdx) const {
+ return asStringView(annotation.word(sentenceIdx, wordIdx));
+ }
+
+ /// Returns a string_view representing sentence corresponding to sentenceIdx.
+ string_view sentence(size_t sentenceIdx) const { return asStringView(annotation.sentence(sentenceIdx)); }
+
+ /// Returns the string_view of the gap between two sentences in the container.
+ ///
+ /// More precisely where `i = sentenceIdx, N = numSentences()` for brevity:
+ ///
+ /// * For `i = 0`: The gap between the start of text and the 0th sentence.
+ /// * For `i = 1...N-1`, returns the text comprising of the gap
+ /// between the `i`-th and `i+1`-th sentence.
+ /// * For `i = N`, the gap between the last (N-1th) sentence and end of
+ /// text.
+ /// @param sentenceIdx: Can be between `[0, numSentences()]`.
+ string_view gap(size_t sentenceIdx) const { return asStringView(annotation.gap(sentenceIdx)); }
+
+ /// Returns a ByteRange representing wordIdx in sentenceIdx
+ ByteRange wordAsByteRange(size_t sentenceIdx, size_t wordIdx) const { return annotation.word(sentenceIdx, wordIdx); }
+
+ /// Returns a ByteRange representing sentence corresponding to sentenceIdx.
+ ByteRange sentenceAsByteRange(size_t sentenceIdx) const { return annotation.sentence(sentenceIdx); }
+
+ /// Utility function to call `fun` on each word (subword token effectively) in
+ /// an `AnnotatedText`. `fun` is called with the `ByteRange`, the `string_view`
+ /// with the word, and a `bool` to indicate whether it is the last word in the
+ /// `AnnotatedText`, which is also the ending whitespace slot of AnnotatedText.
+ template
+ AnnotatedText apply(Fun fun) const {
+ AnnotatedText out;
+
+ for (size_t sentenceIdx = 0; sentenceIdx < numSentences(); ++sentenceIdx) {
+ std::string sentence;
+ std::vector tokens;
+
+ std::string prefix = fun(annotation.gap(sentenceIdx), gap(sentenceIdx), false);
+
+ for (size_t wordIdx = 0; wordIdx < numWords(sentenceIdx); ++wordIdx) {
+ std::string token = fun(wordAsByteRange(sentenceIdx, wordIdx), word(sentenceIdx, wordIdx), false);
+ tokens.push_back(ByteRange{sentence.size(), sentence.size() + token.size()});
+ sentence += token;
+ }
+
+ // Convert our ByteRanges to string_views since that's what appendSentence
+ // expects
+ std::vector views(tokens.size());
+ std::transform(tokens.begin(), tokens.end(), views.begin(), [&](ByteRange const &range) {
+ return marian::string_view(sentence.data() + range.begin, range.size());
+ });
+
+ out.appendSentence(prefix, views.begin(), views.end());
+ }
+
+ out.appendEndingWhitespace(fun(annotation.gap(numSentences()), gap(numSentences()), true));
+
+ return out;
+ }
+
+ private:
+ string_view asStringView(const ByteRange &byteRange) const {
+ return string_view(text.data() + byteRange.begin, byteRange.size());
+ }
+};
+
+} // namespace bergamot
+} // namespace marian
+
+#endif // BERGAMOT_SENTENCE_RANGES_H_
diff --git a/inference/src/translator/batch.cpp b/inference/src/translator/batch.cpp
new file mode 100644
index 000000000..08d3d02c6
--- /dev/null
+++ b/inference/src/translator/batch.cpp
@@ -0,0 +1,26 @@
+#include "batch.h"
+
+#include "request.h"
+
+namespace marian {
+namespace bergamot {
+
+void Batch::log() {
+ size_t numTokens{0}, maxLength{0};
+ for (auto &sentence : sentences_) {
+ numTokens += sentence.numTokens();
+ maxLength = std::max(maxLength, static_cast(sentence.numTokens()));
+ }
+
+ LOG(info, "Batch(tokens={}, max-length={}, sentences_={})", numTokens, maxLength, sentences_.size());
+}
+
+void Batch::add(const RequestSentence &sentence) { sentences_.push_back(sentence); }
+
+void Batch::completeBatch(const Histories &histories) {
+ for (size_t i = 0; i < sentences_.size(); i++) {
+ sentences_[i].completeSentence(histories[i]);
+ }
+}
+} // namespace bergamot
+} // namespace marian
diff --git a/inference/src/translator/batch.h b/inference/src/translator/batch.h
new file mode 100644
index 000000000..2f67252be
--- /dev/null
+++ b/inference/src/translator/batch.h
@@ -0,0 +1,43 @@
+#ifndef SRC_BERGAMOT_BATCH_H
+#define SRC_BERGAMOT_BATCH_H
+
+#include "request.h"
+#include "translator/beam_search.h"
+
+namespace marian {
+namespace bergamot {
+
+// An empty batch is poison.
+class Batch {
+ public:
+ Batch() {}
+ void clear() { sentences_.clear(); }
+
+ size_t size() const { return sentences_.size(); }
+
+ void add(const RequestSentence &sentence);
+
+ // Accessors to read from a Batch. For use in BatchTranslator (consumer on a
+ // PCQueue holding batches).
+ //
+ // sentences() are used to access sentences to construct marian internal
+ // batch.
+ const RequestSentences &sentences() { return sentences_; }
+
+ // On obtaining Histories after translating a batch, completeBatch can be
+ // called with Histories , which forwards the call to Request through
+ // RequestSentence and triggers completion, by setting the promised value to
+ // the future given to client.
+ void completeBatch(const Histories &histories);
+
+ // Convenience function to log batch-statistics. numTokens, max-length.
+ void log();
+
+ private:
+ RequestSentences sentences_;
+};
+
+} // namespace bergamot
+} // namespace marian
+
+#endif // SRC_BERGAMOT_BATCH_H_
diff --git a/inference/src/translator/batching_pool.cpp b/inference/src/translator/batching_pool.cpp
new file mode 100644
index 000000000..61dd1920e
--- /dev/null
+++ b/inference/src/translator/batching_pool.cpp
@@ -0,0 +1,89 @@
+#include "batching_pool.h"
+
+#include
+
+#include "batch.h"
+#include "common/logging.h"
+
+namespace marian {
+namespace bergamot {
+
+BatchingPool::BatchingPool(Ptr options)
+ : miniBatchWords_(options->get("mini-batch-words")), maxActiveBucketLength_(0) {
+ size_t maxLengthBreak = options->get("max-length-break");
+ float maxLengthFactor = options->get("max-length-factor", 3.0);
+
+ // For the time being, we add some slack, which only BatchingPool is aware of. Since the TextProcessor still wraps at
+ // first request in, most of the Batches generated will be under max-length break.
+ //
+ // In the unlikely event of a few sentences overflowing, this allows the exceeding words to be put in the slack area.
+ // Very few batches are expected to be generated at a higher length.
+ size_t pivotSlack = maxLengthBreak * maxLengthFactor - maxLengthBreak;
+ bucket_.resize(maxLengthBreak + pivotSlack + 1);
+
+ ABORT_IF(bucket_.size() - 1 > miniBatchWords_,
+ "Fatal: max-length-break > mini-batch-words will lead to sentences "
+ "longer than what can fit in a batch.");
+}
+
+size_t BatchingPool::generateBatch(Batch &batch) {
+ // For now simply iterates on buckets and converts batches greedily. This
+ // has to be enhanced with optimizing over priority. The baseline
+ // implementation should at least be as fast as marian's maxi-batch with full
+ // corpus size as maxi-batch size.
+ batch.clear();
+ size_t paddedBatchSize = 0;
+
+ for (size_t length = 0; length <= maxActiveBucketLength_; length++) {
+ auto p = bucket_[length].begin();
+ while (p != bucket_[length].end()) {
+ paddedBatchSize = (batch.size() + 1) * length;
+ if (paddedBatchSize <= miniBatchWords_) {
+ auto q = p++;
+ batch.add(*q);
+ bucket_[length].erase(q);
+ } else {
+ // Check if elements exist
+ assert(batch.size() > 0);
+ return batch.size();
+ }
+ }
+ }
+
+ return batch.size();
+}
+
+size_t BatchingPool::enqueueRequest(Ptr request) {
+ size_t toBeFreshlyTranslated = 0;
+ for (size_t i = 0; i < request->numSegments(); i++) {
+ if (!request->cacheHitPrefilled(i)) {
+ RequestSentence sentence(i, request);
+ size_t bucket_id = sentence.numTokens();
+
+ // Due to a workaround for pivoting, unless we can discipline the
+ // vocabulary to get stronger static requirements, it is difficult to
+ // rework the rest of the components. Instead, we allow dynamic growth
+ // here. We let std::vector take care of the dynamic growth.
+ // https://en.cppreference.com/w/cpp/container/vector/resize#Complexity
+ if (bucket_id >= bucket_.size()) {
+ bucket_.resize(bucket_id + 1);
+ }
+
+ bucket_[bucket_id].insert(sentence);
+ maxActiveBucketLength_ = std::max(bucket_id, maxActiveBucketLength_);
+
+ toBeFreshlyTranslated += 1;
+ }
+ }
+
+ return toBeFreshlyTranslated;
+}
+
+void BatchingPool::clear() {
+ for (size_t length = 0; length < bucket_.size(); length++) {
+ bucket_[length].clear();
+ }
+}
+
+} // namespace bergamot
+} // namespace marian
diff --git a/inference/src/translator/batching_pool.h b/inference/src/translator/batching_pool.h
new file mode 100644
index 000000000..58cd2ca8b
--- /dev/null
+++ b/inference/src/translator/batching_pool.h
@@ -0,0 +1,42 @@
+#ifndef SRC_BERGAMOT_BATCHING_POOL_H_
+#define SRC_BERGAMOT_BATCHING_POOL_H_
+
+#include
+#include
+
+#include "batch.h"
+#include "common/options.h"
+#include "data/corpus_base.h"
+#include "definitions.h"
+#include "request.h"
+
+namespace marian {
+namespace bergamot {
+
+class BatchingPool {
+ public:
+ explicit BatchingPool(Ptr options);
+
+ // RequestSentence incorporates (tentative) notions of priority with each
+ // sentence. This method inserts the sentence into the internal data-structure
+ // which maintains priority among sentences from multiple concurrent requests.
+ size_t enqueueRequest(Ptr request);
+
+ // Loads sentences with sentences compiled from (tentatively) multiple
+ // requests optimizing for both padding and priority.
+ size_t generateBatch(Batch &batch);
+
+ // Removes any pending requests from the pool.
+ void clear();
+
+ private:
+ size_t miniBatchWords_;
+ std::vector> bucket_;
+ size_t batchNumber_{0};
+ size_t maxActiveBucketLength_;
+};
+
+} // namespace bergamot
+} // namespace marian
+
+#endif // SRC_BERGAMOT_BATCHING_POOL_H_
diff --git a/inference/src/translator/byte_array_util.cpp b/inference/src/translator/byte_array_util.cpp
new file mode 100644
index 000000000..c7515e797
--- /dev/null
+++ b/inference/src/translator/byte_array_util.cpp
@@ -0,0 +1,178 @@
+#include "byte_array_util.h"
+
+#include
+#include
+
+#include "common/io.h"
+#include "data/shortlist.h"
+
+namespace marian {
+namespace bergamot {
+
+namespace {
+// This is a basic validator that checks if the file has not been truncated
+// it basically loads up the header and checks
+
+// This struct and the getter are copied from the marian source, because it's located
+// inside src/common/binary.cpp:15 and we can't include it.
+struct Header {
+ uint64_t nameLength;
+ uint64_t type;
+ uint64_t shapeLength;
+ uint64_t dataLength;
+};
+
+// cast current void pointer to T pointer and move forward by num elements
+template
+const T* get(const void*& current, uint64_t num = 1) {
+ const T* ptr = (const T*)current;
+ current = (const T*)current + num;
+ return ptr;
+}
+} // Anonymous namespace
+
+bool validateBinaryModel(const AlignedMemory& model, uint64_t fileSize) {
+ const void* current = model.begin();
+ uint64_t memoryNeeded =
+ sizeof(uint64_t) * 2; // We keep track of how much memory we would need if we have a complete file
+ uint64_t numHeaders;
+ if (fileSize >= memoryNeeded) { // We have enough filesize to fetch the headers.
+ uint64_t binaryFileVersion = *get(current);
+ numHeaders = *get(current); // number of item headers that follow
+ } else {
+ return false;
+ }
+ memoryNeeded += numHeaders * sizeof(Header);
+ const Header* headers;
+ if (fileSize >= memoryNeeded) {
+ headers = get(current, numHeaders); // read that many headers
+ } else {
+ return false;
+ }
+
+ // Calculate how many bytes we are going to for reading just the names and the shape
+ for (uint64_t i = 0; i < numHeaders; i++) {
+ memoryNeeded += headers[i].nameLength + headers[i].shapeLength * sizeof(int);
+ // Advance the pointers.
+ get(current, headers[i].nameLength);
+ get(current, headers[i].shapeLength);
+ }
+
+ // Before we start reading the data, there is a small padding to ensure alignment
+ // Read that in, before calculating the actual tensor memory requirements.
+ uint64_t aligned_offset;
+ if (fileSize >= memoryNeeded) {
+ aligned_offset = *get(current); // Offset to align memory to 256 size
+ memoryNeeded += aligned_offset + sizeof(uint64_t);
+ } else {
+ return false;
+ }
+
+ // Finally the tensor size:
+ for (uint64_t i = 0; i < numHeaders; i++) {
+ memoryNeeded += headers[i].dataLength;
+ }
+
+ // If this final check passes, the file is at least big enough to contain the model
+ if (fileSize >= memoryNeeded) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+AlignedMemory loadFileToMemory(const std::string& path, size_t alignment) {
+ uint64_t fileSize = filesystem::fileSize(path);
+ io::InputFileStream in(path);
+ ABORT_IF(in.bad(), "Failed opening file stream: {}", path);
+ AlignedMemory alignedMemory(fileSize, alignment);
+ in.read(reinterpret_cast(alignedMemory.begin()), fileSize);
+ ABORT_IF(alignedMemory.size() != fileSize, "Error reading file {}", path);
+ return alignedMemory;
+}
+
+std::vector getModelMemoryFromConfig(marian::Ptr options) {
+ auto models = options->get>("models");
+
+ std::vector modelMemories(models.size());
+ for (size_t i = 0; i < models.size(); ++i) {
+ const auto model = models[i];
+ if (marian::io::isBin(model)) {
+ modelMemories[i] = loadFileToMemory(model, 256);
+ } else if (marian::io::isNpz(model)) {
+ // if any of the models are npz format, we revert to loading from file for all models.
+ LOG(debug, "Encountered an npz file {}; will use file loading for {} models", model, models.size());
+ return {};
+ } else {
+ ABORT("Unknown extension for model: {}, should be one of `.bin` or `.npz`", model);
+ }
+ }
+
+ return modelMemories;
+}
+
+AlignedMemory getShortlistMemoryFromConfig(marian::Ptr options) {
+ auto shortlist = options->get>("shortlist");
+ if (!shortlist.empty()) {
+ ABORT_IF(!marian::data::isBinaryShortlist(shortlist[0]),
+ "Loading non-binary shortlist file into memory is not supported");
+ return loadFileToMemory(shortlist[0], 64);
+ }
+ return AlignedMemory();
+}
+
+void getVocabsMemoryFromConfig(marian::Ptr options,
+ std::vector>& vocabMemories) {
+ auto vfiles = options->get>("vocabs");
+ ABORT_IF(vfiles.size() < 2, "Insufficient number of vocabularies.");
+ vocabMemories.resize(vfiles.size());
+ std::unordered_map> vocabMap;
+ for (size_t i = 0; i < vfiles.size(); ++i) {
+ ABORT_IF(marian::filesystem::Path(vfiles[i]).extension() != marian::filesystem::Path(".spm"),
+ "Loading non-SentencePiece vocab files into memory is not supported");
+ auto m = vocabMap.emplace(std::make_pair(vfiles[i], std::shared_ptr()));
+ if (m.second) {
+ m.first->second = std::make_shared(loadFileToMemory(vfiles[i], 64));
+ }
+ vocabMemories[i] = m.first->second;
+ }
+}
+
+AlignedMemory getQualityEstimatorModel(const marian::Ptr& options) {
+ const auto qualityEstimatorPath = options->get("quality", "");
+ if (qualityEstimatorPath.empty()) {
+ return {};
+ }
+ return loadFileToMemory(qualityEstimatorPath, 64);
+}
+
+AlignedMemory getQualityEstimatorModel(MemoryBundle& memoryBundle, const marian::Ptr& options) {
+ if (memoryBundle.qualityEstimatorMemory.size() == 0) {
+ return getQualityEstimatorModel(options);
+ }
+
+ return std::move(memoryBundle.qualityEstimatorMemory);
+}
+
+MemoryBundle getMemoryBundleFromConfig(marian::Ptr options) {
+ MemoryBundle memoryBundle;
+ memoryBundle.models = getModelMemoryFromConfig(options);
+ memoryBundle.shortlist = getShortlistMemoryFromConfig(options);
+ getVocabsMemoryFromConfig(options, memoryBundle.vocabs);
+ memoryBundle.ssplitPrefixFile = getSsplitPrefixFileMemoryFromConfig(options);
+ memoryBundle.qualityEstimatorMemory = getQualityEstimatorModel(options);
+
+ return memoryBundle;
+}
+
+AlignedMemory getSsplitPrefixFileMemoryFromConfig(marian::Ptr options) {
+ std::string fpath = options->get("ssplit-prefix-file", "");
+ if (!fpath.empty()) {
+ return loadFileToMemory(fpath, 64);
+ }
+ // Return empty AlignedMemory
+ return AlignedMemory();
+}
+
+} // namespace bergamot
+} // namespace marian
diff --git a/inference/src/translator/byte_array_util.h b/inference/src/translator/byte_array_util.h
new file mode 100644
index 000000000..851a175fd
--- /dev/null
+++ b/inference/src/translator/byte_array_util.h
@@ -0,0 +1,18 @@
+#include "definitions.h"
+#include "marian.h"
+
+namespace marian {
+namespace bergamot {
+
+AlignedMemory loadFileToMemory(const std::string& path, size_t alignment);
+std::vector getModelMemoryFromConfig(marian::Ptr options);
+AlignedMemory getQualityEstimatorModel(const marian::Ptr& options);
+AlignedMemory getQualityEstimatorModel(MemoryBundle& memoryBundle, const marian::Ptr& options);
+AlignedMemory getShortlistMemoryFromConfig(marian::Ptr options);
+AlignedMemory getSsplitPrefixFileMemoryFromConfig(marian::Ptr options);
+void getVocabsMemoryFromConfig(marian::Ptr options,
+ std::vector>& vocabMemories);
+bool validateBinaryModel(const AlignedMemory& model, uint64_t fileSize);
+MemoryBundle getMemoryBundleFromConfig(marian::Ptr options);
+} // namespace bergamot
+} // namespace marian
diff --git a/inference/src/translator/cache.h b/inference/src/translator/cache.h
new file mode 100644
index 000000000..ceeca5d32
--- /dev/null
+++ b/inference/src/translator/cache.h
@@ -0,0 +1,91 @@
+#pragma once
+#include
+#include
+#include
+#include
+
+#include "definitions.h"
+#include "translator/history.h"
+
+namespace marian::bergamot {
+
+template , class Equals = std::equal_to>
+class AtomicCache {
+ public:
+ struct Stats {
+ size_t hits{0};
+ size_t misses{0};
+ };
+
+ explicit AtomicCache(size_t size, size_t buckets) : records_(size), mutexBuckets_(buckets) {}
+
+ std::pair find(const Key &key) const {
+ Value value;
+ bool found = atomicLoad(key, value);
+ return std::make_pair(found, value);
+ }
+
+ void store(const Key &key, Value value) { atomicStore(key, value); }
+
+ const Stats stats() const {
+#ifdef ENABLE_CACHE_STATS
+ return Stats{hits_.load(), misses_.load()};
+#else
+ ABORT("Cache statistics requested without enabling in builds. Please use -DENABLE_CACHE_STATS with cmake.");
+ return Stats{0, 0};
+#endif
+ }
+
+ private:
+ using Record = std::pair;
+
+ bool atomicLoad(const Key &key, Value &value) const {
+ // No probing, direct map onto records_
+ size_t index = hash_(key) % records_.size();
+ size_t mutexId = index % mutexBuckets_.size();
+
+ std::lock_guard lock(mutexBuckets_[mutexId]);
+ const Record &candidate = records_[index];
+ if (equals_(key, candidate.first)) {
+ value = candidate.second;
+#ifdef ENABLE_CACHE_STATS
+ ++hits_;
+#endif
+ return true;
+ } else {
+#ifdef ENABLE_CACHE_STATS
+ ++misses_;
+#endif
+ }
+
+ return false;
+ }
+
+ void atomicStore(const Key &key, Value value) {
+ // No probing, direct map onto records_
+ size_t index = hash_(key) % records_.size();
+ size_t mutexId = index % mutexBuckets_.size();
+
+ std::lock_guard lock(mutexBuckets_[mutexId]);
+ Record &candidate = records_[index];
+
+ candidate.first = key;
+ candidate.second = value;
+ }
+
+ std::vector records_;
+
+ mutable std::vector mutexBuckets_;
+
+#ifdef ENABLE_CACHE_STATS
+ mutable std::atomic hits_{0};
+ mutable std::atomic misses_{0};
+#endif
+
+ Hash hash_;
+ Equals equals_;
+};
+
+typedef AtomicCache> TranslationCache;
+
+} // namespace marian::bergamot
diff --git a/inference/src/translator/definitions.h b/inference/src/translator/definitions.h
new file mode 100644
index 000000000..efba3f9f6
--- /dev/null
+++ b/inference/src/translator/definitions.h
@@ -0,0 +1,78 @@
+#ifndef SRC_BERGAMOT_DEFINITIONS_H_
+#define SRC_BERGAMOT_DEFINITIONS_H_
+
+#include
+
+#include "aligned.h"
+#include "data/types.h"
+#include "data/vocab_base.h"
+
+namespace marian {
+namespace bergamot {
+
+typedef marian::Words Segment;
+typedef std::vector Segments;
+
+/// Shortcut to AlignedVector for byte arrays
+typedef AlignedVector AlignedMemory;
+
+/// Memory bundle for all byte-arrays.
+/// Can be a set/subset of model, shortlist, vocabs and ssplitPrefixFile bytes.
+struct MemoryBundle {
+ std::vector models{}; ///< Byte-array of model (each element is aligned to 256)
+ AlignedMemory shortlist{}; ///< Byte-array of shortlist (aligned to 64)
+
+ /// Vector of vocabulary memories (aligned to 64).
+ /// If two vocabularies are the same (based on the filenames), two entries (shared
+ /// pointers) will be generated which share the same AlignedMemory object.
+ std::vector> vocabs{};
+
+ /// @todo Not implemented yet
+ AlignedMemory ssplitPrefixFile{};
+
+ AlignedMemory qualityEstimatorMemory; ///< Byte-array of qe model (aligned to 64)
+};
+
+/// ByteRange stores indices for half-interval [begin, end) in a string. Can be
+/// used to represent a sentence, word.
+struct ByteRange {
+ size_t begin;
+ size_t end;
+ const size_t size() const { return end - begin; }
+ bool operator==(ByteRange other) const { return begin == other.begin && end == other.end; }
+};
+
+/// A Subword range is mechanically the same as a `ByteRange`, but instead of
+/// describing a span of bytes, it describes a span of Subword tokens. Using
+/// `Annotation.word()` you can switch between the two.
+struct SubwordRange {
+ size_t begin;
+ size_t end;
+ const size_t size() const { return end - begin; }
+ bool operator==(SubwordRange other) const { return begin == other.begin && end == other.end; }
+};
+
+class Response;
+using CallbackType = std::function;
+
+} // namespace bergamot
+} // namespace marian
+
+// @TODO at the moment the usage of string_view in this repository is a hot mess and a disaster waiting to happen.
+// ssplit uses std::string_view if the compiler supports c++17, else falls back to c++11 and absl::string_view
+// bergamot-translator uses, depending on the source file std::string_view (which will break if ssplit-cpp uses
+// absl::string_view) and marian::string_view which is an export of (confusingly) the sentencepiece module that
+// marian has. marian::string_view is our addition to the marian fork, which will make merging even nicer. Not.
+// This is just an ugly patchwork that allos gcc5, our lowest targetted gcc to run. We don't actually try to run
+// on older compilers.
+
+#if defined(__GNUC__) && __GNUC__ < 6 && !defined(__clang__)
+#include
+namespace std {
+using string_view = std::experimental::string_view;
+} // namespace std
+#else
+#include
+#endif
+
+#endif // SRC_BERGAMOT_DEFINITIONS_H_
diff --git a/inference/src/translator/html.cpp b/inference/src/translator/html.cpp
new file mode 100644
index 000000000..421074aa1
--- /dev/null
+++ b/inference/src/translator/html.cpp
@@ -0,0 +1,811 @@
+#include "html.h"
+
+#include
+
+#include "response.h"
+#include "translator/definitions.h"
+#include "xh_scanner.h"
+
+namespace {
+using marian::bergamot::AnnotatedText;
+using marian::bergamot::ByteRange;
+using marian::bergamot::HTML;
+using marian::bergamot::Response;
+
+/// Encodes the minimum of HTML entities.
+void encodeEntities(marian::string_view const &input, std::string &output) {
+ output.clear();
+ output.reserve(input.size()); // assumes there are no entities in most cases
+
+ for (char it : input) {
+ switch (it) {
+ case '&':
+ output.append("&");
+ break;
+ case '<':
+ output.append("<");
+ break;
+ case '>':
+ output.append(">");
+ break;
+ // case ???:
+ // output.append(" ");
+ // break;
+ // case '"':
+ // output.append(""");
+ // break;
+ // case '\'':
+ // output.append("'");
+ // break;
+ default:
+ output.push_back(it);
+ break;
+ }
+ }
+}
+
+/// Counts number of whitespace characters at the start of the input. Used
+/// for determining where to insert an open or close tag.
+size_t countPrefixWhitespaces(marian::string_view const &input) {
+ size_t size = 0;
+ while (size < input.size() && std::isspace(static_cast(input[size]))) ++size;
+ return size;
+}
+
+std::string toLowerCase(std::string_view const &input) {
+ std::string out;
+ out.resize(input.size());
+ std::transform(input.begin(), input.end(), out.begin(), [](unsigned char c) { return std::tolower(c); });
+ return out;
+}
+
+/// Very simple replacement for std::format introduced in C++20. Only supports
+/// replacing `{}` in the template string with whatever `operator<<` for that
+/// type turns it into.
+std::string format(std::string const &formatTemplate) { return formatTemplate; }
+
+template
+std::string format(std::string const &formatTemplate, Arg arg) {
+ std::ostringstream os;
+ auto index = formatTemplate.find("{}");
+ assert(index != std::string::npos);
+ os << formatTemplate.substr(0, index) << arg << formatTemplate.substr(index + 2);
+ return os.str();
+}
+
+template
+std::string format(std::string const &formatTemplate, Arg arg, Args... args) {
+ std::ostringstream os;
+ auto index = formatTemplate.find("{}");
+ assert(index != std::string::npos);
+ os << formatTemplate.substr(0, index) << arg << format(formatTemplate.substr(index + 2), std::forward(args)...);
+ return os.str();
+}
+
+/// Syntactic sugar around rbegin() and rend() that allows me to write
+/// `for (auto &&item : reversed(container))` instead of the needlessly verbose
+/// `for (auto it = container.rbegin(); it != container.rend(); ++it)`
+template
+class Reversed {
+ public:
+ using iterator = typename T::const_reverse_iterator;
+ explicit Reversed(T const &container) : container_(container){};
+ iterator begin() const { return container_.rbegin(); }
+ iterator end() const { return container_.rend(); }
+
+ private:
+ T const &container_;
+};
+
+/// When comparing two tag stacks, determine which tags need to be closed and
+/// opened to get from one stack to the other.
+void diffTags(HTML::TagStack const &prev, HTML::TagStack const &curr, HTML::TagStack &opening,
+ HTML::TagStack &closing) {
+ opening.clear();
+ closing.clear();
+
+ size_t i = 0;
+
+ // Find first difference
+ for (; i < prev.size(); ++i)
+ if (i >= curr.size() || prev[i] != curr[i]) break;
+
+ // Only nodes of type ELEMENT can have children and thus would need a closing tag.
+ // NOLINTNEXTLINE(bugprone-narrowing-conversions)
+ std::copy_if(prev.begin() + i, prev.end(), std::back_inserter(closing),
+ [&](HTML::Tag *tag) { return tag->type == HTML::Tag::ELEMENT; });
+
+ // NOLINTNEXTLINE(bugprone-narrowing-conversions)
+ opening.insert(opening.end(), curr.begin() + i, curr.end());
+}
+
+bool intersects(ByteRange const &range, HTML::Span const &span) {
+ return range.begin <= span.end && range.end >= span.begin;
+};
+
+bool contains(HTML::TagNameSet const &set, std::string_view const &name) { return set.find(name) != set.end(); }
+
+bool contains(HTML::TagStack const &stack, HTML::Tag const *tag) {
+ return std::find(stack.rbegin(), stack.rend(), tag) != stack.rend();
+}
+
+/// Is tag stack B an extended version of A? I.e. same tags, but maybe a few
+/// more nested deeper.
+bool extends(HTML::TagStack const &b, HTML::TagStack const &a) {
+ if (a.size() > b.size()) return false;
+
+ for (auto i = a.begin(), j = b.begin(); i != a.end(); ++i, ++j)
+ if (*i != *j) return false;
+
+ return true;
+}
+
+/// Tests whether `response` has alignment info associated with it or not.
+bool hasAlignments(Response const &response) {
+ // Test for each sentence individually as a sentence may be empty (or there)
+ // might be no sentences, so just testing for alignments.empty() would not be
+ // sufficient.
+ for (size_t sentenceIdx = 0; sentenceIdx < response.target.numSentences(); ++sentenceIdx) {
+ // If response.alignments is just empty, this might catch it.
+ if (response.alignments.size() <= sentenceIdx ||
+ response.alignments[sentenceIdx].size() != response.target.numWords(sentenceIdx))
+ return false;
+
+ // If response.alignments is "empty" because the model did not provide alignments,
+ // it still has entries for each target word. But all these entries are empty.
+ for (size_t wordIdx = 0; wordIdx < response.target.numWords(sentenceIdx); ++wordIdx)
+ if (response.alignments[sentenceIdx][wordIdx].size() != response.source.numWords(sentenceIdx)) return false;
+ }
+ return true;
+}
+
+/// Helper class to append HTML tags to a token. Also makes sure the token is
+/// encoded as valid HTML.
+class TokenFormatter {
+ public:
+ explicit TokenFormatter(marian::string_view token)
+ : offset_(0), whitespaceOffset_(0), whitespaceSize_(countPrefixWhitespaces(token)), closeLeft_(true) {
+ // Do encoding of any entities that popped up in the translation
+ encodeEntities(token, html_);
+ }
+
+ std::string &&html() { return std::move(html_); }
+
+ // Append the markup necessary for moving from `prev` set of tags to `curr`.
+ void append(HTML::TagStack const &prev, HTML::TagStack const &curr) {
+ HTML::TagStack opening, closing;
+
+ diffTags(prev, curr, opening, closing);
+
+ for (HTML::Tag const *tag : Reversed(closing)) {
+ assert(tag->type == HTML::Tag::ELEMENT);
+ std::string closeTag = format("{}>", tag->name);
+ html_.insert(offset_ + (closeLeft_ ? 0 : whitespaceSize_), closeTag);
+ offset_ += closeTag.size();
+ if (closeLeft_) whitespaceOffset_ += closeTag.size();
+ }
+
+ for (HTML::Tag const *tag : opening) {
+ std::string openTag;
+ switch (tag->type) {
+ case HTML::Tag::ELEMENT:
+ case HTML::Tag::VOID_ELEMENT:
+ openTag = format("<{}{}>{}", tag->name, tag->attributes, tag->data);
+ break;
+ case HTML::Tag::COMMENT:
+ openTag = format("", tag->data);
+ break;
+ case HTML::Tag::PROCESSING_INSTRUCTION:
+ openTag = format("{}?>", tag->data);
+ break;
+ case HTML::Tag::WHITESPACE: {
+ // Try to eat two newlines (paragraph break) from our segment
+ auto pos = html_.find("\n\n", whitespaceOffset_);
+ if (pos != std::string::npos && pos < whitespaceOffset_ + whitespaceSize_) {
+ html_.erase(pos, 2);
+ whitespaceSize_ -= 2;
+ }
+ } break;
+ }
+
+ html_.insert(offset_ + whitespaceSize_, openTag);
+ offset_ += openTag.size();
+ closeLeft_ = closeLeft_ && openTag.empty();
+ }
+ }
+
+ private:
+ std::string html_; // Output html
+ size_t offset_; // Size added by prepending HTML
+ size_t whitespaceOffset_; // position of prefix whitespace characters
+ // (it moves as closing tags are prepended)
+ size_t whitespaceSize_; // number of prefix whitespace characters
+
+ // Close tags we want to show up left (before) the token, but open tags
+ // ideally come directly after any prefix whitespace. However, some tokens
+ // match multiple spans. If a previous span has added an open tag, after any
+ // whitespace, and the next span closes said tag again, we need to close
+ // it after the whitespace. So after the first open tag, any closing tag
+ // should also align right, after whitespace, not before. Hence this bool.
+ bool closeLeft_;
+};
+
+/// Count the number of tokens in an AnnotatedText. Used to assert we're not
+/// running out of sync when creating vectors that describe each token.
+size_t debugCountTokens(AnnotatedText const &text) {
+ size_t tokens = 1; // for the ending gap
+ for (size_t sentenceIdx = 0; sentenceIdx < text.numSentences(); ++sentenceIdx) {
+ tokens += 1 + text.numWords(sentenceIdx); // pre-sentence prefix/gap + each word
+ }
+ return tokens;
+}
+
+/// Helper function that consumes a tag as if it is a special tag, except that
+/// it takes nesting into account. I.e. `` will be consumed to the
+// last ``. Assumes TT_TAG_START is already consumed, which was necessary
+/// to determine whether this was an element that needed to be ignored.
+void consumeIgnoredTag(markup::Scanner &scanner, HTML::Tag &tag, std::string const &name) {
+ // Only full elements can be consumed this way. With void tags we don't know
+ // where to stop scanning. All other types cannot be nested anyway.
+ assert(tag.type == HTML::Tag::ELEMENT);
+
+ // TT_TAG_START is already consumed.
+ markup::Scanner::TokenType token;
+ size_t inside = 0;
+
+ // Consume the full open tag, i.e. all its attributes
+ while (!inside) {
+ token = scanner.next();
+ switch (token) {
+ case markup::Scanner::TT_ERROR:
+ ABORT("HTML parse error");
+ case markup::Scanner::TT_EOF:
+ ABORT("Did not find closing tag {}>", name);
+ case markup::Scanner::TT_ATTRIBUTE:
+ tag.attributes += format(" {}=\"{}\"", scanner.attribute(), scanner.value());
+ break;
+ default:
+ // Not an attribute! Must be something inside the body or the closing
+ // tag already. Time to jump to the next loop.
+ ++inside;
+ break;
+ }
+ }
+
+ // Last token was something that would have triggered Scanner::scanBody(),
+ // which sets value() to start pointing at the body.
+ const char *start = scanner.start();
+
+ // Consume the rest of the HTML until (including) the final closing tag. We
+ // start with the token that caused the previous loop to fall into the default
+ // case.
+ while (inside) {
+ switch (token) {
+ case markup::Scanner::TT_ERROR:
+ ABORT("HTML parse error");
+ case markup::Scanner::TT_EOF:
+ ABORT("Did not find closing tag {}>");
+ case markup::Scanner::TT_TAG_START:
+ // Note: Looking specifically for only our own type of tag so we don't
+ // have to care about whether other tags we encounter are void tags or
+ // not. Does assume the HTML is valid, as no stack is kept.
+ if (toLowerCase(scanner.tag()) == name) ++inside;
+ break;
+ case markup::Scanner::TT_TAG_END:
+ if (toLowerCase(scanner.tag()) == name) --inside;
+ break;
+ default:
+ break;
+ }
+
+ // Only continue scanning if we're still inside. We could have just read the
+ // TT_TAG_END token that ended this element, and we don't want to continue
+ // consuming tokens at that point.
+ if (inside) token = scanner.next();
+ }
+
+ // Only a TAG_END could have stopped the previous loop. We take the start
+ // of the final closing tag as the end of our data.
+ assert(token == markup::Scanner::TT_TAG_END);
+ const char *end = scanner.start();
+
+ // All data between the end of the first open element, and the start of the
+ // last close element, we just treat as raw data that will be printed when
+ // this tag is eventually printed.
+ assert(end >= start);
+ tag.data = std::string_view(start, end - start);
+}
+
+} // namespace
+
+namespace marian::bergamot {
+
+/// Formatters used for formatting error messages in ABORT() calls.
+std::ostream &operator<<(std::ostream &out, HTML::Tag const *tag) {
+ if (tag == nullptr) return out << "[nullptr]";
+ switch (tag->type) {
+ case HTML::Tag::ELEMENT:
+ return out << '<' << tag->name << tag->attributes << '>';
+ case HTML::Tag::VOID_ELEMENT:
+ return out << '<' << tag->name << tag->attributes << "/>";
+ case HTML::Tag::COMMENT:
+ return out << "";
+ case HTML::Tag::PROCESSING_INSTRUCTION:
+ return out << "" << tag->data << "?>";
+ case HTML::Tag::WHITESPACE:
+ return out << "[inserted space]";
+ }
+ return out << "[Unknown tag type]";
+}
+
+std::ostream &operator<<(std::ostream &out, HTML::TagStack const &tags) {
+ for (auto it = tags.begin(); it != tags.end(); ++it) {
+ if (it != tags.begin()) out << ' ';
+ out << *it;
+ }
+ return out;
+}
+
+HTML::HTML(std::string &&source, bool processMarkup, Options &&options) : options_(std::move(options)) {
+ if (!processMarkup) return;
+
+ std::string original = std::move(source);
+ markup::instream in(original.data(), original.data() + original.size());
+ markup::Scanner scanner(in);
+ source.clear(); // source is moved out of, so should be clear anyway
+
+ Tag *tag = nullptr; // current tag (after opening at least)
+ TagStack stack; // stack of currently open tags
+ bool addSentenceBreak = false; // whether to add a sentence break next text segment
+ bool addWordBreak = false; // whether to add a word break next text segment
+
+ // Starting point: an empty span with no open tags.
+ spans_.push_back(Span{0, 0, {}});
+
+ bool stop = false;
+ while (!stop) {
+ switch (scanner.next()) {
+ case markup::Scanner::TT_ERROR:
+ ABORT("HTML parse error");
+
+ case markup::Scanner::TT_EOF:
+ stop = true;
+ break;
+
+ case markup::Scanner::TT_TEXT: {
+ // If the previous segment was the open or close tag of a block element
+ // we treat the text after it as a new sentence.
+ if (addSentenceBreak) {
+ // If there isn't already a \n\n at the end of source...
+ if (source.size() >= 2 && source.substr(source.size() - 2) != "\n\n") {
+ stack.push_back(makeTag({Tag::WHITESPACE}));
+ // Important: span->size() == 0 to make it behave as a void element.
+ // Also important: position before the \n\n tokens, not after, to
+ // make it easier to remove them later through apply().
+ spans_.push_back(Span{source.size(), source.size(), stack});
+ source.append("\n\n"); // Should work with ssplit-mode = wrapped_text
+ stack.pop_back();
+ }
+ addSentenceBreak = false;
+ }
+
+ // If the previous segment was an open or close tag, it might be best
+ // to add a space to make sure we don't append to the previous word.
+ if (addWordBreak) {
+ // Only add the space when it would be inside a word. Do not add it if
+ // it would be between a word and punctuation.
+ if (options_.substituteInlineTagsWithSpaces && isContinuation(source, scanner.value())) {
+ source.push_back(' ');
+ }
+ addWordBreak = false;
+ }
+
+ // Store which tags were open when this span of text was encountered.
+ auto begin = source.size();
+ source.append(scanner.value());
+ spans_.push_back(Span{begin, source.size(), stack});
+ } break;
+
+ case markup::Scanner::TT_TAG_START: {
+ std::string name = toLowerCase(scanner.tag());
+
+ // Tag *tag is used by attribute parsing
+ auto type = contains(options_.voidTags, name) ? Tag::VOID_ELEMENT : Tag::ELEMENT;
+ tag = makeTag({type, std::string(scanner.tag())});
+
+ stack.push_back(tag);
+
+ // Empty elements (e.g. ) are not applicable to a span of text
+ // so instead we "apply" them to an empty span in between, and then
+ // immediately remove them again from the stack.
+ if (tag->type == Tag::VOID_ELEMENT) {
+ spans_.push_back(Span{source.size(), source.size(), stack});
+ stack.pop_back();
+ }
+
+ // Ignored tags have same semantics as void tags with regards to moving
+ // them around with the rest of the content.
+ if (contains(options_.ignoredTags, name)) {
+ consumeIgnoredTag(scanner, *tag, name);
+ spans_.push_back(Span{source.size(), source.size(), stack});
+ stack.pop_back();
+ }
+
+ // Treat non-inline HTML tags as spaces that break up words.
+ if (!contains(options_.inlineTags, name)) {
+ addSentenceBreak = true;
+ } else if (!contains(options_.inWordTags, name)) {
+ addWordBreak = true;
+ }
+ } break;
+
+ case markup::Scanner::TT_TAG_END: {
+ std::string tagName = toLowerCase(scanner.tag());
+ // If this is the closing bit of a void tag, i.e. triggered by the "/>"
+ // bit of "", then completely ignore it.
+ if (contains(options_.voidTags, tagName)) break;
+
+ ABORT_IF(stack.empty(), "Encountered more closing tags ({}) than opening tags", scanner.tag());
+
+ ABORT_IF(toLowerCase(stack.back()->name) != toLowerCase(scanner.tag()),
+ "Encountered unexpected closing tag {}>, stack is {}", scanner.tag(), stack);
+
+ // What to do with "" case, where tag is immediately closed
+ // so it never makes it into the taint of any of the spans? This adds
+ // an empty span so it still gets recorded in spans_.
+ if (spans_.empty() || !contains(spans_.back().tags, stack.back()))
+ spans_.push_back(Span{source.size(), source.size(), stack});
+
+ stack.pop_back();
+
+ // Add space if necessary
+ if (!contains(options_.inlineTags, tagName)) {
+ addSentenceBreak = true;
+ } else if (!contains(options_.inWordTags, tagName)) {
+ addWordBreak = true;
+ }
+ } break;
+
+ case markup::Scanner::TT_ATTRIBUTE:
+ assert(tag != nullptr);
+ tag->attributes += format(" {}=\"{}\"", scanner.attribute(), scanner.value());
+ break;
+
+ case markup::Scanner::TT_COMMENT_START:
+ // Tag *tag is used when TT_DATA is seen to add the comment's content.
+ tag = makeTag({Tag::COMMENT});
+ stack.push_back(tag);
+ spans_.push_back(Span{source.size(), source.size(), stack});
+ stack.pop_back();
+ break;
+
+ case markup::Scanner::TT_PROCESSING_INSTRUCTION_START:
+ // Tag *tag is used when TT_DATA is seen to add the PI's content.
+ tag = makeTag({Tag::PROCESSING_INSTRUCTION});
+ stack.push_back(tag);
+ spans_.push_back(Span{source.size(), source.size(), stack});
+ stack.pop_back();
+ break;
+
+ case markup::Scanner::TT_COMMENT_END:
+ case markup::Scanner::TT_PROCESSING_INSTRUCTION_END:
+ tag = nullptr;
+ break;
+
+ case markup::Scanner::TT_DATA:
+ assert(tag != nullptr);
+ tag->data = scanner.value();
+ break;
+
+ default:
+ ABORT("Unsupported scanner token type");
+ }
+ }
+
+ ABORT_IF(!stack.empty(), "Not all tags were closed: {}", stack);
+
+ // Add a trailing span (that's empty) to signify all closed tags.
+ spans_.emplace_back(Span{source.size(), source.size(), stack});
+}
+
+void HTML::restore(Response &response) {
+ // No-op if process_markup was false (and thus spans_ is empty)
+ // TODO: replace this with optional at a higher level
+ if (spans_.empty()) return;
+
+ // We need alignment info to transfer the HTML tags from the input to the
+ // translation. If those are not available, no HTML in translations for you.
+ ABORT_UNLESS(hasAlignments(response),
+ "Response object does not contain alignments. TranslationModel or ResponseOptions is misconfigured?");
+
+ // Reconstruction of HTML tags:
+ // 1. Map each token to a Span
+ // 2. Reconstruct the source HTML with these tainted tokens
+ // 3. Transfer the spans from the source tokens to the target tokens using alignment information
+ // 4. For spans that represent empty elements (e.g. ) figure out their position
+ // 5. Reconstruct the target HTML with these tainted tokens
+
+ // sourceTokenSpans is a vector with a pointer to a span for each token. We
+ // use iterators here to point to these positions so we can easily compare if
+ // one span comes before or after another, information we'll need when we need
+ // to figure out whether we've skipped spans (of emtpy elements) when
+ // reconstructing HTML in response.target.
+ std::vector sourceTokenSpans;
+
+ // RestoreSource re-inserts HTML into the source text, but also identifies
+ // which span each source token fits into best.
+ AnnotatedText source = restoreSource(response.source, sourceTokenSpans);
+ assert(sourceTokenSpans.size() == debugCountTokens(response.source));
+
+ // Find for every token in target the token in source that best matches.
+ std::vector> alignments;
+ hardAlignments(response, alignments, sourceTokenSpans);
+
+ std::vector targetTokenSpans;
+ copyTagStack(response, alignments, sourceTokenSpans, targetTokenSpans);
+ assert(targetTokenSpans.size() == debugCountTokens(response.target));
+
+ // Take the spans, and use them to make a taint for every word in the
+ // translation. Optionally add extra tags, like quality score metadata.
+ std::vector targetTokenTags;
+ annotateTagStack(response, targetTokenSpans, targetTokenTags);
+
+ AnnotatedText target = restoreTarget(response.target, targetTokenSpans, targetTokenTags);
+
+ response.source = source;
+ response.target = target;
+}
+
+AnnotatedText HTML::restoreSource(AnnotatedText const &in, std::vector &sourceTokenSpans) {
+ auto spanIt = spans_.begin();
+ auto prevIt = spans_.begin(); // safe because first span is always empty span, and
+ // and the while-loop below will do the rest
+ assert(prevIt == spans_.end() || prevIt->tags.empty());
+
+ return in.apply([&](ByteRange range, string_view token, bool last) {
+ TokenFormatter formatter(token);
+
+ // Potential issue: spans and tokens can intersect, e.g.
+ //
+ // text
h e ll o
+ // spans |1| |2| |3333| (so only 2 is tainted with
, others only
)
+ // tokens |111111111111111|2|
+ //
+ // Now 1 covers span 1 to 3, so what taint should it get? Just `
`, or
+ // `
`?
+ // Note: only relevant if `substituteInlineTagsWithSpaces` is true. If we
+ // just insert spaces around all elements, every segment of `hello` will be
+ // a token.
+
+ // Seek to the last span that overlaps with this token
+ while (true) {
+ formatter.append(prevIt->tags, spanIt->tags);
+ prevIt = spanIt;
+
+ if (spanIt + 1 != spans_.end() && ((spanIt + 1)->begin < range.end || last)) {
+ spanIt++;
+ continue;
+ }
+
+ break;
+ }
+
+ // TODO: This is just the taint of the last span, not the ones in between.
+ // This makes us lose some markup of parts of tokens as described above.
+ sourceTokenSpans.emplace_back(prevIt);
+
+ return std::move(formatter.html());
+ });
+}
+
+AnnotatedText HTML::restoreTarget(AnnotatedText const &in, std::vector const &targetTokenSpans,
+ std::vector const &targetTokenTags) {
+ auto prevTags = spans_.cbegin()->tags;
+ auto stragglerSpanIt = spans_.cbegin();
+ auto targetSpanIt = targetTokenSpans.begin();
+ auto targetTagIt = targetTokenTags.begin();
+
+ AnnotatedText out = in.apply([&]([[maybe_unused]] ByteRange range, string_view token, bool last) {
+ TokenFormatter formatter(token);
+
+ // First we scan through spans_ to catch up to the span assigned to this
+ // token. We're only interested in empty spans (empty and void elements)
+ for (; stragglerSpanIt < *targetSpanIt; stragglerSpanIt++) {
+ // We're only interested in empty spans or spans that would otherwise get
+ // lost because they didn't align with anything between the spans in
+ // targetSpanIt
+ // TODO That std::find makes this O(N*N) NOT GOOD NOT GOOD
+ if (stragglerSpanIt->size() != 0 &&
+ std::find(targetTokenSpans.begin(), targetTokenSpans.end(), stragglerSpanIt) != targetTokenSpans.end())
+ continue;
+
+ formatter.append(prevTags, stragglerSpanIt->tags);
+ prevTags = stragglerSpanIt->tags;
+ }
+
+ // Now do the same thing but for our target set of tags. Note that we cannot
+ // combine this in the for-loop above (i.e. `span_it <= *targetSpanIt`)
+ // because there is no guarantee that the order in `targetTokenSpans` is
+ // the same as that of `spans`.
+
+ formatter.append(prevTags, *targetTagIt);
+
+ // If this is the last token of the response, close all open tags.
+ if (last) {
+ // Note: this assert is true due to our current implementation of
+ // HardAlignments() that always matches the last token of the input with
+ // the last token of the output. But lets assume someone someday changes
+ // HardAlignments(), and then this for-loop will be necessary.
+ // assert((*targetSpanIt)->tags.empty());
+ formatter.append(*targetTagIt, HTML::TagStack());
+ }
+
+ prevTags = *targetTagIt;
+ ++targetSpanIt;
+ ++targetTagIt;
+
+ return std::move(formatter.html());
+ });
+
+ // Assert that we did in fact use all our taints
+ assert(targetSpanIt == targetTokenSpans.end());
+
+ return out;
+}
+
+HTML::Tag *HTML::makeTag(Tag &&tag) {
+ pool_.emplace_front(std::move(tag));
+ return &pool_.front();
+}
+
+void HTML::copyTagStack(Response const &response, std::vector> const &alignments,
+ std::vector const &sourceTokenSpans,
+ std::vector &targetTokenSpans) {
+ size_t offset = 0; // Sentence offset in sourceTokenSpans
+
+ // Fill targetTokenSpans based on the alignments we just made up.
+ // NOTE: this should match the exact order of Apply()
+ for (size_t sentenceIdx = 0; sentenceIdx < response.target.numSentences(); ++sentenceIdx) {
+ targetTokenSpans.push_back(sourceTokenSpans[offset]); // token_tag for sentence ending gap
+ for (size_t t = 0; t < response.target.numWords(sentenceIdx); ++t) {
+ size_t s = alignments[sentenceIdx][t];
+ assert(s < response.source.numWords(sentenceIdx));
+ targetTokenSpans.push_back(sourceTokenSpans[offset + 1 + s]); // +1 for prefix gap
+ }
+
+ offset += response.source.numWords(sentenceIdx) + 1; // +1 for prefix gap
+ }
+
+ assert(offset + 1 == sourceTokenSpans.size());
+ targetTokenSpans.push_back(sourceTokenSpans[offset]); // token_tag for ending whitespace
+}
+
+void HTML::annotateTagStack(Response const &response, std::vector const &targetTokenSpans,
+ std::vector &targetTokenTags) {
+ auto spanIt = targetTokenSpans.begin();
+ for (size_t sentenceIdx = 0; sentenceIdx < response.target.numSentences(); ++sentenceIdx) {
+ // Sentence prefix
+ targetTokenTags.push_back((*spanIt)->tags);
+ spanIt++;
+
+ // Offset in targetTokenTags at which this sentence's tags start.
+ size_t tagOffset = targetTokenTags.size();
+
+ // Initially, just copy the span's tags to this token
+ for (size_t t = 0; t < response.target.numWords(sentenceIdx); ++t) {
+ targetTokenTags.emplace_back((*spanIt)->tags);
+ spanIt++;
+ }
+
+ // If we have quality score information, add that as metadata as well.
+ if (!response.qualityScores.empty()) {
+ auto const &sentenceQuality = response.qualityScores[sentenceIdx];
+ // Create a single tag for this sentence with sentence level info
+ Tag *sentenceTag = makeTag({Tag::ELEMENT, "font"});
+ sentenceTag->attributes += format(" x-bergamot-sentence-index=\"{}\" x-bergamot-sentence-score=\"{}\"",
+ sentenceIdx, sentenceQuality.sentenceScore);
+
+ // Add that tag to all tokens in this sentence.
+ for (size_t tokenIdx = 0; tokenIdx < response.target.numWords(sentenceIdx); ++tokenIdx) {
+ targetTokenTags[tagOffset + tokenIdx].push_back(sentenceTag);
+ }
+
+ // Add word level tags as well to all tokens that make up a word.
+ for (size_t wordIdx = 0; wordIdx < sentenceQuality.wordRanges.size(); ++wordIdx) {
+ Tag *wordTag = makeTag({Tag::ELEMENT, "font"});
+ wordTag->attributes += format(" x-bergamot-word-index=\"{}\" x-bergamot-word-score=\"{}\"", wordIdx,
+ sentenceQuality.wordScores[wordIdx]);
+ auto const &range = sentenceQuality.wordRanges[wordIdx];
+ for (size_t tokenIdx = range.begin; tokenIdx < range.end; ++tokenIdx) {
+ targetTokenTags[tagOffset + tokenIdx].push_back(wordTag);
+ }
+ }
+ }
+ }
+
+ // Suffix
+ targetTokenTags.push_back((*spanIt)->tags);
+ spanIt++;
+
+ assert(spanIt == targetTokenSpans.end());
+}
+
+// Reports if token `str` is likely to be a continuation of a word. This is used
+// to determine whether we should share the markup, or whether we should see
+// this token as a fresh start. This implementation will treat "hello[world]"
+// as 4 words, assuming its tokenised as something like `h ell o [ wor ld ]`.
+bool HTML::isContinuation(std::string_view prev, std::string_view str) const {
+ if (options_.continuationDelimiters.empty()) return false;
+ if (prev.empty() || str.empty()) return false;
+ return options_.continuationDelimiters.find(str[0]) == std::string::npos &&
+ options_.continuationDelimiters.find(prev.back()) == std::string::npos;
+}
+
+bool HTML::isContinuation(marian::string_view prev, marian::string_view str) const {
+ return isContinuation(std::string_view(prev.data(), prev.size()), std::string_view(str.data(), str.size()));
+}
+
+/// Selects for each token in `response.target` a best source token from
+/// `response.source` and writes this selection to `alignments`. The source
+/// token spans are used to also look at the markup applied to each token to
+/// figure out which source token best represents each target token.
+void HTML::hardAlignments(Response const &response, std::vector> &alignments,
+ std::vector const &sourceTokenSpans) {
+ size_t offset = 0; // sentence offset in sourceTokenSpans
+
+ // For each sentence...
+ for (size_t sentenceIdx = 0; sentenceIdx < response.target.numSentences(); ++sentenceIdx) {
+ alignments.emplace_back();
+
+ // Hard-align: find for each target token the most prevalent source token
+ // Note: only search from 0 to N-1 because token N is end-of-sentence token
+ // that can only align with the end-of-sentence token of the target
+ for (size_t t = 0; t + 1 < response.target.numWords(sentenceIdx); ++t) {
+ alignments.back().push_back(
+ std::max_element(response.alignments[sentenceIdx][t].begin(), response.alignments[sentenceIdx][t].end()) -
+ response.alignments[sentenceIdx][t].begin());
+ }
+
+ // Next, we try to smooth out these selected alignments with a few heuristics
+ for (size_t t = 1; t + 1 < response.target.numWords(sentenceIdx); ++t) {
+ // If this token is a continuation of a previous token, pick the tags from the most
+ // prevalent token for the whole word.
+ if (isContinuation(response.target.word(sentenceIdx, t - 1), response.target.word(sentenceIdx, t))) {
+ // Note: only looking at the previous token since that will already
+ // have this treatment applied to it.
+ size_t currSentenceIdx = alignments.back()[t];
+ size_t prevSentenceIdx = alignments.back()[t - 1];
+ float currScore = response.alignments[sentenceIdx][t][currSentenceIdx];
+ float prevScore = response.alignments[sentenceIdx][t - 1][prevSentenceIdx];
+
+ TagStack const &currTagStack = sourceTokenSpans[offset + 1 + currSentenceIdx]->tags;
+ TagStack const &prevTagStack = sourceTokenSpans[offset + 1 + prevSentenceIdx]->tags;
+
+ // If this token has more markup, or a better score than the previous
+ // token (and they together are part of a word-ish thing) then mark
+ // this word as aligning. Otherwise just copy the alignment source of
+ // the previous token.
+ if (extends(currTagStack, prevTagStack) || currScore >= prevScore) {
+ // Apply this to all previous tokens in the word
+ for (size_t i = t;; --i) {
+ alignments.back()[i] = currSentenceIdx;
+
+ // Stop if this was the first token or the beginning of the word
+ if (i == 0 ||
+ !isContinuation(response.target.word(sentenceIdx, i - 1), response.target.word(sentenceIdx, i)))
+ break;
+ }
+ } else {
+ alignments.back()[t] = prevSentenceIdx;
+ }
+ }
+ }
+
+ // Always align target end with source end
+ alignments.back().push_back(response.source.numWords(sentenceIdx) - 1);
+
+ offset += response.source.numWords(sentenceIdx) + 1; // +1 for prefix gap
+ }
+}
+
+} // namespace marian::bergamot
diff --git a/inference/src/translator/html.h b/inference/src/translator/html.h
new file mode 100644
index 000000000..f3c6dad19
--- /dev/null
+++ b/inference/src/translator/html.h
@@ -0,0 +1,224 @@
+#ifndef SRC_BERGAMOT_HTML_H_
+#define SRC_BERGAMOT_HTML_H_
+
+#include
+#include
+#include
+#include
+#include
+
+#include "annotation.h"
+#include "data/types.h"
+#include "definitions.h"
+
+namespace marian::bergamot {
+
+struct Response;
+
+/// HTML class parses and removes HTML from input text, and places it back into
+/// the translated output text.
+///
+/// When parsing the HTML, it treats tags as markup, where a list of nested tags
+/// can be seen as a list of markups that are applicable to all the text that
+/// follows. This list is stored as a `TagStack`. Whenever an HTML tag opens or
+/// closes, a new TagStack is created to reflect that. TagStack used to be
+/// called `Taint` because it *tainted* the text it was associated with with
+/// those tags as markup. The text between tags themselves is stored in the
+/// input variable. In `spans_`, the TagStack that is associated with a
+/// substring of that text is stored.
+/// When transferring the HTML from the source text to the translated target
+/// text, the TagStacks are first associated with each of the subwords from the
+/// source text. Using hard alignment, each subword in the source text is linked
+/// to a subword in the target text. The TagStacks are then copied over these
+/// links. Finally, the HTML is inserted back into the target text by for each
+/// subword, comparing the TagStack from the previous word to that word, and
+/// opening and closing elements to make up for the difference.
+///
+/// There are a couple of complexities though:
+/// 1. Not all tags can be treated as markup applied to text. For example, an
+/// `` does not contain text itself. Or `` does not. We do want
+/// those tags to remain in the output though. We do this by associating
+/// them to an empty `Span`. When inserting HTML back into the translation
+/// input or output, we keep track of where in the `spans_` vector we are,
+/// and insert any elements from empty spans that we might have skipped over
+/// because empty spans are never linked to tokens/subwords. These are
+/// *stragglers* in some parts of the code, or *void* or *empty* elements in
+/// other parts.
+/// 2. Some tags should be treated as paragraph indicators, and break up
+/// sentences. These are the usual suspects like `
`, but also `
` and
+/// `
`, to make sure we don't translate two table cells into a single
+/// word. This is the `addSentenceBreak` flag in the HTML parsing bit.
+/// We mark these breaks with `\n\n` in the input text and with a special
+/// WHITESPACE tag that we treat as any other void tag. Hopefully this tag
+/// moves with the added `\n\n` and it is easy for us to remove it again.
+/// (in practise it is since these only occur at the end of sentences and
+/// the end of sentences are always aligned between source and target.)
+/// 3. We treat most tags as word-breaking. We do this by adding spaces just
+/// after where we saw the open or close tag occur. If there is already
+/// some whitespace in that place, we do not add extra spaces.
+/// 4. TODO
+class HTML {
+ public:
+ using TagNameSet = std::set>;
+
+ /// Options struct that controls how HTML is interpreted.
+ struct Options {
+ /// List of elements for which we do not expect a closing tag, or
+ /// self-closing elements in XHTML. We do not need to see a closing tag
+ /// for these elements, and they cannot contain text or tags themselves.
+ /// See also:
+ /// https://developer.mozilla.org/en-US/docs/Glossary/Empty_element.
+ /// More relevant source of this list:
+ /// https://searchfox.org/mozilla-central/rev/7d17fd1fe9f0005a2fb19e5d53da4741b06a98ba/dom/base/FragmentOrElement.cpp#1791
+ TagNameSet voidTags{"area", "base", "basefont", "bgsound", "br", "col", "embed", "frame", "hr",
+ "img", "input", "keygen", "link", "meta", "param", "source", "track", "wbr"};
+
+ /// List of elements that are treated as inline, meaning they do not break
+ /// up sentences. Any element *not* in this list will cause the text that
+ /// follows its open or close tag to be treated as a separate sentence.
+ TagNameSet inlineTags{"abbr", "a", "b", "em", "i", "kbd", "mark", "math",
+ "output", "q", "ruby", "small", "span", "strong", "sub", "sup",
+ "time", "u", "var", "wbr", "ins", "del", "img"};
+
+ /// List of elements that are, regardless of `substituteInlineTagsWithSpaces`,
+ /// not substituted with spaces. Technically almost all inline elements
+ /// should be treated like this, except ` ` maybe, But in practice it
+ /// seems to be more effective to limit this set to just that one tag that
+ /// that can only really be used *inside* words: ``.
+ /// See also: https://developer.mozilla.org/en-US/docs/Web/HTML/Element/wbr
+ TagNameSet inWordTags{"wbr"};
+
+ /// List of elements we copy as is, but do parse as if they're HTML because
+ /// they could be nested. For because
+ /// the script tag may not be nested, but that is not the case for these
+ /// elements per se. Some tags, like
+ // ^-- or here
+ //
+ // ^-- or here
+ // comes after TT_COMMENT_START, TT_PI_START, or TT_TAG_START
+ // if the tag was
+ TokenType scanSpecial();
+
+ // Consumes
+ TokenType scanTag();
+
+ // Consumes '&' etc, emits parent_token_type
+ TokenType scanEntity(TokenType parentTokenType);
+
+ size_t skipWhitespace();
+
+ bool resolveEntity(string_ref const &buffer, string_ref &decoded) const;
+
+ static bool isWhitespace(char c);
+
+ private: /* data */
+ string_ref value_;
+ string_ref tagName_;
+ string_ref attributeName_;
+
+ ScanPtr scanFun_; // current 'reader'
+
+ instream &input_;
+
+ // Start position of a token.
+ const char *start_;
+
+ bool gotTail_; // aux flag used in scanComment, scanSpecial, scanProcessingInstruction
+};
+} // namespace markup
diff --git a/inference/wasm/CMakeLists.txt b/inference/wasm/CMakeLists.txt
new file mode 100644
index 000000000..ef8fd988a
--- /dev/null
+++ b/inference/wasm/CMakeLists.txt
@@ -0,0 +1,29 @@
+add_executable(bergamot-translator-worker
+ bindings/service_bindings.cpp
+ bindings/response_options_bindings.cpp
+ bindings/response_bindings.cpp
+)
+
+# Generate version file that can be included in the wasm artifacts
+configure_file(${CMAKE_CURRENT_SOURCE_DIR}/project_version.js.in
+ ${CMAKE_CURRENT_BINARY_DIR}/project_version.js @ONLY)
+
+# This header inclusion needs to go away later as path to public headers of bergamot
+# translator should be directly available from "bergamot-translator" target
+target_include_directories(bergamot-translator-worker
+ PRIVATE ${CMAKE_SOURCE_DIR}/src/translator
+ PRIVATE ${CMAKE_SOURCE_DIR}
+)
+
+# This compile definition is required for generating binding code properly
+target_compile_definitions(bergamot-translator-worker PRIVATE WASM_BINDINGS)
+target_compile_options(bergamot-translator-worker PRIVATE ${WASM_COMPILE_FLAGS})
+target_link_options(bergamot-translator-worker PRIVATE ${WASM_LINK_FLAGS})
+target_link_options(bergamot-translator-worker PRIVATE --extern-pre-js=${CMAKE_CURRENT_BINARY_DIR}/project_version.js)
+
+set_target_properties(bergamot-translator-worker PROPERTIES
+ SUFFIX ".js"
+ RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
+ )
+
+target_link_libraries(bergamot-translator-worker bergamot-translator)
diff --git a/inference/wasm/bindings/response_bindings.cpp b/inference/wasm/bindings/response_bindings.cpp
new file mode 100644
index 000000000..51a46ab84
--- /dev/null
+++ b/inference/wasm/bindings/response_bindings.cpp
@@ -0,0 +1,32 @@
+/*
+ * Bindings for Response class
+ *
+ */
+
+#include
+
+#include
+
+#include "response.h"
+
+using Response = marian::bergamot::Response;
+using ByteRange = marian::bergamot::ByteRange;
+
+using namespace emscripten;
+
+// Binding code
+EMSCRIPTEN_BINDINGS(byte_range) {
+ value_object("ByteRange").field("begin", &ByteRange::begin).field("end", &ByteRange::end);
+}
+
+EMSCRIPTEN_BINDINGS(response) {
+ class_("Response")
+ .constructor<>()
+ .function("size", &Response::size)
+ .function("getOriginalText", &Response::getOriginalText)
+ .function("getTranslatedText", &Response::getTranslatedText)
+ .function("getSourceSentence", &Response::getSourceSentenceAsByteRange)
+ .function("getTranslatedSentence", &Response::getTargetSentenceAsByteRange);
+
+ register_vector("VectorResponse");
+}
diff --git a/inference/wasm/bindings/response_options_bindings.cpp b/inference/wasm/bindings/response_options_bindings.cpp
new file mode 100644
index 000000000..06c152a7c
--- /dev/null
+++ b/inference/wasm/bindings/response_options_bindings.cpp
@@ -0,0 +1,21 @@
+/*
+ * Bindings for ResponseOptions class
+ *
+ */
+
+#include
+
+#include "response_options.h"
+
+using ResponseOptions = marian::bergamot::ResponseOptions;
+
+using namespace emscripten;
+
+// Binding code
+EMSCRIPTEN_BINDINGS(response_options) {
+ value_object("ResponseOptions")
+ .field("qualityScores", &ResponseOptions::qualityScores)
+ .field("alignment", &ResponseOptions::alignment)
+ .field("html", &ResponseOptions::HTML);
+ register_vector("VectorResponseOptions");
+}
diff --git a/inference/wasm/bindings/service_bindings.cpp b/inference/wasm/bindings/service_bindings.cpp
new file mode 100644
index 000000000..54675a498
--- /dev/null
+++ b/inference/wasm/bindings/service_bindings.cpp
@@ -0,0 +1,93 @@
+/*
+ * Bindings for Service class
+ */
+
+#include
+
+#include "service.h"
+
+using namespace emscripten;
+
+using BlockingService = marian::bergamot::BlockingService;
+using TranslationModel = marian::bergamot::TranslationModel;
+using AlignedMemory = marian::bergamot::AlignedMemory;
+using MemoryBundle = marian::bergamot::MemoryBundle;
+
+val getByteArrayView(AlignedMemory& alignedMemory) {
+ return val(typed_memory_view(alignedMemory.size(), alignedMemory.as()));
+}
+
+EMSCRIPTEN_BINDINGS(aligned_memory) {
+ class_("AlignedMemory")
+ .constructor()
+ .function("size", &AlignedMemory::size)
+ .function("getByteArrayView", &getByteArrayView);
+
+ register_vector("AlignedMemoryList");
+}
+
+// When source and target vocab files are same, only one memory object is passed from JS to
+// avoid allocating memory twice for the same file. However, the constructor of the Service
+// class still expects 2 entries in this case, where each entry has the shared ownership of the
+// same AlignedMemory object. This function prepares these smart pointer based AlignedMemory objects
+// for unique AlignedMemory objects passed from JS.
+std::vector> prepareVocabsSmartMemories(std::vector& vocabsMemories) {
+ auto sourceVocabMemory = std::make_shared(std::move(*(vocabsMemories[0])));
+ std::vector> vocabsSmartMemories;
+ vocabsSmartMemories.push_back(sourceVocabMemory);
+ if (vocabsMemories.size() == 2) {
+ auto targetVocabMemory = std::make_shared(std::move(*(vocabsMemories[1])));
+ vocabsSmartMemories.push_back(std::move(targetVocabMemory));
+ } else {
+ vocabsSmartMemories.push_back(sourceVocabMemory);
+ }
+ return vocabsSmartMemories;
+}
+
+MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shortlistMemory,
+ std::vector uniqueVocabsMemories,
+ AlignedMemory* qualityEstimatorMemory) {
+ MemoryBundle memoryBundle;
+ memoryBundle.models.emplace_back(std::move(*modelMemory));
+ memoryBundle.shortlist = std::move(*shortlistMemory);
+ memoryBundle.vocabs = std::move(prepareVocabsSmartMemories(uniqueVocabsMemories));
+ if (qualityEstimatorMemory != nullptr) {
+ memoryBundle.qualityEstimatorMemory = std::move(*qualityEstimatorMemory);
+ }
+
+ return memoryBundle;
+}
+
+// This allows only shared_ptrs to be operational in JavaScript, according to emscripten.
+// https://emscripten.org/docs/porting/connecting_cpp_and_javascript/embind.html#smart-pointers
+std::shared_ptr TranslationModelFactory(const std::string& config, AlignedMemory* model,
+ AlignedMemory* shortlist, std::vector vocabs,
+ AlignedMemory* qualityEstimator) {
+ MemoryBundle memoryBundle = prepareMemoryBundle(model, shortlist, vocabs, qualityEstimator);
+ return std::make_shared(config, std::move(memoryBundle));
+}
+
+EMSCRIPTEN_BINDINGS(translation_model) {
+ class_("TranslationModel")
+ .smart_ptr_constructor("TranslationModel", &TranslationModelFactory, allow_raw_pointers());
+}
+
+EMSCRIPTEN_BINDINGS(blocking_service_config) {
+ value_object("BlockingServiceConfig")
+ .field("cacheSize", &BlockingService::Config::cacheSize);
+}
+
+std::shared_ptr BlockingServiceFactory(const BlockingService::Config& config) {
+ auto copy = config;
+ copy.logger.level = "critical";
+ return std::make_shared(copy);
+}
+
+EMSCRIPTEN_BINDINGS(blocking_service) {
+ class_("BlockingService")
+ .smart_ptr_constructor("BlockingService", &BlockingServiceFactory)
+ .function("translate", &BlockingService::translateMultiple)
+ .function("translateViaPivoting", &BlockingService::pivotMultiple);
+
+ register_vector("VectorString");
+}
diff --git a/inference/wasm/import-gemm-module.js b/inference/wasm/import-gemm-module.js
new file mode 100644
index 000000000..6430096dc
--- /dev/null
+++ b/inference/wasm/import-gemm-module.js
@@ -0,0 +1,46 @@
+
+/* Use an optimized gemm implementation if available, otherwise use the fallback
+ * implementation.
+ */
+function createWasmGemm() {
+ // A map of expected gemm function to the corresponding fallback gemm function names.
+ const GEMM_TO_FALLBACK_FUNCTIONS_MAP = {
+ "int8_prepare_a": "int8PrepareAFallback",
+ "int8_prepare_b": "int8PrepareBFallback",
+ "int8_prepare_b_from_transposed": "int8PrepareBFromTransposedFallback",
+ "int8_prepare_b_from_quantized_transposed": "int8PrepareBFromQuantizedTransposedFallback",
+ "int8_prepare_bias": "int8PrepareBiasFallback",
+ "int8_multiply_and_add_bias": "int8MultiplyAndAddBiasFallback",
+ "int8_select_columns_of_b": "int8SelectColumnsOfBFallback"
+ };
+
+ // Name of the optimized gemm implementation.
+ const OPTIMIZED_GEMM = "mozIntGemm";
+
+ const optimizedGemmModule = WebAssembly[OPTIMIZED_GEMM];
+ if (!optimizedGemmModule) {
+ return fallbackGemm(GEMM_TO_FALLBACK_FUNCTIONS_MAP);
+ }
+
+ const optimizedGemmModuleExports = new WebAssembly.Instance(optimizedGemmModule(), {"": {memory: wasmMemory}}).exports;
+ for (let key in GEMM_TO_FALLBACK_FUNCTIONS_MAP) {
+ if (!optimizedGemmModuleExports[key]) {
+ return fallbackGemm(GEMM_TO_FALLBACK_FUNCTIONS_MAP);
+ }
+ }
+ console.log(`Using optimized gemm (${OPTIMIZED_GEMM}) implementation`);
+ return optimizedGemmModuleExports;
+}
+
+// Return the fallback gemm implementation.
+function fallbackGemm(gemmToFallbackFunctionsMap) {
+ // The fallback gemm implementation
+ const FALLBACK_GEMM = "asm";
+
+ let fallbackGemmModuleExports = {};
+ for (let key in gemmToFallbackFunctionsMap) {
+ fallbackGemmModuleExports[key] = (...a) => Module[FALLBACK_GEMM][gemmToFallbackFunctionsMap[key]](...a)
+ }
+ console.log(`Using fallback gemm implementation`);
+ return fallbackGemmModuleExports;
+}
diff --git a/inference/wasm/module/main.js b/inference/wasm/module/main.js
new file mode 100644
index 000000000..d712a2199
--- /dev/null
+++ b/inference/wasm/module/main.js
@@ -0,0 +1,21 @@
+import * as readline from 'node:readline/promises';
+import {stdin, stdout} from 'node:process';
+import {BatchTranslator} from "./translator.js";
+
+const rl = readline.createInterface({input: stdin, output: stdout});
+
+const translator = new BatchTranslator();
+
+for await (const line of rl) {
+ const response = await translator.translate({
+ from: "en",
+ to: "es",
+ text: line,
+ html: false,
+ qualityScores: false
+ });
+
+ console.log(response.target.text);
+}
+
+translator.delete();
diff --git a/inference/wasm/module/package.json b/inference/wasm/module/package.json
new file mode 100644
index 000000000..f30464665
--- /dev/null
+++ b/inference/wasm/module/package.json
@@ -0,0 +1,39 @@
+{
+ "name": "@browsermt/bergamot-translator",
+ "version": "0.4.9",
+ "description": "Cross platform C++ library focusing on optimized machine translation on the consumer-grade device.",
+ "homepage": "https://github.com/browsermt/bergamot-translator#readme",
+ "repository": {
+ "type": "git",
+ "url": "git+ssh://git@github.com/browsermt/bergamot-translator.git"
+ },
+ "keywords": [
+ "machine",
+ "translation"
+ ],
+ "author": "",
+ "license": "MPL-2.0",
+ "bugs": {
+ "url": "https://github.com/browsermt/bergamot-translator/issues"
+ },
+ "type": "module",
+ "main": "translator.js",
+ "scripts": {
+ "test": "echo \"Error: no test specified\" && exit 1"
+ },
+ "files": [
+ "worker/bergamot-translator-worker.js",
+ "worker/bergamot-translator-worker.wasm",
+ "worker/translator-worker.js",
+ "translator.js",
+ "main.js"
+ ],
+ "config": {
+ "emscripten_version": "3.1.8"
+ },
+ "scripts": {
+ "prepare": "test -f worker/bergamot-translator-worker.wasm || npm run build",
+ "build": "mkdir -p ../../build-wasm && docker run --rm -v $(realpath ../../):/src -v $(realpath ../../build-wasm):/build -v $(pwd)/worker:/dst -w /build emscripten/emsdk:$npm_package_config_emscripten_version sh -c \"emcmake cmake -DCOMPILE_WASM=on -DWORMHOLE=off /src && emmake make -j2 && cp bergamot-translator-worker.wasm bergamot-translator-worker.js /dst\"",
+ "test": "echo \"Hello world!\" | node main.js"
+ }
+}
diff --git a/inference/wasm/module/translator.js b/inference/wasm/module/translator.js
new file mode 100644
index 000000000..f27c07653
--- /dev/null
+++ b/inference/wasm/module/translator.js
@@ -0,0 +1,879 @@
+/**
+ * @typedef {Object} TranslationRequest
+ * @property {String} from
+ * @property {String} to
+ * @property {String} text
+ * @property {Boolean} html
+ * @property {Integer?} priority
+ */
+
+/**
+ * @typedef {Object} TranslationResponse
+ * @property {TranslationRequest} request
+ * @property {{text: string}} target
+ */
+
+/**
+ * NodeJS compatibility, a thin WebWorker layer around node:worker_threads.
+ */
+if (!(typeof window !== 'undefined' && window.Worker)) {
+ globalThis.Worker = class {
+ #worker;
+
+ constructor(url) {
+ this.#worker = new Promise(async (accept) => {
+ const {Worker} = await import(/* webpackIgnore: true */ 'node:worker_threads');
+ accept(new Worker(url));
+ });
+ }
+
+ addEventListener(eventName, callback) {
+ this.#worker.then(worker => worker.on(eventName, (data) => callback({data})));
+ }
+
+ postMessage(message) {
+ this.#worker.then(worker => worker.postMessage(message));
+ }
+
+ terminate() {
+ this.#worker.then(worker => worker.terminate());
+ }
+ }
+}
+
+/**
+ * Thrown when a pending translation is replaced by another newer pending
+ * translation.
+ */
+export class SupersededError extends Error {}
+
+
+/**
+ * Thrown when a translation was removed from the queue.
+ */
+export class CancelledError extends Error {}
+
+
+/**
+ * Wrapper around bergamot-translator loading and model management.
+ */
+ export class TranslatorBacking {
+
+ /**
+ * @param {{
+ * cacheSize?: number,
+ * useNativeIntGemm?: boolean,
+ * downloadTimeout?: number,
+ * registryUrl?: string
+ * pivotLanguage?: string?
+ * onerror?: (err: Error)
+ * }} options
+ */
+ constructor(options) {
+ this.options = options || {};
+
+ this.registryUrl = this.options.registryUrl || 'https://bergamot.s3.amazonaws.com/models/index.json';
+
+ this.downloadTimeout = 'downloadTimeout' in this.options ? parseInt(this.options.downloadTimeout) : 60000;
+
+ /**
+ * registry of all available models and their urls
+ * @type {Promise}
+ */
+ this.registry = this.loadModelRegistery();
+
+ /**
+ * Map of downloaded model data files as buffers per model.
+ * @type {Map<{from:string,to:string}, Promise