Skip to content

Commit

Permalink
refactor: use app state and tokeniser pool instead
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Sep 18, 2024
1 parent ebb3577 commit 710f9f6
Show file tree
Hide file tree
Showing 20 changed files with 208 additions and 142 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
FROM ghcr.io/winstxnhdw/nllb-api:main

ENV SERVER_PORT=7860
ENV OMP_NUM_THREADS=2
ENV OMP_NUM_THREADS=1
ENV CT2_USE_EXPERIMENTAL_PACKED_GEMM=1
ENV CT2_FORCE_CPU_ISA=AVX512
ENV WORKER_COUNT=1
ENV WORKER_COUNT=2

EXPOSE $SERVER_PORT
6 changes: 3 additions & 3 deletions server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from server.api import v3
from server.config import Config
from server.lifespans import load_fasttext_model, load_nllb_model
from server.lifespans import load_fasttext_model, load_translator_model


def exception_handler(_, exception: Exception) -> Response[dict[str, str]]:
Expand All @@ -21,7 +21,7 @@ def exception_handler(_, exception: Exception) -> Response[dict[str, str]]:
request (Request) : the request
exception (Exception) : the exception
"""
getLogger('custom.access').error('Application Exception', exc_info=exception)
getLogger('custom.access').error('', exc_info=exception)

return Response(
content={'detail': 'Internal Server Error'},
Expand Down Expand Up @@ -52,5 +52,5 @@ def app() -> Litestar:
openapi_config=openapi_config,
exception_handlers={HTTP_500_INTERNAL_SERVER_ERROR: exception_handler},
route_handlers=[v3],
lifespan=[load_fasttext_model, load_nllb_model],
lifespan=[load_fasttext_model, load_translator_model],
)
5 changes: 3 additions & 2 deletions server/api/v3/detect_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from litestar.openapi.spec.example import Example
from litestar.params import Parameter

from server.features import LanguageDetector
from server.schemas.v1 import Language
from server.state import AppState


@get('/detect_language', sync_to_thread=False, cache=True)
def detect_language(
state: AppState,
text: Annotated[
str,
Parameter(
Expand All @@ -28,4 +29,4 @@ def detect_language(
-------
the `/detect_language` route detects the language of the input text
"""
return Language(language=LanguageDetector.detect(text))
return Language(language=state.language_detector.detect(text))
16 changes: 11 additions & 5 deletions server/api/v3/translate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from asyncio import wrap_future
from concurrent.futures import ThreadPoolExecutor
from typing import Annotated, get_args

from litestar import Controller, get, post
from litestar.openapi.spec.example import Example
from litestar.params import Parameter
from litestar.status_codes import HTTP_200_OK

from server.features import TranslatorPool
from server.features.types import Languages
from server.schemas.v1 import Translated, Translation
from server.state import AppState
from server.types import Languages


class TranslateController(Controller):
Expand All @@ -18,10 +20,12 @@ class TranslateController(Controller):
"""

path = '/translate'
thread_pool = ThreadPoolExecutor()

@get(cache=True)
async def translate_get(
self,
state: AppState,
text: Annotated[
str,
Parameter(
Expand Down Expand Up @@ -54,13 +58,15 @@ async def translate_get(
-------
the GET variant of the `/translate` route
"""
return Translated(result=await TranslatorPool.translate(text, source, target))
translate_job = self.thread_pool.submit(state.translator.translate, text, source, target)
return Translated(result=await wrap_future(translate_job))

@post(status_code=HTTP_200_OK)
async def translate_post(self, data: Translation) -> Translated:
async def translate_post(self, state: AppState, data: Translation) -> Translated:
"""
Summary
-------
the POST variant of the `/translate` route
"""
return Translated(result=await TranslatorPool.translate(data.text, data.source, data.target))
translate_job = self.thread_pool.submit(state.translator.translate, data.text, data.source, data.target)
return Translated(result=await wrap_future(translate_job))
1 change: 0 additions & 1 deletion server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class Config(BaseSettings):
server_port: int = 49494
server_root_path: str = '/api'
worker_count: int = 1
translator_pool_count: int = 2
use_cuda: bool = False
translator_model_name: str = 'winstxnhdw/nllb-200-distilled-1.3B-ct2-int8'
language_detector_model_name: str = 'facebook/fasttext-language-identification'
4 changes: 2 additions & 2 deletions server/features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from server.features.detect_language import LanguageDetector as LanguageDetector
from server.features.translator import TranslatorPool as TranslatorPool
from server.features.language_detector import get_language_detector as get_language_detector
from server.features.translator import get_translator as get_translator
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fasttext.FastText import _FastText as FastText # type: ignore

from server.config import Config
from server.features.types.languages import Languages
from server.types.languages import Languages
from server.helpers import huggingface_file_download


Expand All @@ -14,27 +14,16 @@ class LanguageDetector:
Methods
-------
load() -> None
load the model
detect(text: str) -> Languages
detect the language of the input text
"""

model: FastText
__slots__ = ('model',)

@classmethod
def load(cls):
"""
Summary
-------
download and load the model
"""
model_path = huggingface_file_download(Config.language_detector_model_name, 'model.bin')
cls.model: FastText = load_model(model_path)
def __init__(self, model: FastText):
self.model = model

@classmethod
def detect(cls, text: str) -> Languages:
def detect(self, text: str) -> Languages:
"""
Summary
-------
Expand All @@ -48,4 +37,17 @@ def detect(cls, text: str) -> Languages:
-------
language (Languages) : the detected language
"""
return cls.model.predict(text, k=5)[0][0][9:] # type: ignore
return self.model.predict(text, k=5)[0][0][9:] # type: ignore


def get_language_detector() -> LanguageDetector:
"""
Summary
-------
get the language detector
Returns
-------
language_detector (LanguageDetector) : the language detector
"""
return LanguageDetector(load_model(huggingface_file_download(Config.language_detector_model_name, 'model.bin')))
139 changes: 78 additions & 61 deletions server/features/translator.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,101 @@
from asyncio import sleep, wrap_future
from concurrent.futures import ThreadPoolExecutor
from itertools import cycle
from typing import Iterator, Self

from ctranslate2 import Translator as CTranslator
from transformers.models.nllb.tokenization_nllb_fast import NllbTokenizerFast

from server.config import Config
from server.features.types import Languages, TranslatorOptions
from server.helpers import huggingface_download
from server.types import Languages


class Translator:
class Tokeniser:
"""
Summary
-------
a class for the NLLB translator
context manager for the NLLB tokeniser
Methods
-------
translate(input: str, source_language: str, target_language: str) -> str
translate the input from the source language to the target language
"""
encode(text: str) -> list[str]
encode the input text
__slots__ = ('translator', 'tokeniser', 'lock')
decode(tokens: str | list[str]) -> str
decode the input tokens
"""

def __init__(self):
model_path = huggingface_download(Config.translator_model_name)
options: TranslatorOptions = {
'model_path': model_path,
'device': 'cuda' if Config.use_cuda else 'cpu',
'compute_type': 'auto',
'inter_threads': Config.worker_count,
}
__slots__ = ('tokeniser', 'lock')

self.translator = CTranslator(**options)
def __init__(self, model_path: str):
self.tokeniser: NllbTokenizerFast = NllbTokenizerFast.from_pretrained(model_path, local_files_only=True)
self.lock = False

def __call__(self, source_language: Languages) -> Self:
self.tokeniser.src_lang = source_language
return self

def __enter__(self):
self.lock = True

def __exit__(self, *_):
self.lock = False

def translate(self, text: str, source_language: Languages, target_language: Languages) -> str:
def encode(self, text: str) -> list[str]:
"""
Summary
-------
translate the input from the source language to the target language without the Python GIL
encode the input text
Parameters
----------
input (str) : the input to translate
source_language (Languages) : the source language
target_language (Languages) : the target language
text (str) : the input text
Returns
-------
translated_text (str) : the translated text
tokens (list[str]) : the tokenised input text
"""
self.tokeniser.src_lang = source_language
return self.tokeniser(text).tokens()

results = self.translator.translate_batch(
(self.tokeniser(text).tokens(),),
([target_language],),
batch_type='tokens',
beam_size=1,
)
def decode(self, tokens: str | list[str]) -> str:
"""
Summary
-------
decode the input tokens
return self.tokeniser.decode(
self.tokeniser.convert_tokens_to_ids(results[0].hypotheses[0][1:]),
clean_up_tokenization_spaces=False,
)
Parameters
----------
tokens (str | list[str]) : the input tokens
Returns
-------
text (str) : the decoded text
"""
return self.tokeniser.decode(self.tokeniser.convert_tokens_to_ids(tokens), clean_up_tokenization_spaces=False)

class TranslatorPool:

class Translator:
"""
Summary
-------
a static class that encapsulates a pool of translators
a class for the NLLB translator
Methods
-------
load() -> None
load the translator pool
translate(text: str, source_language: Languages, target_language: Languages) -> str
translate the input from the source language to the target language using a pool of translators
translate(input: str, source_language: str, target_language: str) -> str
translate the input from the source language to the target language
"""

@classmethod
def load(cls):
"""
Summary
-------
load the translator pool
"""
cls.thread_pool = ThreadPoolExecutor()
cls.pool = cycle([Translator() for _ in range(Config.translator_pool_count)])
__slots__ = ('tokeniser_pool', 'translator')

def __init__(self, translator: CTranslator, tokeniser_pool: Iterator[Tokeniser]):
self.tokeniser_pool = tokeniser_pool
self.translator = translator

@classmethod
async def translate(cls, text: str, source_language: Languages, target_language: Languages) -> str:
def translate(self, text: str, source_language: Languages, target_language: Languages) -> str:
"""
Summary
-------
translate the input from the source language to the target language using a pool of translators
translate the input from the source language to the target language using a tokeniser pool
Parameters
----------
Expand All @@ -116,14 +107,40 @@ async def translate(cls, text: str, source_language: Languages, target_language:
-------
translated_text (str) : the translated text
"""
for translator in cls.pool:
if translator.lock:
await sleep(0)
for tokeniser in self.tokeniser_pool:
if tokeniser.lock:
continue

with translator:
return await wrap_future(
cls.thread_pool.submit(translator.translate, text, source_language, target_language)
with tokeniser(source_language):
results = self.translator.translate_batch(
(tokeniser.encode(text),),
([target_language],),
batch_type='tokens',
beam_size=1,
)

raise RuntimeError('Translator pool has been exhausted. This should never happen.')
return tokeniser.decode(results[0].hypotheses[0][1:])

raise RuntimeError('Tokeniser pool has been exhausted. This should never happen.')


def get_translator() -> Translator:
"""
Summary
-------
get the translator pool
Returns
-------
translator (TranslatorPool) : the translator pool
"""
model_path = huggingface_download(Config.translator_model_name)
tokeniser_pool = cycle([Tokeniser(model_path) for _ in range(Config.worker_count)])
translator = CTranslator(
model_path,
'cuda' if Config.use_cuda else 'cpu',
compute_type='auto',
inter_threads=Config.worker_count,
)

return Translator(translator, tokeniser_pool)
4 changes: 0 additions & 4 deletions server/features/types/__init__.py

This file was deleted.

Loading

0 comments on commit 710f9f6

Please sign in to comment.