Skip to content

Commit

Permalink
[serving] Scaffolding for llm serving. (#409)
Browse files Browse the repository at this point in the history
(needs bump for IREE runtime updates)
  • Loading branch information
stellaraccident authored Feb 15, 2024
1 parent e955627 commit f1c3d16
Show file tree
Hide file tree
Showing 13 changed files with 1,929 additions and 56 deletions.
4 changes: 2 additions & 2 deletions core/iree-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
iree-compiler==20240207.794
iree-runtime==20240207.794
iree-compiler==20240215.802
iree-runtime==20240215.802
63 changes: 63 additions & 0 deletions serving/tests/framework/device_session_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest

from turbine_serving.framework.session import (
DeviceSession,
)


@pytest.fixture
def local_device_session():
session = DeviceSession(uri="local-task")
yield session
session.shutdown()


def test_start_shutdown_no_host_contexts(local_device_session: DeviceSession):
ms = local_device_session.create_module_set("default")
ms.initialize()


def test_host_context_start_stop(local_device_session: DeviceSession):
ms = local_device_session.create_module_set("default")
ms.initialize()
hc = ms.host_context


def test_host_context_scheduling(local_device_session: DeviceSession):
device = local_device_session.device
ms = local_device_session.create_module_set("default")
ms.initialize()
hc = ms.host_context

sem = device.create_semaphore(0)

async def task1():
print("[coro1] test_host_context_scheduling.task")
await hc.on_semaphore(sem, 1, True)
print("[coro1] await completed")
sem.signal(2)

async def task2():
print("[coro2] waiting for 2")
await hc.on_semaphore(sem, 2, True)
sem.fail("Fail from task2")

f1 = hc.run_concurrent(task1())
f2 = hc.run_concurrent(task2())
sem.signal(1)
print("[main] Waiting for semaphore")

# Ensure task completion. Important to consume to ensure that exceptions
# propagate.
f1.result()
f2.result()

print("[main] Waiting on semaphore payload 3")
with pytest.raises(Exception, match="Fail from task2"):
sem.wait(3)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(self, args):
[
sys.executable,
"-m",
"turbine_serving.llm.entrypoints.api_server",
"turbine_serving.llm.api.rest_server",
"--testing-mock-service",
]
+ args,
env=env,
Expand All @@ -39,9 +40,9 @@ def _wait_for_ready(self):
except Exception as e:
if self.process.poll() is not None:
raise RuntimeError("API server processs terminated") from e
time.sleep(0.25)
time.sleep(1.0)
if time.time() - start > 30:
raise RuntimeError("Timeout waiting for server start") from e
raise RuntimeError("Timeout waiting for server start")

def __del__(self):
try:
Expand All @@ -59,5 +60,30 @@ def server():
yield runner


def test_basic(server: ServerRunner):
def test_health(server: ServerRunner):
# Health check is part of getting the fixture.
...


def test_generate_non_streaming(server: ServerRunner):
resp = requests.post(
f"{server.url}/generate",
json={
"prompt": "Hi Bob",
},
)
resp.raise_for_status()
d = resp.json()
assert d["text"] == "Hi Bob", repr(d)


def test_generate_streaming(server: ServerRunner):
resp = requests.post(
f"{server.url}/generate", json={"prompt": "Hi Bob!", "stream": True}
)
resp.raise_for_status()
full_contents = resp.content
expected_contents = b'{"text": "Hi Bob!"}\x00' * 5
assert (
full_contents == expected_contents
), f"Expected {expected_contents!r} vs {full_contents!r}"
130 changes: 130 additions & 0 deletions serving/tests/llm/service_v1_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest

from iree.runtime import ( # type: ignore
HalElementType,
)

from turbine_serving.framework.session import DeviceSession
from turbine_serving.llm.config import (
CacheParams,
ModelParams,
ServiceParams,
)

from turbine_serving.llm.service import (
GenerateRequest,
GenerateResponsePart,
)

from turbine_serving.llm.attn_block_cache import (
create_attn_block_cache_module,
AttnBlockCache,
)

from turbine_serving.llm.impl.service_v1 import (
GenerateServiceV1,
)

from turbine_serving.llm.testing.fake_v1_module import (
create_fake_module,
)


@pytest.fixture
def cache_params(model_params: ModelParams) -> CacheParams:
return CacheParams(model=model_params, device_block_count=128, block_pos_stride=16)


@pytest.fixture
def model_params() -> ModelParams:
return ModelParams(
module_name="AwesomeLLM",
module_abi_version=1,
attn_dtype=HalElementType.FLOAT_16,
max_seq_len=128,
transformer_block_count=32,
attn_head_count=32,
attn_head_dim=128,
prefill_batch_sizes=[1, 4, 16],
decode_batch_sizes=[1, 4, 16],
)


@pytest.fixture
def uninitialized_session(model_params: ModelParams):
from iree.runtime._binding import disable_leak_checker # type: ignore

disable_leak_checker()
session = DeviceSession(uri="local-task", queue_count=2)
yield session
session.shutdown()
del session


@pytest.fixture
def attn_block_cache(
uninitialized_session: DeviceSession, cache_params: CacheParams
) -> AttnBlockCache:
return AttnBlockCache(uninitialized_session, cache_params)


@pytest.fixture
def session(
model_params: ModelParams,
uninitialized_session: DeviceSession,
attn_block_cache: AttnBlockCache,
):
session = uninitialized_session
lms = session.create_module_set("AwesomeLLM", context_count=1)
lms.add(
create_attn_block_cache_module(attn_block_cache),
create_fake_module(session.device, "AwesomeLLM", model_params=model_params),
)
lms.initialize()
return session


@pytest.fixture
def service(
session: DeviceSession,
cache_params: CacheParams,
model_params: ModelParams,
attn_block_cache: AttnBlockCache,
):
params = ServiceParams(cache=cache_params, model=model_params)
return GenerateServiceV1(session=session, params=params, cache=attn_block_cache)


def test_single(service: GenerateServiceV1):
state = service.start()

async def task():
await state.set_sequences(
requests=[
GenerateRequest(
"1",
"hello, tell me a story",
[3, 4, 5, 12, 23, 88, 10, 2, 5, 9, 12, 13, 99, 56, 33, 124, 73],
),
GenerateRequest("2", "goodbye", [9, 10]),
]
)
guarded_outputs = await state.prefill()
prefill_ids = await guarded_outputs.resolve(state.host_context)
print(
"PREFILL IDS:",
prefill_ids,
":\n",
prefill_ids.map().asarray(
prefill_ids.shape, HalElementType.map_to_dtype(prefill_ids.element_type)
),
)
await state.recycle()

state.host_context.run_sync(task())
48 changes: 48 additions & 0 deletions serving/turbine_serving/framework/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
import os
import sys


# Whether debug assertions are disabled.
NDEBUG: bool = False

_default_log_level = os.getenv("TURBINE_LOG_LEVEL", "DEBUG")


class DefaultFormatter(logging.Formatter):
def __init__(self):
super().__init__(
"%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s",
"%m-%d %H:%M:%S",
)


def _setup_logger():
root_logger = logging.getLogger("turbine_serving")
root_logger.setLevel(logging.DEBUG)
default_handler = logging.StreamHandler(sys.stderr)
default_handler.flush = sys.stderr.flush
default_handler.setLevel(_default_log_level)
default_handler.setFormatter(DefaultFormatter())
root_logger.addHandler(default_handler)
root_logger.propagate = False
return root_logger, default_handler


root_logger, default_handler = _setup_logger()

logging.getLogger("asyncio").addHandler(default_handler)


def get_logger(name: str):
logger = logging.getLogger(name)
logger.setLevel(_default_log_level)
logger.addHandler(default_handler)
logger.propagate = False
return logger
Loading

0 comments on commit f1c3d16

Please sign in to comment.