diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..b7f746ef1 --- /dev/null +++ b/.clang-format @@ -0,0 +1,5 @@ +BasedOnStyle: Google + +# Maximum line length 80 is too low even for 1080p monitor. @XapaJIaMnu +# personally would like 120. +ColumnLimit: 120 diff --git a/.clang-format-ignore b/.clang-format-ignore new file mode 100644 index 000000000..50795cacb --- /dev/null +++ b/.clang-format-ignore @@ -0,0 +1,4 @@ +3rd_party +wasm/test_page +src/translator/aligned.h +src/translator/pcqueue.h diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 000000000..bdbadb624 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,32 @@ +Checks: > + .*, + bugprone-*, + concurrency-*, + google-*, + portability-*, + performance-*, + clang-analyzer-*, + readability-*, + -readability-implicit-bool-conversion, + -readability-isolate-declaration, + -readability-uppercase-literal-suffix, + misc-*, + -misc-noexcept*, + modernize-*, + -modernize-deprecated-headers, + -modernize-use-nodiscard, + -modernize-raw-string-literal, + -modernize-return-braced-init-list, + -modernize-use-equals-delete, + -modernize-use-trailing-return-type, + + + +CheckOptions: + - { key: readability-identifier-naming.ClassCase, value: CamelCase } + - { key: readability-identifier-naming.ClassMethodCase, value: camelBack } + - { key: readability-identifier-naming.VariableCase, value: camelBack } + - { key: readability-identifier-naming.FunctionCase, value: camelBack } + - { key: readability-identifier-naming.PrivateMemberSuffix, value: _ } + - { key: readability-identifier-naming.ParameterCase, value: camelBack } + diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 26f6f4418..fa4f321cf 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,5 +1,31 @@ +# Firefox Translations review group +.dockerignore @mozilla/firefox-translations +.github @mozilla/firefox-translations +.gitignore @mozilla/firefox-translations +.gitmodules @mozilla/firefox-translations +docker @mozilla/firefox-translations +docs @mozilla/firefox-translations +utils @mozilla/firefox-translations +CODE_OF_CONDUCT.md @mozilla/firefox-translations +LICENSE @mozilla/firefox-translations +poetry.lock @mozilla/firefox-translations +pyproject.toml @mozilla/firefox-translations +README.md @mozilla/firefox-translations +Taskfile.yml @mozilla/firefox-translations + +# Translations Training review group +configs @mozilla/translations-training +pipeline @mozilla/translations-training +snakemake @mozilla/translations-training +tests @mozilla/translations-training +tracking @mozilla/translations-training + +# Translations Inference review group +inference-engine @mozilla/translations-inference + # Taskcluster pipeline related files. Changes to these ought to be reviewed by # RelEng to watch for security issues and best practices. These should also # be reviewed by people familiar with the pipeline itself. -.taskcluster.yml @mozilla/releng -taskcluster @mozilla/releng +.taskcluster.yml @mozilla/releng @mozilla/translations-training +taskcluster @mozilla/releng @mozilla/translations-training + diff --git a/.gitmodules b/.gitmodules index f1813a444..f88221eb5 100644 --- a/.gitmodules +++ b/.gitmodules @@ -16,3 +16,29 @@ [submodule "3rd_party/preprocess"] path = 3rd_party/preprocess url = https://github.com/kpu/preprocess.git +[submodule "inference/3rd_party/ssplit-cpp"] + path = inference/3rd_party/ssplit-cpp + url = https://github.com/browsermt/ssplit-cpp +# This is the same dependency and repository as `3rd_party/browsermt-marian-dev` below. +# +# When forking `inference-engine` into to this project, I made an earnest attempt to utilize the preexisting +# `3rd_party/browsermt-marian-dev` submodule within `inference-engine`. Unfortunately, I ran into several roadblocks: +# +# 1) I cannot directly add `3rd_party/browsermt-marian-dev` as a cmake subdirectory because cmake is aware that +# this path is not a subdirectory of the `inference-engine` project root. +# +# 2) Symbolic links do not appear to work for git submodule direcotires the way that they do for regular directories. +# Even if the symbolic link had linked correctly, it may have still failed due to the considerations of 1). +# +# 3) I tried using cmake to copy the files from `3rd_party/browsermt-marian-dev` into `inference-engine/3rd_party/browsermt-marian-dev` +# at build time, which would ensure that there is no duplicate reference to the URL in this file, however the upstream dependency itself +# has hard-coded expectations that the `.git` directory is only one level up, which appears to work correctly for the way git submodules are +# configured, but does not work if the files are copied over to a regular directory deeper in the repository's directory tree. +# +# It may be possible to remove `3rd_party/browsermt-marian-dev` to instead use `inference-engine/3rd-party/browsermt-marian-dev` everywhere +# within this repository, but I will leave that for a future commit if there is a need to do so. +# +# TODO(#869) +[submodule "inference/3rd_party/browsermt-marian-dev"] + path = inference/3rd_party/browsermt-marian-dev + url = https://github.com/browsermt/marian-dev diff --git a/Taskfile.yml b/Taskfile.yml index c767924e1..745c5f05c 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -75,6 +75,30 @@ tasks: cmds: - poetry run opuscleaner-server serve --host=0.0.0.0 --port=8000 + inference-clean: + desc: Clean build artifacts from the inference directory. + cmds: + - >- + task docker-run -- ./inference/scripts/clean.sh + + inference-build: + desc: Build inference engine. + cmds: + - >- + task docker-run -- ./inference/scripts/build-local.sh + + inference-test: + desc: Run inference tests. + cmds: + - >- + task docker-run -- ./inference/scripts/unit-tests.sh + + inference-build-wasm: + desc: Build inference engine WASM. + cmds: + - >- + task docker-run -- ./inference/scripts/build-wasm.sh + lint-black: desc: Checks the styling of the Python code with Black. deps: [poetry-install-black] diff --git a/inference/.gitignore b/inference/.gitignore new file mode 100644 index 000000000..78202d979 --- /dev/null +++ b/inference/.gitignore @@ -0,0 +1,30 @@ +# vim temporary files +*.swp +*.swo + +# CMake +CMakeLists.txt.user +CMakeCache.txt +CMakeFiles +CMakeScripts +Testing +Makefile +cmake_install.cmake +install_manifest.txt +compile_commands.json +CTestTestfile.cmake +_deps + + +wasm/test_page/node_modules +/build +/build-local +/build-native +/build-wasm +/emsdk +models +wasm/module/worker/bergamot-translator-worker.* +wasm/module/browsermt-bergamot-translator-*.tgz + +# VSCode +.vscode diff --git a/inference/3rd_party/CMakeLists.txt b/inference/3rd_party/CMakeLists.txt new file mode 100644 index 000000000..62ba02722 --- /dev/null +++ b/inference/3rd_party/CMakeLists.txt @@ -0,0 +1,32 @@ +# browsermt-marian-dev is tested elsewhere in both paths, turning off here. +set(COMPILE_TESTS OFF) +add_subdirectory(browsermt-marian-dev EXCLUDE_FROM_ALL) + +if(COMPILE_WASM) + # This is a bad way of adding compilation flags. Will be improved soon. + add_compile_options(${WASM_COMPILE_FLAGS}) + add_link_options(${WASM_LINK_FLAGS}) +endif(COMPILE_WASM) + +add_subdirectory(ssplit-cpp EXCLUDE_FROM_ALL) + +# Add include directories for 3rd party targets to be able to use it anywhere in the +# project without explicitly specifying their include directories. Once they +# fixe this problem, it can be removed. +get_property(INCDIRS DIRECTORY browsermt-marian-dev/src PROPERTY INCLUDE_DIRECTORIES) +target_include_directories(marian PUBLIC ${INCDIRS}) + +get_property(INCLUDE_DIRECTORIES DIRECTORY ssplit-cpp/src PROPERTY INCLUDE_DIRECTORIES) +target_include_directories(ssplit PUBLIC ${INCLUDE_DIRECTORIES}) + +get_property(COMPILE_DEFINITIONS DIRECTORY browsermt-marian-dev PROPERTY COMPILE_DEFINITIONS) +target_compile_definitions(marian PUBLIC ${COMPILE_DEFINITIONS}) + +get_property(COMPILE_OPTIONS DIRECTORY browsermt-marian-dev PROPERTY COMPILE_OPTIONS) +target_compile_options(marian PUBLIC ${COMPILE_OPTIONS}) + +# Compilation flags +get_directory_property(CMAKE_C_FLAGS DIRECTORY browsermt-marian-dev DEFINITION CMAKE_C_FLAGS) +get_directory_property(CMAKE_CXX_FLAGS DIRECTORY browsermt-marian-dev DEFINITION CMAKE_CXX_FLAGS) +set(CMAKE_C_FLAGS ${CMAKE_C_FLAGS} PARENT_SCOPE) +set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} PARENT_SCOPE) diff --git a/inference/3rd_party/browsermt-marian-dev b/inference/3rd_party/browsermt-marian-dev new file mode 160000 index 000000000..2781d735d --- /dev/null +++ b/inference/3rd_party/browsermt-marian-dev @@ -0,0 +1 @@ +Subproject commit 2781d735d4a10dca876d61be587afdab2726293c diff --git a/inference/3rd_party/ssplit-cpp b/inference/3rd_party/ssplit-cpp new file mode 160000 index 000000000..a311f9865 --- /dev/null +++ b/inference/3rd_party/ssplit-cpp @@ -0,0 +1 @@ +Subproject commit a311f9865ade34db1e8e080e6cc146f55dafb067 diff --git a/inference/BERGAMOT_VERSION b/inference/BERGAMOT_VERSION new file mode 100644 index 000000000..a423f7f06 --- /dev/null +++ b/inference/BERGAMOT_VERSION @@ -0,0 +1 @@ +v0.4.5 diff --git a/inference/CMakeLists.txt b/inference/CMakeLists.txt new file mode 100644 index 000000000..febff3e6e --- /dev/null +++ b/inference/CMakeLists.txt @@ -0,0 +1,188 @@ +cmake_minimum_required(VERSION 3.5.1) +set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) + +if (POLICY CMP0074) + cmake_policy(SET CMP0074 NEW) # CMake 3.12 +endif () + +if (POLICY CMP0077) + cmake_policy(SET CMP0077 NEW) +endif() + +project(bergamot_translator CXX C) + +# Retrieve the parent-directory path of PROJECT_SOURCE_DIR and assign that to REPOSITORY_ROOT_DIR. +cmake_path(GET PROJECT_SOURCE_DIR PARENT_PATH REPOSITORY_ROOT_DIR) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Generate a compile_commands.json in the build directory. The compile commands allow +# code editors to understand the build process and provide static analysis of the code. +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Note that with CMake MSVC build, the option CMAKE_BUILD_TYPE is automatically derived from the key +# 'configurationType' in CMakeSettings.json configurations +if(NOT CMAKE_BUILD_TYPE) + message(WARNING "CMAKE_BUILD_TYPE not set; setting to Release") + set(CMAKE_BUILD_TYPE "Release") +endif() + +if(NOT COMPILE_WASM) + # Setting BUILD_ARCH to native invokes CPU intrinsic detection logic below. + # Prevent invoking that logic for WASM builds. + set(BUILD_ARCH native CACHE STRING "Compile for this CPU architecture.") + + # Unfortunately MSVC supports a limited subset of BUILD_ARCH flags. Instead try to guess + # what architecture we can compile to reading BUILD_ARCH and mapping it to MSVC values + # references: https://clang.llvm.org/docs/UsersManual.html https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html https://gcc.gnu.org/onlinedocs/gcc-4.8.5/gcc/i386-and-x86-64-Options.html + # https://docs.microsoft.com/en-us/cpp/build/reference/arch-x86?redirectedfrom=MSDN&amp;view=vs-2019&view=msvc-170 https://devblogs.microsoft.com/oldnewthing/20201026-00/?p=104397 + # This is by no means an exhaustive list but should match the most common flags Linux programmers expect to parse to MSVC + if(MSVC) + if(BUILD_ARCH STREQUAL "native") # avx2 is good default for native. Very few desktop systems support avx512 + set(MSVC_BUILD_ARCH "/arch:AVX2") + elseif(BUILD_ARCH STREQUAL "skylake-avx512" OR BUILD_ARCH STREQUAL "cannonlake" OR BUILD_ARCH STREQUAL "x86-64-v4" OR BUILD_ARCH STREQUAL "tigerlake" OR BUILD_ARCH STREQUAL "cooperlake" OR BUILD_ARCH STREQUAL "cascadelake") + set(MSVC_BUILD_ARCH "/arch:AVX512") + elseif(BUILD_ARCH STREQUAL "core-avx2" OR BUILD_ARCH STREQUAL "haswell" OR BUILD_ARCH STREQUAL "x86-64-v3" OR BUILD_ARCH STREQUAL "broadwell" OR BUILD_ARCH STREQUAL "skylake") + set(MSVC_BUILD_ARCH "/arch:AVX2") + elseif(BUILD_ARCH STREQUAL "sandybridge" OR BUILD_ARCH STREQUAL "corei7-avx" OR BUILD_ARCH STREQUAL "core-avx-i" OR BUILD_ARCH STREQUAL "ivybridge") + set(MSVC_BUILD_ARCH "/arch:AVX") + elseif(BUILD_ARCH STREQUAL "nehalem" OR BUILD_ARCH STREQUAL "westmere" OR BUILD_ARCH STREQUAL "x86-64-v2" OR BUILD_ARCH STREQUAL "corei7" OR BUILD_ARCH STREQUAL "core2") + set(MSVC_BUILD_ARCH "/arch:SSE2") # This is MSVC default. We won't go down to SSE because we don't support that hardware at all with intgemm. Marian recommends to only go down to SSE4.1 at most + else() + message(WARNING "Unknown BUILD_ARCH ${BUILD_ARCH} provided. Default to SSE2 for Windows build") + set(MSVC_BUILD_ARCH "/arch:SSE2") + endif() + endif(MSVC) +endif() + +#MSVC can't seem to pick up correct flags otherwise: +if(MSVC) + add_definitions(-DUSE_SSE2=1) # Supposed to fix something in the sse_mathfun.h but not sure it does + set(INTRINSICS ${MSVC_BUILD_ARCH}) # ARCH we're targetting on win32. @TODO variable + + set(CMAKE_CXX_FLAGS "/EHsc /DWIN32 /D_WINDOWS /DUNICODE /D_UNICODE /D_CRT_NONSTDC_NO_WARNINGS /D_CRT_SECURE_NO_WARNINGS /bigobj") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} /MT /O2 ${INTRINSICS} /MP /GL /DNDEBUG") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} /MTd /Od /Ob0 ${INTRINSICS} /RTC1 /Zi /D_DEBUG") + + # ignores warning LNK4049: locally defined symbol free imported - this comes from zlib + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /ignore:4049") + set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRT") + set(CMAKE_EXE_LINKER_FLAGS_DEBUG "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRTD") + set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} /LTCG:incremental") +endif(MSVC) + +include(CMakeDependentOption) + +# Project specific cmake options +option(COMPILE_WASM "Compile for WASM" OFF) +cmake_dependent_option(USE_WASM_COMPATIBLE_SOURCE "Use wasm compatible sources" OFF "NOT COMPILE_WASM" ON) + +# WASM disables a million libraries, which also includes the unit test-library. +cmake_dependent_option(COMPILE_UNIT_TESTS "Compile unit tests" OFF "USE_WASM_COMPATIBLE_SOURCE" ON) +option(COMPILE_TESTS "Compile bergamot-tests" OFF) +cmake_dependent_option(ENABLE_CACHE_STATS "Enable stats on cache" ON "COMPILE_TESTS" OFF) + + +# Set 3rd party submodule specific cmake options for this project +SET(COMPILE_CUDA OFF CACHE BOOL "Compile GPU version") +SET(USE_SENTENCEPIECE ON CACHE BOOL "Download and compile SentencePiece") +SET(USE_STATIC_LIBS ON CACHE BOOL "Link statically against non-system libs") +SET(SSPLIT_COMPILE_LIBRARY_ONLY ON CACHE BOOL "Do not compile ssplit tests") +if (USE_WASM_COMPATIBLE_SOURCE) + SET(COMPILE_LIBRARY_ONLY ON CACHE BOOL "Build only the Marian library and exclude all executables.") + SET(USE_MKL OFF CACHE BOOL "Compile with MKL support") + # # Setting the ssplit-cpp submodule specific cmake options for wasm + SET(SSPLIT_USE_INTERNAL_PCRE2 ON CACHE BOOL "Use internal PCRE2 instead of system PCRE2") +endif() + +# Documentation: https://cliutils.gitlab.io/modern-cmake/chapters/projects/submodule.html +# Ensures the submodules are set correctly during a build. +find_package(Git QUIET) +if(GIT_FOUND AND EXISTS "${REPOSITORY_ROOT_DIR}/.git") +# Update submodules as needed + option(GIT_SUBMODULE "Check submodules during build" ON) + if(GIT_SUBMODULE) + message(STATUS "Submodule update") + execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE GIT_SUBMOD_RESULT) + if(NOT GIT_SUBMOD_RESULT EQUAL "0") + message(FATAL_ERROR "git submodule update --init failed with ${GIT_SUBMOD_RESULT}, please checkout submodules") + endif() + endif() +endif() + +# Project versioning +include(GetVersionFromFile) +message(STATUS "Project name: ${PROJECT_NAME}") +message(STATUS "Project version: ${PROJECT_VERSION_STRING_FULL}") + +if(COMPILE_WASM) + # See https://github.com/emscripten-core/emscripten/blob/main/src/settings.js + list(APPEND WASM_COMPILE_FLAGS + -O3 + # Preserve whitespaces in JS even for release builds; this doesn't increase wasm binary size + $<$:-g1> + # Relevant Debug info only for release with debug builds as this increases wasm binary size + $<$:-g2> + -fPIC + -mssse3 + -msimd128 + # -fno-exceptions # Can't do that because spdlog uses exceptions + -sDISABLE_EXCEPTION_CATCHING=1 + -sSTRICT=1 + ) + list(APPEND WASM_LINK_FLAGS + -O3 + # Preserve whitespaces in JS even for release builds; this doesn't increase wasm binary size + $<$:-g1> + # Relevant Debug info only for release with debug builds as this increases wasm binary size + $<$:-g2> + -lembind + # Save some code, and some speed + -sASSERTIONS=0 + -sDISABLE_EXCEPTION_CATCHING=1 + # the intgemm functions we call will be undefined since these are linked at + # runtime by our own javascript. + -sLLD_REPORT_UNDEFINED + -sERROR_ON_UNDEFINED_SYMBOLS=0 + # Cause we can! + -sSTRICT=1 + # You know we need it + -sALLOW_MEMORY_GROWTH=1 + -sENVIRONMENT=web,worker + # No need to call main(), there's nothing there. + -sINVOKE_RUN=0 + # No need for filesystem code in the generated Javascript + -sFILESYSTEM=0 + # If you turn this on, it will mangle names which makes the dynamic linking hard. + -sDECLARE_ASM_MODULE_EXPORTS=0 + # Export all of the intgemm functions in case we need to fall back to using the embedded intgemm + -sEXPORTED_FUNCTIONS=[_int8PrepareAFallback,_int8PrepareBFallback,_int8PrepareBFromTransposedFallback,_int8PrepareBFromQuantizedTransposedFallback,_int8PrepareBiasFallback,_int8MultiplyAndAddBiasFallback,_int8SelectColumnsOfBFallback] + # Necessary for mozintgemm linking. This prepares the `wasmMemory` variable ahead of time as + # opposed to delegating that task to the wasm binary itself. This way we can link MozIntGEMM + # module to the same memory as the main bergamot-translator module. + -sIMPORTED_MEMORY=1 + # Dynamic execution is either frowned upon or blocked inside browser extensions + -sDYNAMIC_EXECUTION=0 + ) +endif(COMPILE_WASM) + +# Needs to be enabled before including the folder containing tests (src/tests) +if(COMPILE_TESTS) + enable_testing() +endif(COMPILE_TESTS) + +add_subdirectory(3rd_party) +add_subdirectory(src) + +if(COMPILE_WASM) + add_subdirectory(wasm) +endif(COMPILE_WASM) + +option(COMPILE_PYTHON "Compile python bindings. Intended to be activated with setup.py" OFF) +if(COMPILE_PYTHON) + add_subdirectory(bindings/python) +endif(COMPILE_PYTHON) + diff --git a/inference/cmake/GetVersionFromFile.cmake b/inference/cmake/GetVersionFromFile.cmake new file mode 100644 index 000000000..47c35bc23 --- /dev/null +++ b/inference/cmake/GetVersionFromFile.cmake @@ -0,0 +1,60 @@ +## +# This CMake modules sets the project version from a version file. +# +# The module sets the following variables: +# +# * PROJECT_VERSION_STRING +# * PROJECT_VERSION_STRING_FULL +# * PROJECT_VERSION_MAJOR +# * PROJECT_VERSION_MINOR +# * PROJECT_VERSION_PATCH +# * PROJECT_VERSION_TWEAK +# * PROJECT_VERSION_GIT_SHA +# +# This module is public domain, use it as it fits you best. +## + +# Get full string version from file +if(PROJECT_VERSION_FILE) + file(STRINGS ${PROJECT_VERSION_FILE} PROJECT_VERSION_STRING) +else() + file(STRINGS ${CMAKE_CURRENT_SOURCE_DIR}/BERGAMOT_VERSION PROJECT_VERSION_STRING) +endif() + +# Get current commit SHA from git +execute_process(COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE PROJECT_VERSION_GIT_SHA + OUTPUT_STRIP_TRAILING_WHITESPACE) + +# Get partial versions into a list +string(REGEX MATCHALL "-.*$|[0-9]+" PROJECT_PARTIAL_VERSION_LIST + ${PROJECT_VERSION_STRING}) + +# Set the version numbers +list(GET PROJECT_PARTIAL_VERSION_LIST 0 PROJECT_VERSION_MAJOR) +list(GET PROJECT_PARTIAL_VERSION_LIST 1 PROJECT_VERSION_MINOR) +list(GET PROJECT_PARTIAL_VERSION_LIST 2 PROJECT_VERSION_PATCH) + +# The tweak part is optional, so check if the list contains it +list(LENGTH PROJECT_PARTIAL_VERSION_LIST PROJECT_PARTIAL_VERSION_LIST_LEN) +if(PROJECT_PARTIAL_VERSION_LIST_LEN GREATER 3) + list(GET PROJECT_PARTIAL_VERSION_LIST 3 PROJECT_VERSION_TWEAK) + string(SUBSTRING ${PROJECT_VERSION_TWEAK} 1 -1 PROJECT_VERSION_TWEAK) +endif() + +# Unset the list +unset(PROJECT_PARTIAL_VERSION_LIST) + +# Set full project version string +set(PROJECT_VERSION_STRING_FULL + ${PROJECT_VERSION_STRING}+${PROJECT_VERSION_GIT_SHA}) + +# Print all variables for debugging +#message(STATUS ${PROJECT_VERSION_STRING_FULL}) +#message(STATUS ${PROJECT_VERSION_STRING}) +#message(STATUS ${PROJECT_VERSION_MAJOR}) +#message(STATUS ${PROJECT_VERSION_MINOR}) +#message(STATUS ${PROJECT_VERSION_PATCH}) +#message(STATUS ${PROJECT_VERSION_TWEAK}) +#message(STATUS ${PROJECT_VERSION_GIT_SHA}) diff --git a/inference/examples/run-native.sh b/inference/examples/run-native.sh new file mode 100644 index 000000000..84e1302f0 --- /dev/null +++ b/inference/examples/run-native.sh @@ -0,0 +1,19 @@ +# In source-root folder + +# Obtain an example model from the web. +mkdir -p models +wget --quiet --continue --directory models/ \ + https://data.statmt.org/bergamot/models/deen/ende.student.tiny11.v2.93821e13b3c511b5.tar.gz +(cd models && tar -xzf ende.student.tiny11.v2.93821e13b3c511b5.tar.gz) + +# Patch the config-files generated from marian for use in bergamot. +python3 bergamot-translator-tests/tools/patch-marian-for-bergamot.py \ + --config-path models/ende.student.tiny11/config.intgemm8bitalpha.yml \ + --ssplit-prefix-file $(realpath 3rd_party/ssplit-cpp/nonbreaking_prefixes/nonbreaking_prefix.en) + +# Patched config file will be available with .bergamot.yml suffix. +CONFIG=models/ende.student.tiny11/config.intgemm8bitalpha.yml.bergamot.yml + +build/app/bergamot --model-config-paths $CONFIG --cpu-threads 4 <<< "Hello World!" +# Hallo Welt! + diff --git a/inference/patches/01-marian-fstream-for-macos.patch b/inference/patches/01-marian-fstream-for-macos.patch new file mode 100644 index 000000000..6b521ba7e --- /dev/null +++ b/inference/patches/01-marian-fstream-for-macos.patch @@ -0,0 +1,13 @@ +diff --git a/3rd_party/browsermt-marian-dev/src/3rd_party/zstr/strict_fstream.hpp b/3rd_party/browsermt-marian-dev/src/3rd_party/zstr/strict_fstream.hpp +index 7b1173931df977e69021f3995fa064a492f89d38..948e91eaf99b6b29ce41cf793fba6717f3b5f5b5 100644 +--- a/3rd_party/browsermt-marian-dev/src/3rd_party/zstr/strict_fstream.hpp ++++ b/3rd_party/browsermt-marian-dev/src/3rd_party/zstr/strict_fstream.hpp +@@ -27,7 +27,7 @@ static std::string strerror() + { + buff = "Unknown error"; + } +-#elif (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600 || __APPLE__) && ! _GNU_SOURCE ++#elif (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600 || __APPLE__) + // XSI-compliant strerror_r() + if (strerror_r(errno, &buff[0], buff.size()) != 0) + { diff --git a/inference/scripts/build-local.sh b/inference/scripts/build-local.sh new file mode 100755 index 000000000..ae64689fe --- /dev/null +++ b/inference/scripts/build-local.sh @@ -0,0 +1,50 @@ +#!/bin/bash +set -e + +# Run script from the context of inference directory +cd "$(dirname $0)/.." + +# Ensure script is running within docker +./scripts/detect-docker.sh inference-build + +# Return the number of available CPUs, or default to 1 if nproc is unavailable. +detect_cpus() { + if command -v nproc >/dev/null 2>&1; then + nproc + else + echo 1 + fi +} + +# Parse command-line arguments for the --test flag +COMPILE_TESTS=OFF +while [[ "$#" -gt 0 ]]; do + case $1 in + "--test") COMPILE_TESTS=ON ;; + *) echo "Unknown parameter passed: $1"; exit 1 ;; + esac + shift +done + +if [ ! -d "build-local" ]; then + echo "Creating build-local directory..." + mkdir build-local +else + echo "build-local directory already exists. Skipping creation." +fi + +cd build-local || exit + +# Run cmake with optional COMPILE_TESTS flag +echo "Running cmake for build-local..." +if [ "$COMPILE_TESTS" = "ON" ]; then + cmake ../ -DCOMPILE_TESTS=ON +else + cmake ../ +fi + +# Run make using the detected number of CPUs +CPUS=$(detect_cpus) +echo "Running make for build-local with $CPUS CPUs..." +make -j ${CPUS} + diff --git a/inference/scripts/build-wasm.sh b/inference/scripts/build-wasm.sh new file mode 100755 index 000000000..c21eea985 --- /dev/null +++ b/inference/scripts/build-wasm.sh @@ -0,0 +1,69 @@ +#!/usr/bin/env bash +set -e + +# Run script from the context of inference directory +cd "$(dirname $0)/.." + +# Ensure script is running within docker +./scripts/detect-docker.sh inference-build-wasm + +set -x + +# Prerequisite: Download and Install Emscripten using following instructions (unless the EMSDK env var is already set) +if [ "$EMSDK" == "" ]; then + EMSDK_UPDATE_REQUIRED=0 + if [ ! -d "emsdk" ]; then + git clone https://github.com/emscripten-core/emsdk.git + EMSDK_UPDATE_REQUIRED=1 + else + cd emsdk + git fetch + # Only pull if necessary + if [ $(git rev-parse HEAD) != $(git rev-parse @{u}) ]; then + git pull --ff-only + EMSDK_UPDATE_REQUIRED=1 + fi + cd - + fi + if [ "$EMSDK_UPDATE_REQUIRED" == "1" ]; then + cd emsdk + ./emsdk install 3.1.8 + ./emsdk activate 3.1.8 + cd - + fi + source ./emsdk/emsdk_env.sh +fi + +# Compile +# 1. Create a folder where you want to build all the artifacts and compile +BUILD_DIRECTORY="build-wasm" +if [ ! -d ${BUILD_DIRECTORY} ]; then + mkdir ${BUILD_DIRECTORY} +fi +cd ${BUILD_DIRECTORY} + +emcmake cmake -DCOMPILE_WASM=on ../ +emmake make -j2 + +# 2. Import GEMM library from a separate wasm module +bash ../wasm/patch-artifacts-import-gemm-module.sh + +set +x +echo "" +echo "Build complete" +echo "" +echo " ./build-wasm/bergamot-translator-worker.js" +echo " ./build-wasm/bergamot-translator-worker.wasm" + +WASM_SIZE=$(wc -c bergamot-translator-worker.wasm | awk '{print $1}') +GZIP_SIZE=$(gzip -c bergamot-translator-worker.wasm | wc -c | xargs) # xargs trims the whitespace + +# Convert it to human readable. +WASM_SIZE="$(awk 'BEGIN {printf "%.2f",'$WASM_SIZE'/1048576}')M ($WASM_SIZE bytes)" +GZIP_SIZE="$(awk 'BEGIN {printf "%.2f",'$GZIP_SIZE'/1048576}')M ($GZIP_SIZE bytes)" + +echo " Uncompressed wasm size: $WASM_SIZE" +echo " Compressed wasm size: $GZIP_SIZE" + +# The artifacts (.js and .wasm files) will be available in the build directory +exit 0 diff --git a/inference/scripts/clean.sh b/inference/scripts/clean.sh new file mode 100755 index 000000000..73f5ae5eb --- /dev/null +++ b/inference/scripts/clean.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e + +# Run script from the context of inference directory +cd "$(dirname $0)/.." + +# Ensure script is running within docker +./scripts/detect-docker.sh inference-clean + +# List of directories to clean +dirs=("build-local" "build-wasm" "emsdk") + +# Flag to track if any directories were cleaned +cleaned=false + +# Check and remove directories +for dir in "${dirs[@]}"; do + if [ -d "$dir" ]; then + echo "Removing $dir..." + rm -rf "$dir" + cleaned=true + fi +done + +# If no directories were cleaned, print a message +if [ "$cleaned" = false ]; then + echo "Nothing to clean" +fi + diff --git a/inference/scripts/detect-docker.sh b/inference/scripts/detect-docker.sh new file mode 100755 index 000000000..c1065349a --- /dev/null +++ b/inference/scripts/detect-docker.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +help_task=$1 + +if [ -z "${IS_DOCKER}" ]; then + if [ "${ALLOW_RUN_ON_HOST}" != "1" ]; then + echo >&2 + echo "Error: This script needs to be run inside Docker, or you must set ALLOW_RUN_ON_HOST=1." >&2 + echo >&2 + if [ -n "${help_task}" ]; then + echo " Help: To run this script directly in docker, run: task ${help_task}" >&2 + fi + echo " Help: To enter docker, run: task docker" >&2 + exit 1 + else + echo >&2 + echo "ALLOW_RUN_ON_HOST is set to 1. Continuing..." >&2 + fi +fi diff --git a/inference/scripts/unit-tests.sh b/inference/scripts/unit-tests.sh new file mode 100755 index 000000000..dd8be9925 --- /dev/null +++ b/inference/scripts/unit-tests.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set -e + +# Run script from the context of inference directory +cd "$(dirname $0)/.." + +# Ensure script is running within docker +./scripts/detect-docker.sh inference-test + +# Check if build-local/src/tests/units directory exists +if [ ! -d "build-local/src/tests/units" ]; then + echo "Directory build-local/src/tests/units does not exist. Running build." + ./scripts/build-local.sh --test +else + echo "Directory build-local/src/tests/units already exists. Skipping build." +fi + +# Change to the unit tests directory +cd build-local/src/tests/units + +# List of test commands +tests=( + "./run_annotation_tests" + "./run_cache_tests" + "./run_html_tests" + "./run_quality_estimator_tests" + "./run_xh_scanner_tests" +) + +# Run all tests, collect failures +failures=0 + +for test in "${tests[@]}"; do + echo "Running $test..." + if ! $test; then + echo "$test failed!" + failures=$((failures + 1)) + fi +done + +# If any test failed, exit with a non-zero status +if [ $failures -gt 0 ]; then + echo "$failures test(s) failed." + exit 1 +else + echo "All tests passed successfully." + exit 0 +fi + diff --git a/inference/src/CMakeLists.txt b/inference/src/CMakeLists.txt new file mode 100644 index 000000000..856831be9 --- /dev/null +++ b/inference/src/CMakeLists.txt @@ -0,0 +1,6 @@ +add_subdirectory(translator) + +if (COMPILE_TESTS) + add_subdirectory(tests) +endif(COMPILE_TESTS) + diff --git a/inference/src/tests/CMakeLists.txt b/inference/src/tests/CMakeLists.txt new file mode 100644 index 000000000..cd0e4c777 --- /dev/null +++ b/inference/src/tests/CMakeLists.txt @@ -0,0 +1,24 @@ +# Unit tests + +# Include Catch explicitly from marian. +set(CATCH_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/3rd_party/browsermt-marian-dev/3rd-party) +add_library(Catch INTERFACE) +target_include_directories(Catch INTERFACE ${CATCH_INCLUDE_DIR}) + +if (COMPILE_UNIT_TESTS) + add_subdirectory(units) +endif (COMPILE_UNIT_TESTS) + + + +if(NOT MSVC) + # Testing apps + set(TEST_BINARIES async blocking intgemm-resolve wasm) + foreach(binary ${TEST_BINARIES}) + add_executable("${binary}" "${binary}.cpp") + target_link_libraries("${binary}" bergamot-translator) + set_target_properties("${binary}" PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/tests/") + endforeach(binary) + +endif(NOT MSVC) + diff --git a/inference/src/tests/async.cpp b/inference/src/tests/async.cpp new file mode 100644 index 000000000..25ba334ae --- /dev/null +++ b/inference/src/tests/async.cpp @@ -0,0 +1,27 @@ +#include "common.h" +#include "translator/parser.h" +#include "translator/service.h" +#include "translator/translation_model.h" + +using namespace marian::bergamot; + +int main(int argc, char *argv[]) { + ConfigParser configParser("AsyncService test-suite", /*multiOpMode=*/true); + configParser.parseArgs(argc, argv); + auto &config = configParser.getConfig(); + + AsyncService service(config.serviceConfig); + + std::vector> models; + + for (auto &modelConfigPath : config.modelConfigPaths) { + TranslationModel::Config modelConfig = parseOptionsFromFilePath(modelConfigPath); + std::shared_ptr model = service.createCompatibleModel(modelConfig); + models.push_back(model); + } + + TestSuite testSuite(service); + testSuite.run(config.opMode, models); + + return 0; +} diff --git a/inference/src/tests/blocking.cpp b/inference/src/tests/blocking.cpp new file mode 100644 index 000000000..3bbb45634 --- /dev/null +++ b/inference/src/tests/blocking.cpp @@ -0,0 +1,25 @@ +#include "common.h" +using namespace marian::bergamot; + +int main(int argc, char *argv[]) { + ConfigParser configParser("BlockingService test-suite", /*multiOpMode=*/true); + configParser.parseArgs(argc, argv); + + auto &config = configParser.getConfig(); + BlockingService service(config.serviceConfig); + + TestSuite testSuite(service); + std::vector> models; + + for (auto &modelConfigPath : config.modelConfigPaths) { + TranslationModel::Config modelConfig = parseOptionsFromFilePath(modelConfigPath); + std::shared_ptr model = std::make_shared(modelConfig); + models.push_back(model); + } + + /// WASM is one special case where WASM path is being checked, involving translateMultiple and a multi-line feed. + /// Hence we do not bind it at a single input-blob single Response constraint imposed by the TestSuite. + testSuite.run(config.opMode, models); + + return 0; +} diff --git a/inference/src/tests/common-impl.cpp b/inference/src/tests/common-impl.cpp new file mode 100644 index 000000000..431ddaa71 --- /dev/null +++ b/inference/src/tests/common-impl.cpp @@ -0,0 +1,316 @@ + +#ifndef BERGAMOT_TESTS_COMMON_IMPL +#error "This is an impl file and must not be included directly!" +#endif + +Response Bridge::translate(BlockingService &service, std::shared_ptr &model, + std::string &&source, const ResponseOptions &responseOptions) { + // project source to a vector of std::string, send in, unpack the first element from + // vector, return. + std::vector sources = {source}; + std::vector options = {responseOptions}; + return service.translateMultiple(model, std::move(sources), options).front(); +} + +Response Bridge::translate(AsyncService &service, std::shared_ptr &model, + std::string &&source, const ResponseOptions &responseOptions) { + // downgrade to blocking via promise, future, wait and return response; + std::promise responsePromise; + std::future responseFuture = responsePromise.get_future(); + + auto callback = [&responsePromise](Response &&response) { responsePromise.set_value(std::move(response)); }; + service.translate(model, std::move(source), callback, responseOptions); + + responseFuture.wait(); + + Response response = responseFuture.get(); + return response; +} + +Response Bridge::pivot(BlockingService &service, std::shared_ptr &sourceToPivot, + std::shared_ptr &pivotToTarget, std::string &&source, + const ResponseOptions &responseOptions) { + std::vector sources = {source}; + std::vector options = {responseOptions}; + return service.pivotMultiple(sourceToPivot, pivotToTarget, std::move(sources), options).front(); +} + +Response Bridge::pivot(AsyncService &service, std::shared_ptr &sourceToPivot, + std::shared_ptr &pivotToTarget, std::string &&source, + const ResponseOptions &responseOptions) { + std::promise responsePromise; + std::future responseFuture = responsePromise.get_future(); + + auto callback = [&responsePromise](Response &&response) { responsePromise.set_value(std::move(response)); }; + service.pivot(sourceToPivot, pivotToTarget, std::move(source), callback, responseOptions); + responseFuture.wait(); + Response response = responseFuture.get(); + return response; +} + +template +TestSuite::TestSuite(Service &service) : service_{service} {} + +template +void TestSuite::TestSuite::run(const std::string &opModeAsString, std::vector> &models) { + if (opModeAsString == "decoder") { + benchmarkDecoder(models.front()); + } else if (opModeAsString == "test-response-source-sentences") { + annotatedTextSentences(models.front(), /*source=*/true); + } else if (opModeAsString == "test-response-target-sentences") { + annotatedTextSentences(models.front(), /*source=*/false); + } else if (opModeAsString == "test-response-source-words") { + annotatedTextWords(models.front(), /*source=*/true); + } else if (opModeAsString == "test-response-target-words") { + annotatedTextWords(models.front(), /*source=*/false); + } else if (opModeAsString == "test-forward-backward") { + forwardAndBackward(models); + } else if (opModeAsString == "test-quality-estimator-words") { + qualityEstimatorWords(models.front()); + } else if (opModeAsString == "test-quality-estimator-scores") { + qualityEstimatorScores(models.front()); + } else if (opModeAsString == "test-translation-cache") { + translationCache(models.front()); + } else if (opModeAsString == "test-pivot") { + pivotTranslate(models); + } else if (opModeAsString == "test-pivot-with-html") { + pivotTranslateWithHTML(models); + } else if (opModeAsString == "test-html-translation") { + htmlTranslation(models.front()); + } else { + std::cerr << "Incompatible test mode. Choose from the one of the valid test-modes"; + std::abort(); + } +} + +template +void TestSuite::benchmarkDecoder(Ptr &model) { + marian::timer::Timer decoderTimer; + std::string source = readFromStdin(); + + ResponseOptions responseOptions; + Response response = bridge_.translate(service_, model, std::move(source), responseOptions); + + for (size_t sentenceIdx = 0; sentenceIdx < response.size(); sentenceIdx++) { + std::cout << response.target.sentence(sentenceIdx) << "\n"; + } + + std::cerr << "Total time: " << std::setprecision(5) << decoderTimer.elapsed() << "s wall" << std::endl; +} + +// Reads from stdin and translates. Prints the tokens separated by space for each sentence. Prints words from source +// side text annotation if source=true, target annotation otherwise. +template +void TestSuite::annotatedTextWords(Ptr model, bool sourceSide /*=true*/) { + ResponseOptions responseOptions; + std::string source = readFromStdin(); + Response response = bridge_.translate(service_, model, std::move(source), responseOptions); + AnnotatedText &annotatedText = sourceSide ? response.source : response.target; + for (size_t s = 0; s < annotatedText.numSentences(); s++) { + for (size_t w = 0; w < annotatedText.numWords(s); w++) { + std::cout << (w == 0 ? "" : "\t"); + std::cout << annotatedText.word(s, w); + } + std::cout << "\n"; + } +} + +// Reads from stdin and translates the read content. Prints the sentences in source or target in constructed response +// in each line, depending on source = true or false respectively. +template +void TestSuite::annotatedTextSentences(Ptr model, bool sourceSide /*=true*/) { + ResponseOptions responseOptions; + std::string source = readFromStdin(); + Response response = bridge_.translate(service_, model, std::move(source), responseOptions); + AnnotatedText &annotatedText = sourceSide ? response.source : response.target; + for (size_t s = 0; s < annotatedText.numSentences(); s++) { + std::cout << annotatedText.sentence(s) << "\n"; + } +} + +template +void TestSuite::forwardAndBackward(std::vector> &models) { + ABORT_IF(models.size() != 2, "Forward and backward test needs two models."); + ResponseOptions responseOptions; + std::string source = readFromStdin(); + Response forwardResponse = bridge_.translate(service_, models.front(), std::move(source), responseOptions); + + // Make a copy of target + std::string target = forwardResponse.target.text; + Response backwardResponse = bridge_.translate(service_, models.back(), std::move(target), responseOptions); + + // Print both onto the command-line + std::cout << forwardResponse.source.text; + std::cout << "----------------\n"; + std::cout << forwardResponse.target.text; + std::cout << "----------------\n"; + std::cout << backwardResponse.target.text; +} + +// Reads from stdin and translates the read content. Prints the quality words for each sentence. +template +void TestSuite::qualityEstimatorWords(Ptr model) { + ResponseOptions responseOptions; + responseOptions.qualityScores = true; + std::string source = readFromStdin(); + const Response response = bridge_.translate(service_, model, std::move(source), responseOptions); + + for (size_t sentenceIdx = 0; sentenceIdx < response.qualityScores.size(); ++sentenceIdx) { + const auto &sentenceQualityEstimate = response.qualityScores[sentenceIdx]; + std::cout << "[SentenceBegin]\n"; + + for (const auto &wordByteRange : getWordByteRanges(response, sentenceIdx)) { + const string_view word(response.target.text.data() + wordByteRange.begin, wordByteRange.size()); + std::cout << word << "\n"; + } + std::cout << "[SentenceEnd]\n\n"; + } +} + +template +void TestSuite::htmlTranslation(Ptr model) { + ResponseOptions responseOptions; + responseOptions.HTML = true; + responseOptions.alignment = true; + std::string source = readFromStdin(); + const Response response = bridge_.translate(service_, model, std::move(source), responseOptions); + + std::cout << response.target.text; +} + +// Reads from stdin and translates the read content. Prints the quality scores for each sentence. +template +void TestSuite::qualityEstimatorScores(Ptr model) { + ResponseOptions responseOptions; + responseOptions.qualityScores = true; + + std::string source = readFromStdin(); + const Response response = bridge_.translate(service_, model, std::move(source), responseOptions); + + for (const auto &sentenceQualityEstimate : response.qualityScores) { + std::cout << std::fixed << std::setprecision(3) << sentenceQualityEstimate.sentenceScore << "\n"; + + for (const float &wordScore : sentenceQualityEstimate.wordScores) { + std::cout << std::fixed << std::setprecision(3) << wordScore << "\n"; + } + std::cout << "\n"; + } +} + +template +void TestSuite::translationCache(Ptr model) { + ResponseOptions responseOptions; + + // Read a large input text blob from stdin + const std::string source = readFromStdin(); + + // Round 1 + std::string buffer = source; + Response firstResponse = bridge_.translate(service_, model, std::move(buffer), responseOptions); + + auto statsFirstRun = service_.cacheStats(); + LOG(info, "Cache Hits/Misses = {}/{}", statsFirstRun.hits, statsFirstRun.misses); + ABORT_IF(statsFirstRun.hits != 0, "Expecting no cache hits, but hits found."); + + // Round 2; There should be cache hits + buffer = source; + Response secondResponse = bridge_.translate(service_, model, std::move(buffer), responseOptions); + + auto statsSecondRun = service_.cacheStats(); + LOG(info, "Cache Hits/Misses = {}/{}", statsSecondRun.hits, statsSecondRun.misses); + ABORT_IF(statsSecondRun.hits <= 0, "At least one hit expected, none found."); + if (statsSecondRun.hits != statsFirstRun.misses) { + std::cerr << "Mismatch in expected hits (Hits, Misses = " << statsSecondRun.hits << ", " << statsSecondRun.misses + << "). This can happen due to random eviction." << std::endl; + } + + ABORT_IF(firstResponse.target.text != secondResponse.target.text, + "Recompiled string provided different output when operated with cache. On the same hardware while using " + "same path, this is expected to be same."); + + std::cout << firstResponse.target.text; +} + +template +void TestSuite::pivotTranslateWithHTML(std::vector> &models) { + ABORT_IF(models.size() != 2, "Forward and backward test needs two models."); + ResponseOptions responseOptions; + responseOptions.HTML = true; + std::string source = readFromStdin(); + std::promise responsePromise; + std::future responseFuture = responsePromise.get_future(); + Response response = bridge_.pivot(service_, models.front(), models.back(), std::move(source), responseOptions); + std::cout << response.source.text; + std::cout << response.target.text; +} + +template +void TestSuite::pivotTranslate(std::vector> &models) { + // We expect a source -> pivot; pivot -> source model to get source -> source and build this test using accuracy of + // matches. + ABORT_IF(models.size() != 2, "Forward and backward test needs two models."); + ResponseOptions responseOptions; + responseOptions.alignment = true; + std::string source = readFromStdin(); + std::promise responsePromise; + std::future responseFuture = responsePromise.get_future(); + + Response response = bridge_.pivot(service_, models.front(), models.back(), std::move(source), responseOptions); + + const float EPS = 1e-5; + size_t totalOutcomes = 0; + size_t favourableOutcomes = 0; + + for (size_t sentenceId = 0; sentenceId < response.source.numSentences(); sentenceId++) { + std::cout << "> " << response.source.sentence(sentenceId) << "\n"; + std::cout << "< " << response.target.sentence(sentenceId) << "\n\n"; + + // Assert what we have is a probability distribution over source-tokens given a target token. + for (size_t t = 0; t < response.alignments[sentenceId].size(); t++) { + float sum = 0.0f; + for (size_t s = 0; s < response.alignments[sentenceId][t].size(); s++) { + sum += response.alignments[sentenceId][t][s]; + } + + std::cerr << fmt::format("Sum @ (target-token = {}, sentence = {}) = {}", t, sentenceId, sum) << std::endl; + ABORT_IF((std::abs(sum - 1.0f) > EPS), "Not a probability distribution, something's going wrong"); + } + + // For each target token, find argmax s, i.e find argmax p(s | t), max p(s | t) + for (size_t t = 0; t < response.alignments[sentenceId].size(); t++) { + bool valid = false; + float maxV = 0.0f; + auto argmaxV = std::make_pair(-1, -1); + for (size_t s = 0; s < response.alignments[sentenceId][t].size(); s++) { + auto v = response.alignments[sentenceId][t][s]; + if (v > maxV) { + maxV = v; + argmaxV = std::make_pair(t, s); + } + } + + auto sourceToken = response.source.word(sentenceId, argmaxV.second); + auto targetToken = response.target.word(sentenceId, argmaxV.first); + if (sourceToken == targetToken) { + favourableOutcomes += 1; + } + + std::cerr << sourceToken << " " << targetToken << " " << maxV << std::endl; + + totalOutcomes += 1; + } + + // Assert each alignment over target is a valid probability distribution. + } + + // Measure accuracy of word match. + float accuracy = static_cast(favourableOutcomes) / static_cast(totalOutcomes); + + // This is arbitrary value chosen by @jerinphilip, but should be enough to check if things fail. + // This value is calibrated on bergamot input in BRT. All this is supposed to do is let the developers know if + // something's largely amiss to the point alignments are not working. + ABORT_IF(accuracy < 0.70, "Accuracy {} not enough. Please check if something's off.", accuracy * 100); + + std::cout << response.source.text; + std::cout << response.target.text; +} diff --git a/inference/src/tests/common.h b/inference/src/tests/common.h new file mode 100644 index 000000000..238a62357 --- /dev/null +++ b/inference/src/tests/common.h @@ -0,0 +1,100 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +#include "common/definitions.h" +#include "common/timer.h" +#include "common/utils.h" +#include "marian.h" +#include "translator/byte_array_util.h" +#include "translator/parser.h" +#include "translator/response.h" +#include "translator/response_options.h" +#include "translator/service.h" +#include "translator/utils.h" + +namespace marian::bergamot { + +/// Due to the stubborn-ness of the extension and native to not agree on API (e.g, translateMultiple vs translate), +/// different underlying cache we have the following "bridge" at test-applications - taking into account the fact that +/// the most commonly used primitives across both Services is a single text blob in and corresponding Response out, in a +/// blocking fashion. +/// +/// The following contraption constrains a single sentence to single Response parameterized by Service, in a test-suite +/// below. This allows sharing of code for test-suite between WebAssembly's workflows and Native's workflows. +/// +/// The intention here is to use templating to achieve the same thing an ifdef would have at compile-time. Also mandates +/// after bridge layer, both WebAssembly and Native paths compile correctly (this does not guarantee outputs are the +/// same through both code-paths, or that both are tested at runtime - only that both compile and work with a bridge). +/// +/// For any complex workflows involving non-blocking concurrent translation, it is required to write something not +/// constrained by the following. + +template +struct Bridge : public std::false_type {}; + +template <> +struct Bridge : public std::true_type { + Response translate(BlockingService &service, std::shared_ptr &model, std::string &&source, + const ResponseOptions &responseOptions); + Response pivot(BlockingService &service, std::shared_ptr &sourceToPivot, + std::shared_ptr &pivotToTarget, std::string &&source, + const ResponseOptions &responseOptions); +}; + +template <> +struct Bridge : public std::true_type { + Response translate(AsyncService &service, std::shared_ptr &model, std::string &&source, + const ResponseOptions &responseOptions); + Response pivot(AsyncService &service, std::shared_ptr &sourceToPivot, + std::shared_ptr &pivotToTarget, std::string &&source, + const ResponseOptions &responseOptions); +}; + +template +class TestSuite { + private: + Bridge bridge_; + Service &service_; + + public: + TestSuite(Service &service); + void run(const std::string &opModeAsString, std::vector> &models); + + private: + void benchmarkDecoder(Ptr &model); + + // Reads from stdin and translates. Prints the tokens separated by space for each sentence. Prints words from source + // side text annotation if source=true, target annotation otherwise. + void annotatedTextWords(Ptr model, bool sourceSide = true); + + // Reads from stdin and translates the read content. Prints the sentences in source or target in constructed response + // in each line, depending on source = true or false respectively. + void annotatedTextSentences(Ptr model, bool sourceSide = true); + + void forwardAndBackward(std::vector> &models); + + // Reads from stdin and translates the read content. Prints the quality words for each sentence. + void qualityEstimatorWords(Ptr model); + + // Reads from stdin and translates the read content. Prints the quality scores for each sentence. + void qualityEstimatorScores(Ptr model); + + void translationCache(Ptr model); + + void pivotTranslate(std::vector> &models); + + void pivotTranslateWithHTML(std::vector> &models); + + void htmlTranslation(Ptr model); +}; + +#define BERGAMOT_TESTS_COMMON_IMPL +#include "common-impl.cpp" +#undef BERGAMOT_TESTS_COMMON_IMPL + +} // namespace marian::bergamot diff --git a/inference/src/tests/intgemm-resolve.cpp b/inference/src/tests/intgemm-resolve.cpp new file mode 100644 index 000000000..f95d0c449 --- /dev/null +++ b/inference/src/tests/intgemm-resolve.cpp @@ -0,0 +1,8 @@ +#include + +#include "intgemm/intgemm.h" + +int main() { + std::cout << static_cast(intgemm::kCPU) << "\n"; + return 0; +} diff --git a/inference/src/tests/units/CMakeLists.txt b/inference/src/tests/units/CMakeLists.txt new file mode 100644 index 000000000..9cfb50006 --- /dev/null +++ b/inference/src/tests/units/CMakeLists.txt @@ -0,0 +1,25 @@ +# Unit tests +set(UNIT_TESTS + annotation_tests + cache_tests + quality_estimator_tests + html_tests + xh_scanner_tests) + +foreach(test ${UNIT_TESTS}) + add_executable("run_${test}" run_tests.cpp "${test}.cpp") + target_include_directories("run_${test}" PRIVATE ${CATCH_INCLUDE_DIR} "${CMAKE_SOURCE_DIR}/src") + + if(CUDA_FOUND) + target_link_libraries("run_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS} Catch bergamot-translator) + else(CUDA_FOUND) + target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch bergamot-translator) + endif(CUDA_FOUND) + + if(msvc) + # disable c4305: truncation from 'double' to '_ty' + target_compile_options("run_${test}" public /wd4305) + endif(msvc) + + add_test(NAME ${test} COMMAND "run_${test}") +endforeach(test) diff --git a/inference/src/tests/units/annotation_tests.cpp b/inference/src/tests/units/annotation_tests.cpp new file mode 100644 index 000000000..d7178f4df --- /dev/null +++ b/inference/src/tests/units/annotation_tests.cpp @@ -0,0 +1,214 @@ +#include +#include + +#include "catch.hpp" +#include "translator/annotation.h" + +using namespace marian::bergamot; + +TEST_CASE("Test Annotation API with random sentences") { + /// Objective here is to test insertion for sentences, and that whatever comes + /// out adheres to the way it was inserted. Towards this, we keep externally + /// which sentence went in where and try to use accessor methods on + /// AnnotatedText to check if what we have as ground-truth by construction is + /// consistent with what is returned. + size_t sentences = 500; + size_t maxWords = 40; + + // Set in case needed to see output. The output is in lines of #sentences + + // header, which can be split and compared for easy understanding. The ideal + // way to inspect what is going wrong is to redirect output and use to split + // the different stages by sentences + 1 lines and check the diff. + bool debug{false}; + + std::mt19937 randomIntGen_; + randomIntGen_.seed(42); + + // External book-keeping so we have ground truths. Each element represents a + // sentence. + + // word byte ranges - for testAnnotation.word(sId, wId) + std::vector> groundTruthWords; + // sentence byte ranges - for testAnnotation.sentence(sId, wId) + std::vector groundTruthSentences; + + // Prepare the text and construct ByteRanges as intended for sentences and + // words. The ByteRanges we construct here are expected to be the + // ground-truths for words and sentences. The string being constructed is like + // as follows: + // + // 0-0 0-1 0-2 0-3 + // 1-0 1-1 1-2 1-3 1-4 + // 2-0 2-1 + // + // 4-0 4-1 4-2 4-3 + // + // Tokens are contiguous because that's how SentencePiece works. + // + // Below, we accumulate the text with intended structure as above, and + // ground-truth tables populated to be aware of the ByteRanges where they are + // meant to be. + if (debug) { + std::cout << "Preparing text and ground truth-tables" << std::endl; + } + std::string text; + for (size_t idx = 0; idx < sentences; idx++) { + if (idx != 0) text += "\n"; + + // Words can be zero, we need to support empty word sentences as well. + size_t numWords = randomIntGen_() % maxWords; + + std::vector wordByteRanges; + wordByteRanges.reserve(numWords); + + // For empty sentence, we expect it to be empty and marked in position where + // the existing string is if needed to be pointed out. + size_t before = text.size() - 1; + size_t sentenceBegin{before}, sentenceEnd{before}; + + for (size_t idw = 0; idw < numWords; idw++) { + // Get new beginning, accounting for space above. + before = text.size(); + + // Add the word + std::string word = std::to_string(idx) + "-" + std::to_string(idw); + text += word; + + // Do math, before, before + new-word's size. + wordByteRanges.push_back((ByteRange){before, before + word.size()}); + + if (debug) { + std::cout << word; + } + + if (idw == 0) { + sentenceBegin = before; + } + if (idw == numWords - 1) { + sentenceEnd = before + word.size(); + } + } + if (debug) { + std::cout << std::endl; + } + + groundTruthWords.push_back(wordByteRanges); + groundTruthSentences.push_back((ByteRange){sentenceBegin, sentenceEnd}); + } + + AnnotatedText testAnnotation(std::move(text)); // This the container we add through API and + // check if the access is correct. + + // We prepare string_views now with the known ByteRanges and use the + // string_view based AnnotatedText.addSentence(...) API to add sentences to + // transparently convert from string_views to ByteRanges, rebasing/working out + // the math underneath. + + if (debug) { + std::cout << "Inserting words onto container and save ground-truth-table:" << std::endl; + } + + std::vector> wordStringViews; + std::vector::const_iterator sentence_iter = groundTruthSentences.begin(); + for (auto &sentence : groundTruthWords) { + std::vector wordByteRanges; + bool first{true}; + for (auto &word : sentence) { + marian::string_view wordView(&testAnnotation.text[word.begin], word.size()); + wordByteRanges.push_back(wordView); + if (debug) { + if (first) { + first = false; + } else { + std::cout << " "; + } + std::cout << std::string(wordView); + } + } + testAnnotation.recordExistingSentence(wordByteRanges.begin(), wordByteRanges.end(), + testAnnotation.text.data() + sentence_iter->begin); + ++sentence_iter; + wordStringViews.push_back(wordByteRanges); + if (debug) { + std::cout << std::endl; + } + } + + if (debug) { + std::cout << "Inserting sentences onto container and save ground-truth-table" << std::endl; + } + std::vector sentenceStringViews; + for (auto &sentenceByteRange : groundTruthSentences) { + char *data = &(testAnnotation.text[sentenceByteRange.begin]); + marian::string_view sentenceView(data, sentenceByteRange.size()); + sentenceStringViews.push_back(sentenceView); + + if (debug) { + std::cout << sentenceView << std::endl; + } + } + + // Access from the sentence(sentenceIdx) API and confirm that the ground truth + // we expect is same as what comes out of the container. + if (debug) { + std::cout << "From container: Sentences" << std::endl; + } + for (int idx = 0; idx < groundTruthSentences.size(); idx++) { + ByteRange expected = groundTruthSentences[idx]; + ByteRange obtained = testAnnotation.sentenceAsByteRange(idx); + if (debug) { + std::cout << std::string(testAnnotation.sentence(idx)) << std::endl; + } + CHECK(expected.begin == obtained.begin); + CHECK(expected.end == obtained.end); + std::string expected_string = std::string(sentenceStringViews[idx]); + std::string obtained_string = std::string(testAnnotation.sentence(idx)); + CHECK(expected_string == obtained_string); + } + + /// Access the word(sentenceIdx, wordIdx) API and confirm what we hold as + /// expected words are the same as those obtained from the container. + if (debug) { + std::cout << "From container: Words" << std::endl; + } + + CHECK(groundTruthWords.size() == testAnnotation.numSentences()); + for (int idx = 0; idx < groundTruthWords.size(); idx++) { + CHECK(groundTruthWords[idx].size() == testAnnotation.numWords(idx)); + } + + for (int idx = 0; idx < groundTruthWords.size(); idx++) { + for (int idw = 0; idw < groundTruthWords[idx].size(); idw++) { + ByteRange expected = groundTruthWords[idx][idw]; + ByteRange obtained = testAnnotation.wordAsByteRange(idx, idw); + if (debug) { + std::cout << std::string(testAnnotation.word(idx, idw)) << " "; + } + CHECK(expected.begin == obtained.begin); + CHECK(expected.end == obtained.end); + + std::string expected_string = std::string(wordStringViews[idx][idw]); + std::string obtained_string = std::string(testAnnotation.word(idx, idw)); + CHECK(expected_string == obtained_string); + } + if (debug) { + std::cout << std::endl; + } + } + + // Try inserting an empty Sentence. This is ensuring we check for empty + // Sentence if the random test above does not cover it for some reason. + int emptySentenceIdx = sentences; + std::vector emptySentence; + testAnnotation.recordExistingSentence(emptySentence.begin(), emptySentence.end(), + testAnnotation.text.data() + testAnnotation.text.size()); + + // There are no words. + CHECK(testAnnotation.numWords(emptySentenceIdx) == 0); + + // Empty sentence expected at output. + std::string expectedEmptyString = ""; + marian::string_view emptyView = testAnnotation.sentence(emptySentenceIdx); + std::string obtainedString = std::string(emptyView.data(), emptyView.size()); + CHECK(expectedEmptyString == obtainedString); +} diff --git a/inference/src/tests/units/cache_tests.cpp b/inference/src/tests/units/cache_tests.cpp new file mode 100644 index 000000000..f2f1b19ed --- /dev/null +++ b/inference/src/tests/units/cache_tests.cpp @@ -0,0 +1,56 @@ + +#include +#include + +#include "catch.hpp" +#include "translator/cache.h" +#include "translator/history.h" + +using namespace marian::bergamot; + +TEST_CASE("Test Cache in a threaded setting") { + size_t numThreads = 100; + size_t numIters = 10000; + using Key = int; + using Value = int; + using TestCache = AtomicCache; + + TestCache cache(/*size=*/300, /*mutexBuckets=*/16); + + auto op = [numIters, &cache]() { + std::mt19937_64 randomGenerator; + randomGenerator.seed(42); // reproducible outputs + Value randMax = 2000; + + for (size_t i = 0; i < numIters; i++) { + Key query = randomGenerator() % randMax; + std::pair result = cache.find(query); + if (result.first) { + REQUIRE(result.second == query); + } + + Value value = query; + cache.store(/*key=*/query, std::move(value)); + } + }; + + std::vector workers; + for (size_t t = 0; t < numThreads; t++) { + workers.emplace_back(op); + } + + for (size_t t = 0; t < numThreads; t++) { + workers[t].join(); + } + + TestCache::Stats stats = cache.stats(); + float hitRate = static_cast(stats.hits) / static_cast(stats.hits + stats.misses); + + // This is non-deterministic due to threads. + std::cout << "Hit-Rate:" << hitRate << "\n"; + std::cout << "(Hits, Misses) = " << stats.hits << " " << stats.misses << "\n"; + + // Can we create a specialization of the actual cache-type we want? Does it compile, at least? + // We already have Ptr, it's easier to move Ptr to cache. + TranslationCache translationCache(/*size=*/300, /*mutexBuckets=*/16); +} diff --git a/inference/src/tests/units/html_tests.cpp b/inference/src/tests/units/html_tests.cpp new file mode 100644 index 000000000..96eff5aad --- /dev/null +++ b/inference/src/tests/units/html_tests.cpp @@ -0,0 +1,880 @@ +#include "html_tests.h" + +#include + +#include "catch.hpp" +#include "data/types.h" // for marian::string_view +#include "translator/html.h" +#include "translator/response.h" + +using namespace marian::bergamot; +using marian::string_view; + +class MarianThrowsExceptionsFixture { + protected: + MarianThrowsExceptionsFixture() : prev_(marian::getThrowExceptionOnAbort()) { + marian::setThrowExceptionOnAbort(true); + } + ~MarianThrowsExceptionsFixture() { marian::setThrowExceptionOnAbort(prev_); } + + private: + bool prev_; +}; + +std::ostream &operator<<(std::ostream &out, std::pair const &b) { + return out << '(' << b.first << ',' << b.second << ')'; +} + +std::ostream &operator<<(std::ostream &out, ByteRange const &b) { return out << '{' << b.begin << ',' << b.end << '}'; } + +std::vector asByteRanges(AnnotatedText const &annotation) { + std::vector words; + words.emplace_back(annotation.annotation.gap(0)); + for (size_t sentenceIdx = 0; sentenceIdx < annotation.numSentences(); ++sentenceIdx) { + for (size_t wordIdx = 0; wordIdx < annotation.numWords(sentenceIdx); ++wordIdx) + words.emplace_back(annotation.wordAsByteRange(sentenceIdx, wordIdx)); + words.emplace_back(annotation.annotation.gap(sentenceIdx + 1)); + } + return words; +} + +std::vector asTokens(AnnotatedText const &annotation) { + std::vector words; + words.emplace_back(annotation.gap(0)); + for (size_t sentenceIdx = 0; sentenceIdx < annotation.numSentences(); ++sentenceIdx) { + for (size_t wordIdx = 0; wordIdx < annotation.numWords(sentenceIdx); ++wordIdx) + words.emplace_back(annotation.word(sentenceIdx, wordIdx)); + words.emplace_back(annotation.gap(sentenceIdx + 1)); + } + return words; +} + +void recordSentenceFromByteRange(AnnotatedText &text, std::vector const &ranges) { + assert(ranges.size() > 0); + + std::vector tokens; + tokens.reserve(ranges.size()); + + for (auto &&range : ranges) tokens.emplace_back(text.text.data() + range.begin, range.size()); + + text.recordExistingSentence(tokens.begin(), tokens.end(), text.text.data() + ranges[0].begin); +} + +template +std::vector> identity_matrix(size_t size) { + std::vector> rows(size); + for (size_t row = 0; row < size; ++row) { + rows[row].resize(size, T(0)); + rows[row][row] = T(1); + } + return rows; +} + +TEST_CASE("Ignore HTML if process_markup is false") { + std::string html_code("

This text & has HTML in it

"); + + std::string input(html_code); + HTML html(std::move(input), false); + CHECK(input == html_code); + + Response response; + response.source.text = html_code; + response.target.text = html_code; + // Note: response.alignments is empty, which is allowed in this case + html.restore(response); + + // Assert that restore() does not mess with my HTML code + CHECK(response.source.text == html_code); +} + +TEST_CASE_METHOD(MarianThrowsExceptionsFixture, "Abort if alignments are missing") { + std::string input("

hello world

\n"); + HTML html(std::move(input), true); + + AnnotatedText source("hello world\n"); + recordSentenceFromByteRange(source, { + ByteRange{0, 4}, // 0.0 "hell" + ByteRange{4, 5}, // 0.1 "o" + ByteRange{5, 11}, // 0.2 " world" + ByteRange{11, 11} // 0.3 "" + }); + + AnnotatedText target("hallo Welt\n"); + recordSentenceFromByteRange(target, { + ByteRange{0, 4}, // 0.0 "hall" + ByteRange{4, 5}, // 0.1 "o" + ByteRange{5, 10}, // 0.2 " Welt" + ByteRange{10, 10} // 0.3 "" + }); + + Response response; + response.source = source; + response.target = target; + // Note: explicitly not setting response.alignments + + CHECK_THROWS_WITH( + html.restore(response), + "Response object does not contain alignments. TranslationModel or ResponseOptions is misconfigured?"); +} + +TEST_CASE_METHOD(MarianThrowsExceptionsFixture, "Abort if alignments are misconfigured") { + std::string input("

hello world

\n"); + HTML html(std::move(input), true); + + AnnotatedText source("hello world\n"); + recordSentenceFromByteRange(source, { + ByteRange{0, 4}, // 0.0 "hell" + ByteRange{4, 5}, // 0.1 "o" + ByteRange{5, 11}, // 0.2 " world" + ByteRange{11, 11} // 0.3 "" + }); + + AnnotatedText target("hallo Welt\n"); + recordSentenceFromByteRange(target, { + ByteRange{0, 4}, // 0.0 "hall" + ByteRange{4, 5}, // 0.1 "o" + ByteRange{5, 10}, // 0.2 " Welt" + ByteRange{10, 10} // 0.3 "" + }); + + Response response; + response.source = source; + response.target = target; + + // If the model is misconfigured to not give any alignment information, + // response will have entries for each target word, but they will all be empty. + response.alignments = {{{}, {}, {}, {}}}; + + CHECK_THROWS_WITH( + html.restore(response), + "Response object does not contain alignments. TranslationModel or ResponseOptions is misconfigured?"); +} + +TEST_CASE("Do not abort if the input is just empty") { + std::string input(""); + HTML html(std::move(input), true); + CHECK(input == ""); + + Response response; + html.restore(response); + CHECK(response.source.text == ""); + CHECK(response.target.text == ""); +} + +TEST_CASE("Do not abort if the input is just empty element") { + std::string input("

"); + HTML html(std::move(input), true); + CHECK(input == ""); + + Response response; + html.restore(response); + CHECK(response.source.text == "

"); + CHECK(response.target.text == "

"); +} + +TEST_CASE("Tag names are case insensitive") { + // Tests

vs

and
should be recognized as a void tag
. + // should be recognized as inline. + std::string test_str("

Space
please?

"); + + std::string input(test_str); + HTML html(std::move(input), true); + CHECK(input == "Spa ce\n\nplease?"); +} + +TEST_CASE("Test case html entities") { + // These are all entities I would expect in innerHTML, since all other entities + // can be encoded as UTF-8 so there's no need to encode them through &...; when + // innerHTML encodes the DOM as HTML. + std::string input("

This is a sentence <with> named & entities

"); + HTML html(std::move(input), true); + CHECK(input == "This is a sentence named & entities"); +} + +TEST_CASE("Test self-closing tags should be treated as paragraph break") { + std::string test_str("

Space
please?

"); + + std::string input(test_str); + HTML html(std::move(input), true); + CHECK(input == "Space\n\nplease?"); + + Response response; + std::string source_str("Space\n\nplease?"); + std::vector source_tokens{ + string_view(source_str.data() + 0, 5), // Space + string_view(source_str.data() + 5, 0), // [EOS] + string_view(source_str.data() + 5, 2), // \n\n + string_view(source_str.data() + 7, 1), // p + string_view(source_str.data() + 8, 5), // lease + string_view(source_str.data() + 13, 1), // ? + string_view(source_str.data() + 14, 0), // EOS + }; + response.source.appendSentence("", source_tokens.begin(), source_tokens.begin() + 2); + response.source.appendSentence("\n\n", source_tokens.begin() + 3, source_tokens.end()); + + std::string target_str("Platz\n\nbitte?"); + std::vector target_tokens{ + string_view(target_str.data() + 0, 5), // Platz + string_view(target_str.data() + 5, 0), // [EOS] + string_view(target_str.data() + 5, 2), // \n\n + string_view(target_str.data() + 7, 5), // bitte + string_view(target_str.data() + 12, 1), // ? + string_view(target_str.data() + 13, 0), // [EOS] + }; + response.target.appendSentence("", target_tokens.begin(), target_tokens.begin() + 2); + response.target.appendSentence("", target_tokens.begin() + 3, target_tokens.end()); + response.alignments = {{ + {1.0, 0.0}, // Platz <- Space + {0.0, 1.0} // [EOS] <- [EOS] + }, + { + {0.1, 0.9, 0.0, 0.0}, // _bitte <- _p + lease + {0.0, 0.0, 1.0, 0.0}, // ? <- ? + {0.0, 0.0, 0.0, 1.0}, // [EOS] <- [EOS] + }}; + + // Main focus of this test is that the space that was introduced in the text + // that was being translated does not end up in the translation. + html.restore(response); + CHECK(response.source.text == "

Space
please?

"); + CHECK(response.target.text == "

Platz
bitte?

"); +} + +TEST_CASE("Test inline tags should be treated as spaces") { + std::string test_str("underline"); + + std::string input(test_str); + HTML html(std::move(input), true); + CHECK(input == "un der line"); + + Response response; + std::string source_str("un der line"); + std::vector source_tokens{ + string_view(source_str.data() + 0, 2), // un + string_view(source_str.data() + 2, 3), // _de + string_view(source_str.data() + 5, 1), // r + string_view(source_str.data() + 6, 5), // _line + string_view(source_str.data() + 11, 0), // EOS + }; + response.source.appendSentence("", source_tokens.begin(), source_tokens.end()); + + std::string target_str("una linea der"); + std::vector target_tokens{ + string_view(target_str.data() + 0, 3), // una + string_view(target_str.data() + 3, 6), // _linéa + string_view(target_str.data() + 9, 3), // _de + string_view(target_str.data() + 12, 1), // r + string_view(target_str.data() + 13, 0), // [EOS] + }; + response.target.appendSentence("", target_tokens.begin(), target_tokens.end()); + + response.alignments = {{{0.9795, 0.0127, 0.0002, 0.0066, 0.0009}, + {0.0098, 0.2967, 0.0156, 0.6640, 0.0138}, + {0.0214, 0.7472, 0.0626, 0.0745, 0.0943}, + {0.0022, 0.0230, 0.9357, 0.0165, 0.0226}, + {0.0122, 0.0240, 0.0085, 0.7427, 0.2125}}}; + + html.restore(response); + CHECK(response.source.text == "un der line"); // TODO leave spaces? + CHECK(response.target.text == "una linea der"); +} + +TEST_CASE("Test inline tags should not break words") { + std::string test_str("underline"); + + std::string input(test_str); + HTML::Options options; + options.substituteInlineTagsWithSpaces = false; + HTML html(std::move(input), true, std::move(options)); + CHECK(input == "underline"); + + Response response; + std::string source_str("underline"); + std::vector source_tokens{ + string_view(source_str.data() + 0, 9), // underline + string_view(source_str.data() + 9, 0), // EOS + }; + response.source.appendSentence("", source_tokens.begin(), source_tokens.end()); + + std::string target_str("subrayar"); + std::vector target_tokens{ + string_view(target_str.data() + 0, 8), // subrayar + string_view(target_str.data() + 8, 0), // [EOS] + }; + response.target.appendSentence("", target_tokens.begin(), target_tokens.end()); + + response.alignments = {identity_matrix(2)}; + + html.restore(response); + CHECK(response.source.text == "underline"); // TODO not spread to whole word? + CHECK(response.target.text == "subrayar"); // TODO not spread to the whole word? +} + +TEST_CASE("Test reconstruction of target sentence") { + std::string input("

hello world

\n"); + HTML html(std::move(input), true); + CHECK(input == "hello world\n\n\n"); // tripple \n because \n +

+ + AnnotatedText source("hello world\n\n\n"); + recordSentenceFromByteRange(source, { + ByteRange{0, 4}, // 0.0 "hell" + ByteRange{4, 5}, // 0.1 "o" + ByteRange{5, 11}, // 0.2 " world" + ByteRange{11, 11} // 0.3 "" + }); + + AnnotatedText target("hallo Welt\n\n\n"); + recordSentenceFromByteRange(target, { + ByteRange{0, 4}, // 0.0 "hall" + ByteRange{4, 5}, // 0.1 "o" + ByteRange{5, 10}, // 0.2 " Welt" + ByteRange{10, 10} // 0.3 "" + }); + + Response response; + response.source = source; + response.target = target; + response.alignments = {identity_matrix(4)}; + + html.restore(response); + + std::vector html_tokens_source{"", "

hell", "o", " world", "", "

\n"}; + + std::vector html_tokens_target{"", "

hall", "o", " Welt", "", "

\n"}; + + CHECK(asTokens(response.source) == html_tokens_source); + CHECK(asTokens(response.target) == html_tokens_target); +} + +TEST_CASE("Test reconstruction of target sentence with entities") { + std::string input("

hello world & friends!

"); + HTML html(std::move(input), true); + CHECK(input == "hello world & friends!"); + + AnnotatedText source("hello world & friends!"); + recordSentenceFromByteRange(source, { + ByteRange{0, 4}, // 0.0 "hell" + ByteRange{4, 5}, // 0.1 "o" + ByteRange{5, 11}, // 0.2 " world" + ByteRange{11, 13}, // 0.3 " &" + ByteRange{13, 21}, // 0.4 " friends" + ByteRange{21, 22}, // 0.5 "!" + ByteRange{22, 22} // 0.6 "" + }); + + AnnotatedText target("hallo Welt & Freunde!"); + recordSentenceFromByteRange(target, { + ByteRange{0, 4}, // 0.0 "hall" + ByteRange{4, 5}, // 0.1 "o" + ByteRange{5, 10}, // 0.2 " Welt" + ByteRange{10, 12}, // 0.3 " &" + ByteRange{12, 20}, // 0.4 " Freunde" + ByteRange{20, 21}, // 0.5 "!" + ByteRange{21, 21} // 0.6 "" + }); + + Response response; + response.source = source; + response.target = target; + response.alignments = {identity_matrix(7)}; + + html.restore(response); + + std::vector html_tokens_source{"", "

hell", "o", " world", " &", + " friends", "!", "", "

"}; + + std::vector html_tokens_target{"", "

hall", "o", " Welt", " &", + + " Freunde", "!", "", "

"}; + + CHECK(asTokens(response.source) == html_tokens_source); + CHECK(asTokens(response.target) == html_tokens_target); +} + +TEST_CASE("Test reconstruction of target with multiple sentences") { + std::string input( + "

hello world! How does this deal with multiple sentences? Will it work?

"); + HTML html(std::move(input), true); + + AnnotatedText source("hello world! How does this deal with multiple sentences? Will it work?"); + CHECK(source.text == input); + + recordSentenceFromByteRange(source, { + ByteRange{0, 4}, // 0.0 "hell" + ByteRange{4, 5}, // 0.1 "o" + ByteRange{5, 11}, // 0.2 " world" + ByteRange{11, 12}, // 0.3 "!" + ByteRange{12, 12} // 0.4 "" + }); + recordSentenceFromByteRange(source, { + ByteRange{13, 16}, // 1.0 "How" + ByteRange{16, 21}, // 1.1 " does" + ByteRange{21, 26}, // 1.2 " this" + ByteRange{26, 32}, // 1.3 " deal" + ByteRange{32, 37}, // 1.4 " with" + ByteRange{37, 46}, // 1.5 " multiple" + ByteRange{46, 55}, // 1.6 " sentence" + ByteRange{55, 56}, // 1.7 "s" + ByteRange{56, 57}, // 1.8 "?" + ByteRange{57, 57} // 1.9 "" + }); + recordSentenceFromByteRange(source, { + ByteRange{58, 62}, // 2.0 "Will" + ByteRange{62, 65}, // 2.1 " it" + ByteRange{65, 70}, // 2.2 " work" + ByteRange{70, 71}, // 2.3 "?" + ByteRange{71, 71} // 2.4 "" + }); + + AnnotatedText target("hallo Welt! Wie geht das mit mehreren Sätzen um? Wird es funktionieren?"); + recordSentenceFromByteRange(target, { + ByteRange{0, 4}, // 0.0 "hall" + ByteRange{4, 5}, // 0.1 "o" + ByteRange{5, 10}, // 0.2 " Welt" + ByteRange{10, 11}, // 0.3 "!" + ByteRange{11, 11}, // 0.4 "" + }); + recordSentenceFromByteRange(target, { + ByteRange{12, 15}, // 1.0 "Wie" + ByteRange{15, 20}, // 1.1 " geht" + ByteRange{20, 24}, // 1.2 " das" + ByteRange{24, 28}, // 1.3 " mit" + ByteRange{28, 37}, // 1.4 " mehreren" + ByteRange{37, 44}, // 1.5 " Sätze" + ByteRange{44, 45}, // 1.6 "n" + ByteRange{45, 48}, // 1.7 " um" + ByteRange{48, 49}, // 1.8 "?" + ByteRange{49, 49}, // 1.9 "" + }); + recordSentenceFromByteRange(target, { + ByteRange{50, 54}, // 2.0 "Wird" + ByteRange{54, 57}, // 2.1 " es" + ByteRange{57, 71}, // 2.2 " funktionieren" + ByteRange{71, 72}, // 2.3 "?" + ByteRange{72, 72}, // 2.4 "" + }); + + std::vector text_tokens_source{ + "", "hall", "o", " Welt", "!", "", " ", "Wie", " geht", " das", " mit", " mehreren", + " Sätze", "n", " um", "?", "", " ", "Wird", " es", " funktionieren", "?", "", ""}; + + CHECK(asTokens(target) == text_tokens_source); + + Response response; + response.source = source; + response.target = target; + response.alignments = {identity_matrix(5), identity_matrix(10), identity_matrix(5)}; + html.restore(response); + + std::vector html_tokens_source{"", + "

hell", + "o", + " world", + "!", + "", + " ", + "How", + " does", + " this", + " deal", // note how both spaces moved to __deal + " with", + " multiple", + " sentence", + "s", + "?", + "", + " ", + "Will", + " it", + " work", + "?", + "", + "

"}; + CHECK(asTokens(response.source) == html_tokens_source); +} + +TEST_CASE("Test self-closing tag (HTML5)") { + std::string input("

hello world and other creatures

"); + HTML html(std::move(input), true); + CHECK(input == "hello world and other creatures"); // Note double space between "hello" and "world" +} + +TEST_CASE("Test self-closing tag (XHTML)") { + std::string input("

helloworld

"); + HTML html(std::move(input), true); + CHECK(input == "hello world"); // introduced space +} + +TEST_CASE("Test empty void tag at end of input") { + std::string input("hello
"); + HTML html(std::move(input), true); + CHECK(input == "hello "); + + Response response; + std::string sentence_str("hello "); + std::vector sentence{ + string_view(sentence_str.data() + 0, 4), // 0.0 hell + string_view(sentence_str.data() + 4, 2), // 0.1 o_ + string_view(sentence_str.data() + 6, 0), // 0.2 [EOS] + }; + response.source.appendSentence("", sentence.begin(), sentence.end()); + response.target.appendSentence("", sentence.begin(), sentence.end()); + response.alignments = {identity_matrix(3)}; + + html.restore(response); + CHECK(response.source.text == "hello
"); + CHECK(response.target.text == "hello
"); +} + +TEST_CASE("Test empty tag pair at end of input") { + std::string input("hello "); + HTML html(std::move(input), true); + CHECK(input == "hello "); + + Response response; + std::string sentence_str("hello "); + std::vector sentence{ + string_view(sentence_str.data() + 0, 4), // 0.0 hell + string_view(sentence_str.data() + 4, 2), // 0.1 o_ + string_view(sentence_str.data() + 6, 0), // 0.2 [EOS] + }; + response.source.appendSentence("", sentence.begin(), sentence.end()); + response.target.appendSentence("", sentence.begin(), sentence.end()); + response.alignments = {identity_matrix(3)}; + + html.restore(response); + CHECK(response.source.text == "hello "); + CHECK(response.target.text == "hello "); +} + +TEST_CASE("Test empty self-closing pair at end of input in parent") { + std::string input("

hello

"); + HTML html(std::move(input), true); + CHECK(input == "hello "); +} + +TEST_CASE("Test empty tag") { + std::string test_str( + "

hello world

"); + + std::string input(test_str); + HTML html(std::move(input), true); + CHECK(input == "hello world"); + + Response response; + + std::string sentence_str("hello world"); + std::vector sentence{ + string_view(sentence_str.data() + 0, 4), // 0.0 hell + string_view(sentence_str.data() + 4, 1), // 0.1 o + string_view(sentence_str.data() + 5, 6), // 0.2 _world + string_view(sentence_str.data() + 11, 0), // 0.3 "" + }; + response.source.appendSentence("", sentence.begin(), sentence.end()); + response.target.appendSentence("", sentence.begin(), sentence.end()); + response.alignments = {identity_matrix(4)}; + + html.restore(response); + CHECK(response.source.text == test_str); + CHECK(response.target.text == test_str); +} + +TEST_CASE("Test world"); + + std::string input(test_str); + HTML html(std::move(input), true); + CHECK(input == "hello \n\nworld"); + + Response response; + std::string sentence_str("hello \n\nworld"); + std::vector sentence{ + string_view(sentence_str.data() + 0, 4), // 0.0 hell + string_view(sentence_str.data() + 4, 2), // 0.1 o_ + string_view(sentence_str.data() + 6, 2), // 0.2 \n\n + string_view(sentence_str.data() + 8, 5), // 0.3 world + string_view(sentence_str.data() + 13, 0), // 0.4 "" + }; + response.source.appendSentence("", sentence.begin(), sentence.end()); + response.target.appendSentence("", sentence.begin(), sentence.end()); + response.alignments = {identity_matrix(5)}; + + html.restore(response); + CHECK(response.source.text == test_str); + CHECK(response.target.text == test_str); +} + +TEST_CASE("Test comment") { + std::string test_str("foo bar"); + + std::string input(test_str); + HTML html(std::move(input), true); + CHECK(input == "foo bar"); + + Response response; + std::string sentence_str("foo bar"); + std::vector sentence{ + string_view(sentence_str.data() + 0, 3), // foo + string_view(sentence_str.data() + 3, 4), // _bar + string_view(sentence_str.data() + 7, 0), // "" + }; + response.source.appendSentence("", sentence.begin(), sentence.end()); + response.target.appendSentence("", sentence.begin(), sentence.end()); + response.alignments = {identity_matrix(3)}; + + html.restore(response); + CHECK(response.source.text == test_str); + CHECK(response.target.text == test_str); +} + +TEST_CASE("Test element") { + std::string test_str("hello"); + + std::string input(test_str); + HTML html(std::move(input), true); + CHECK(input == "hello"); +} + +TEST_CASE("Test element (case-insensitive)") { + std::string test_str("hello"); + + std::string input(test_str); + HTML html(std::move(input), true); + CHECK(input == "hello"); +} + +TEST_CASE("Test ignored element (nested)") { + std::string test_str("foo nested bar"); + std::string expected_str("foo nestedbar"); + + std::string input(test_str); + HTML html(std::move(input), true); + CHECK(input == "foo bar"); + + Response response; + std::string sentence_str("foo bar"); + std::vector sentence{ + string_view(sentence_str.data() + 0, 3), // foo + string_view(sentence_str.data() + 3, 1), // _ + string_view(sentence_str.data() + 4, 4), // _bar + string_view(sentence_str.data() + 8, 0), // "" + }; + response.source.appendSentence("", sentence.begin(), sentence.end()); + response.target.appendSentence("", sentence.begin(), sentence.end()); + response.alignments = {identity_matrix(4)}; + + html.restore(response); + CHECK(response.source.text == expected_str); + CHECK(response.target.text == expected_str); +} + +TEST_CASE("Test ignored element (with entity)") { + std::string test_str("foo & bar"); + std::string expected_str("foo &bar"); + + std::string input(test_str); + HTML html(std::move(input), true); + CHECK(input == "foo bar"); + + Response response; + std::string sentence_str("foo bar"); + std::vector sentence{ + string_view(sentence_str.data() + 0, 3), // foo + string_view(sentence_str.data() + 3, 1), // _ + string_view(sentence_str.data() + 4, 4), // _bar + string_view(sentence_str.data() + 8, 0), // "" + }; + response.source.appendSentence("", sentence.begin(), sentence.end()); + response.target.appendSentence("", sentence.begin(), sentence.end()); + response.alignments = {identity_matrix(4)}; + + html.restore(response); + CHECK(response.source.text == expected_str); + CHECK(response.target.text == expected_str); +} + +TEST_CASE("End-to-end translation", "[!mayfail]") { + std::string input("

I like to drive this car.

"); + HTML html(std::move(input), true); + CHECK(input == "I like to drive this car."); + + Response response; + + // clang-format off + response.alignments = std::vector>>{{ + {0.982376, 0.00742467, 0.00682965, 0.00121767, 0.000848056,6.51436e-05,7.53791e-06,0.00123162}, + {0.165639, 0.368694, 0.230394, 0.222476, 0.00349563, 0.00105052, 0.000603092,0.00764845}, + {0.00493271,0.0805876, 0.0139988, 0.89116, 0.000928116,0.00200724, 0.000512013,0.00587302}, + {0.0194648, 0.411029, 0.087059, 0.0477847, 0.26596, 0.111161, 0.000392092,0.0571499}, + {0.00879706,0.492504, 0.0448291, 0.007779, 0.423114, 0.0125523, 0.00119587, 0.00922804}, + {0.00181909,0.00603626, 0.0335758, 0.037193, 0.747266, 0.102497, 0.0585782, 0.0130341}, + {4.1348e-06,0.000156165,2.16369e-05,0.00275059, 0.00183456, 0.992357, 0.0023765, 0.000499018}, + {0.00149043,0.000719392,0.0168534, 0.00430164, 0.00200343, 0.0106381, 0.948566, 0.0154279}, + {0.0903136, 0.0550843, 0.0699474, 0.0792285, 0.223006, 0.207565, 0.129241, 0.145614}, + }}; + // clang-format on + + { + std::string sentence_str("I like to drive this car."); + std::vector sentence{ + string_view(sentence_str.data() + 0, 1), // 0.0 "I" + string_view(sentence_str.data() + 1, 5), // 0.1 " like" + string_view(sentence_str.data() + 6, 3), // 0.2 " to" + string_view(sentence_str.data() + 9, 6), // 0.3 " drive" + string_view(sentence_str.data() + 15, 5), // 0.4 " this" + string_view(sentence_str.data() + 20, 4), // 0.5 " car" + string_view(sentence_str.data() + 24, 1), // 0.6 "." + string_view(sentence_str.data() + 25, 0), // 0.7 "" + }; + response.source.appendSentence("", sentence.begin(), sentence.end()); + } + + { + std::string sentence_str("Ich fahre gerne dieses Auto."); + std::vector sentence{ + string_view(sentence_str.data() + 0, 3), // 0.0 "Ich" + string_view(sentence_str.data() + 3, 1), // 0.1 " " + string_view(sentence_str.data() + 4, 4), // 0.2 "fahr" + string_view(sentence_str.data() + 8, 1), // 0.3 "e" + string_view(sentence_str.data() + 9, 6), // 0.4 " gerne" + string_view(sentence_str.data() + 15, 7), // 0.5 " dieses" + string_view(sentence_str.data() + 22, 5), // 0.6 " Auto" + string_view(sentence_str.data() + 27, 1), // 0.7 "." + string_view(sentence_str.data() + 28, 0), // 0.8 "" + }; + response.target.appendSentence("", sentence.begin(), sentence.end()); + } + + html.restore(response); + + { + AnnotatedText source; + std::string sentence_str("

I like to drive this car."); + std::vector sentence{ + string_view(sentence_str.data() + 0, 4), // 0.0 "

I" + string_view(sentence_str.data() + 4, 8), // 0.1 " like" + string_view(sentence_str.data() + 12, 7), // 0.2 " to" + string_view(sentence_str.data() + 19, 9), // 0.3 " drive" + string_view(sentence_str.data() + 28, 9), // 0.4 " this" + string_view(sentence_str.data() + 37, 4), // 0.5 " car" + string_view(sentence_str.data() + 41, 1), // 0.6 "." + string_view(sentence_str.data() + 42, 0), // 0.7 "" + }; + source.appendSentence("", sentence.begin(), sentence.end()); + source.appendEndingWhitespace("

"); + + CHECK(asTokens(response.source) == asTokens(source)); + } + + { + AnnotatedText target; + // Empty because the space token after "Ich" has "

" markup, passed down from "like" + std::string sentence_str("

Ich fahre gerne dieses Auto."); + std::vector sentence{ + string_view(sentence_str.data() + 0, 6), // 0.0 "

Ich" + string_view(sentence_str.data() + 6, 4), // 0.1 " " + string_view(sentence_str.data() + 10, 11), // 0.2 "fahr" + string_view(sentence_str.data() + 21, 1), // 0.3 "e" + string_view(sentence_str.data() + 22, 13), // 0.4 " gerne" + string_view(sentence_str.data() + 35, 11), // 0.5 " dieses" + string_view(sentence_str.data() + 46, 5), // 0.6 " Auto" + string_view(sentence_str.data() + 51, 1), // 0.7 "." + string_view(sentence_str.data() + 52, 0), // 0.8 "" + }; + target.appendSentence("", sentence.begin(), sentence.end()); + target.appendEndingWhitespace("

"); + + CHECK(asTokens(response.target) == asTokens(target)); + } +} + +TEST_CASE("End-to-end translation when no words with markup align", "[!mayfail]") { + std::string input("

I like to drive this car.

"); + HTML html(std::move(input), true); + CHECK(input == "I like to drive this car."); + + Response response; + + // clang-format off + response.alignments = std::vector>>{{ + {0.5360, 0.4405, 0.0142, 0.0061, 0.0029, 0.0001, 0.0000, 0.0001}, + {0.0451, 0.0602, 0.5120, 0.2584, 0.1145, 0.0062, 0.0019, 0.0017}, + {0.0392, 0.0009, 0.6535, 0.2293, 0.0492, 0.0199, 0.0014, 0.0067}, + {0.0007, 0.0036, 0.0112, 0.0118, 0.9209, 0.0449, 0.0050, 0.0019}, + {0.0000, 0.0004, 0.0008, 0.0047, 0.0163, 0.9683, 0.0045, 0.0050}, + {0.0011, 0.0046, 0.0039, 0.0090, 0.0023, 0.0024, 0.9648, 0.0119}, + {0.0840, 0.0744, 0.1545, 0.1330, 0.1818, 0.1722, 0.0859, 0.1143}, + }}; + // clang-format on + + { + std::string sentence_str("I like to drive this car."); + std::vector sentence{ + string_view(sentence_str.data() + 0, 1), // 0.0 "I" + string_view(sentence_str.data() + 1, 5), // 0.1 " like" + string_view(sentence_str.data() + 6, 3), // 0.2 " to" + string_view(sentence_str.data() + 9, 6), // 0.3 " drive" + string_view(sentence_str.data() + 15, 5), // 0.4 " this" + string_view(sentence_str.data() + 20, 4), // 0.5 " car" + string_view(sentence_str.data() + 24, 1), // 0.6 "." + string_view(sentence_str.data() + 25, 0), // 0.7 [EOS] + }; + response.source.appendSentence("", sentence.begin(), sentence.end()); + } + + { + std::string sentence_str("Rád řídím to auto."); + std::vector sentence{ + string_view(sentence_str.data() + 0, 4), // 0.0 "Rád" + string_view(sentence_str.data() + 4, 6), // 0.1 " říd" + string_view(sentence_str.data() + 10, 3), // 0.2 "ím" + string_view(sentence_str.data() + 13, 3), // 0.3 "_to" + string_view(sentence_str.data() + 16, 5), // 0.4 " auto" + string_view(sentence_str.data() + 21, 1), // 0.5 "." + string_view(sentence_str.data() + 22, 0), // 0.6 [EOS] + }; + response.target.appendSentence("", sentence.begin(), sentence.end()); + } + + html.restore(response); + + { + AnnotatedText source; + std::string sentence_str("

I like to drive this car."); + std::vector sentence{ + string_view(sentence_str.data() + 0, 4), // 0.0 "

I" + string_view(sentence_str.data() + 4, 8), // 0.1 " like" + string_view(sentence_str.data() + 12, 7), // 0.2 " to" + string_view(sentence_str.data() + 19, 9), // 0.3 " drive" + string_view(sentence_str.data() + 28, 9), // 0.4 " this" + string_view(sentence_str.data() + 37, 4), // 0.5 " car" + string_view(sentence_str.data() + 41, 1), // 0.6 "." + string_view(sentence_str.data() + 42, 0), // 0.7 "" + }; + source.appendSentence("", sentence.begin(), sentence.end()); + source.appendEndingWhitespace("

"); + + CHECK(asTokens(response.source) == asTokens(source)); + } + + { + AnnotatedText target; + std::string sentence_str("

Rád řídím to auto."); + std::vector sentence{ + string_view(sentence_str.data() + 0, 7), // 0.0 "

Rád" + string_view(sentence_str.data() + 7, 13), // 0.1 " říd" + string_view(sentence_str.data() + 20, 3), // 0.2 "ím" + string_view(sentence_str.data() + 23, 10), // 0.3 "_to" + string_view(sentence_str.data() + 33, 5), // 0.4 " auto" + string_view(sentence_str.data() + 38, 1), // 0.5 "." + string_view(sentence_str.data() + 39, 0), // 0.6 [EOS] + }; + target.appendSentence("", sentence.begin(), sentence.end()); + target.appendEndingWhitespace("

"); + + CHECK(asTokens(response.target) == asTokens(target)); + } +} + +// TEST_CASE("") \ No newline at end of file diff --git a/inference/src/tests/units/html_tests.h b/inference/src/tests/units/html_tests.h new file mode 100644 index 000000000..0407b65b2 --- /dev/null +++ b/inference/src/tests/units/html_tests.h @@ -0,0 +1,9 @@ +#pragma once +#include + +#include "translator/definitions.h" + +std::ostream &operator<<(std::ostream &out, marian::bergamot::ByteRange const &b); + +std::ostream &operator<<(std::ostream &out, + std::pair const &b); diff --git a/inference/src/tests/units/quality_estimator_tests.cpp b/inference/src/tests/units/quality_estimator_tests.cpp new file mode 100644 index 000000000..e11c07a7b --- /dev/null +++ b/inference/src/tests/units/quality_estimator_tests.cpp @@ -0,0 +1,62 @@ +#include "quality_estimator_tests.h" + +#include "catch.hpp" +#include "translator/quality_estimator.h" + +using namespace marian::bergamot; + +SCENARIO("Logistic Regressor test", "[QualityEstimator]") { + GIVEN("A feature matrix") { + const std::vector > features = {{-0.3, -0.3, 1.0, -0.183683336}, + {-0.0001, -0.0001, 1.0, -0.183683336}, + {-0.002, -0.002, 1.0, -0.183683336}, + {-0.5, -0.5, 1.0, -0.183683336}, + {-0.15, -0.2, 2.0, -0.183683336}}; + + LogisticRegressorQualityEstimator::Matrix featureMatrix(features.size(), features.begin()->size()); + + for (int i = 0; i < features.size(); ++i) { + for (int j = 0; j < features.begin()->size(); ++j) { + featureMatrix.at(i, j) = features[i][j]; + } + } + + AND_GIVEN("A LogistRegressor") { + LogisticRegressorQualityEstimator::Array coefficients = {0.99000001, 0.899999976, -0.200000003, 0.5}; + const float intercept = {-0.300000012}; + + LogisticRegressorQualityEstimator::Scale scale; + scale.stds = {0.200000003, 0.300000012, 2.5, 0.100000001}; + scale.means = {-0.100000001, -0.769999981, 5, -0.5}; + + LogisticRegressorQualityEstimator lrQE(std::move(scale), std::move(coefficients), intercept); + + WHEN("It's call predict") { + const std::vector prediction = lrQE.predict(featureMatrix); + + THEN("return the prediction") { + CHECK(prediction == std::vector{-2.14596, -4.41793, -4.403, -0.93204, -3.03343}); + } + } + + WHEN("LR is construct by aligned memory") { + const auto lrQEAlignedMemory = LogisticRegressorQualityEstimator::fromAlignedMemory(lrQE.toAlignedMemory()); + + WHEN("It's call predict") { + const std::vector prediction = lrQEAlignedMemory.predict(featureMatrix); + + THEN("return the prediction") { + CHECK(prediction == std::vector{-2.14596, -4.41793, -4.403, -0.93204, -3.03343}); + } + } + } + } + } +} + +bool operator==(const std::vector& value1, const std::vector& value2) { + return std::equal(value1.begin(), value1.end(), value2.begin(), value2.end(), [](const auto& a, const auto& b) { + auto value = Approx(b).epsilon(0.001); + return a == value; + }); +} diff --git a/inference/src/tests/units/quality_estimator_tests.h b/inference/src/tests/units/quality_estimator_tests.h new file mode 100644 index 000000000..37cba3ef3 --- /dev/null +++ b/inference/src/tests/units/quality_estimator_tests.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +bool operator==(const std::vector& value1, const std::vector& value2); diff --git a/inference/src/tests/units/run_tests.cpp b/inference/src/tests/units/run_tests.cpp new file mode 100644 index 000000000..0c7c351f4 --- /dev/null +++ b/inference/src/tests/units/run_tests.cpp @@ -0,0 +1,2 @@ +#define CATCH_CONFIG_MAIN +#include "catch.hpp" diff --git a/inference/src/tests/units/xh_scanner_tests.cpp b/inference/src/tests/units/xh_scanner_tests.cpp new file mode 100644 index 000000000..0fdfa3566 --- /dev/null +++ b/inference/src/tests/units/xh_scanner_tests.cpp @@ -0,0 +1,261 @@ +#include + +#include "catch.hpp" +#include "translator/xh_scanner.h" + +TEST_CASE("scan element with attributes") { + markup::instream in("
"); + markup::Scanner scanner(in); + + CHECK(scanner.next() == markup::Scanner::TT_TAG_START); + CHECK(scanner.tag() == "div"); + + CHECK(scanner.next() == markup::Scanner::TT_ATTRIBUTE); + CHECK(scanner.attribute() == "id"); + CHECK(scanner.value() == "test"); + + CHECK(scanner.next() == markup::Scanner::TT_ATTRIBUTE); + CHECK(scanner.attribute() == "class"); + CHECK(scanner.value() == "a b c "); + + CHECK(scanner.next() == markup::Scanner::TT_EOF); +} + +TEST_CASE("scan element with valueless attributes") { + markup::instream in(""); + markup::Scanner scanner(in); + + CHECK(scanner.next() == markup::Scanner::TT_TAG_START); + CHECK(scanner.tag() == "input"); + + CHECK(scanner.next() == markup::Scanner::TT_ATTRIBUTE); + CHECK(scanner.attribute() == "checked"); + CHECK(scanner.value() == ""); + + CHECK(scanner.next() == markup::Scanner::TT_ATTRIBUTE); + CHECK(scanner.attribute() == "hidden"); + CHECK(scanner.value() == ""); + + CHECK(scanner.next() == markup::Scanner::TT_EOF); +} + +TEST_CASE("scan element with unquoted attributes") { + markup::instream in("