Skip to content

Commit

Permalink
perf/build/docs: use translator threads instead of workers
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Sep 18, 2024
1 parent 323f75d commit fe37941
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ ENV SERVER_PORT=7860
ENV OMP_NUM_THREADS=1
ENV CT2_USE_EXPERIMENTAL_PACKED_GEMM=1
ENV CT2_FORCE_CPU_ISA=AVX512
ENV WORKER_COUNT=2
ENV TRANSLATOR_THREADS=4

EXPOSE $SERVER_PORT
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ Chinese (Simplified) | zho_Hans
Chinese (Traditional) | zho_Hant
Standard Malay | zsm_Latn
Zulu | zul_Latn

</details>

```bash
Expand All @@ -256,7 +257,7 @@ docker run --rm \

### Optimisation

You can pass the following environment variables to optimise the API for your own uses. The value of `OMP_NUM_THREADS` increases the number of threads used to translate a given batch of inputs, while `WORKER_COUNT` increases the number of workers used to handle requests in parallel.
You can pass the following environment variables to optimise the API for your own uses. The value of `OMP_NUM_THREADS` increases the number of threads used to translate a given batch of inputs, while `TRANSLATOR_THREADS` increases the number of threads used to handle translate requests in parallel. It is recommended to not modify `WORKER_COUNT` as spawning multiple workers can lead to increased memory usage and poorer performance.

> [!IMPORTANT]\
> `OMP_NUM_THREADS` $\times$ `WORKER_COUNT` should not exceed the physical number of cores on your machine.
Expand All @@ -265,6 +266,7 @@ You can pass the following environment variables to optimise the API for your ow
docker run --rm \
-e SERVER_PORT=7860 \
-e OMP_NUM_THREADS=6 \
-e TRANSLATOR_THREADS=2 \
-e WORKER_COUNT=1 \
-p 7860:7860 \
ghcr.io/winstxnhdw/nllb-api:main
Expand Down
17 changes: 6 additions & 11 deletions server/api/v3/translate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from asyncio import wrap_future
from concurrent.futures import ThreadPoolExecutor
from typing import Annotated, get_args

from litestar import Controller, get, post
Expand All @@ -20,10 +18,9 @@ class TranslateController(Controller):
"""

path = '/translate'
thread_pool = ThreadPoolExecutor()

@get(cache=True)
async def translate_get(
@get(cache=True, sync_to_thread=True)
def translate_get(
self,
state: AppState,
text: Annotated[
Expand Down Expand Up @@ -58,15 +55,13 @@ async def translate_get(
-------
the GET variant of the `/translate` route
"""
translate_job = self.thread_pool.submit(state.translator.translate, text, source, target)
return Translated(result=await wrap_future(translate_job))
return Translated(result=state.translator.translate(text, source, target))

@post(status_code=HTTP_200_OK)
async def translate_post(self, state: AppState, data: Translation) -> Translated:
@post(status_code=HTTP_200_OK, sync_to_thread=True)
def translate_post(self, state: AppState, data: Translation) -> Translated:
"""
Summary
-------
the POST variant of the `/translate` route
"""
translate_job = self.thread_pool.submit(state.translator.translate, data.text, data.source, data.target)
return Translated(result=await wrap_future(translate_job))
return Translated(result=state.translator.translate(data.text, data.source, data.target))
4 changes: 4 additions & 0 deletions server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class Config(BaseSettings):
server_port (int) : the port to run the server on
server_root_path (str) : the root path for the server
worker_count (int) : the number of workers to use
translator_threads (int) : the number of threads for the translator
translator_beam_size (int) : the beam size for the translator
use_cuda (bool) : whether to use CUDA for inference
translator_model_name (str) : the name of the translator model
language_detector_model_name (str) : the name of the language detector model
Expand All @@ -40,6 +42,8 @@ class Config(BaseSettings):
server_port: int = 49494
server_root_path: str = '/api'
worker_count: int = 1
translator_threads: int = 1
translator_beam_size: int = 1
use_cuda: bool = False
translator_model_name: str = 'winstxnhdw/nllb-200-distilled-1.3B-ct2-int8'
language_detector_model_name: str = 'facebook/fasttext-language-identification'
18 changes: 10 additions & 8 deletions server/features/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,18 @@ class Translator:
translate the input from the source language to the target language
"""

__slots__ = ('tokeniser_pool', 'translator')
__slots__ = ('translator', 'tokeniser_pool', 'beam_size')

def __init__(self, translator: CTranslator, tokeniser_pool: Iterator[Tokeniser]):
def __init__(self, translator: CTranslator, tokeniser_pool: Iterator[Tokeniser], beam_size: int):
self.tokeniser_pool = tokeniser_pool
self.translator = translator
self.beam_size = beam_size

def translate(self, text: str, source_language: Languages, target_language: Languages) -> str:
"""
Summary
-------
translate the input from the source language to the target language using a tokeniser pool
translate the input from the source language to the target language using a pool of tokenisers
Parameters
----------
Expand All @@ -107,6 +108,7 @@ def translate(self, text: str, source_language: Languages, target_language: Lang
-------
translated_text (str) : the translated text
"""

for tokeniser in self.tokeniser_pool:
if tokeniser.lock:
continue
Expand All @@ -116,10 +118,10 @@ def translate(self, text: str, source_language: Languages, target_language: Lang
(tokeniser.encode(text),),
([target_language],),
batch_type='tokens',
beam_size=1,
beam_size=self.beam_size,
)

return tokeniser.decode(results[0].hypotheses[0][1:])
return tokeniser.decode(results[0].hypotheses[0][1:])

raise RuntimeError('Tokeniser pool has been exhausted. This should never happen.')

Expand All @@ -135,12 +137,12 @@ def get_translator() -> Translator:
translator (TranslatorPool) : the translator pool
"""
model_path = huggingface_download(Config.translator_model_name)
tokeniser_pool = cycle([Tokeniser(model_path) for _ in range(Config.worker_count)])
tokeniser_pool = cycle([Tokeniser(model_path) for _ in range(Config.translator_threads)])
translator = CTranslator(
model_path,
'cuda' if Config.use_cuda else 'cpu',
compute_type='auto',
inter_threads=Config.worker_count,
inter_threads=Config.translator_threads,
)

return Translator(translator, tokeniser_pool)
return Translator(translator, tokeniser_pool, Config.translator_beam_size)

0 comments on commit fe37941

Please sign in to comment.