Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ml): model sessions #10559

Merged
merged 4 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions machine-learning/ann/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ class Ann(metaclass=_Singleton):
def __init__(self, log_level: int = 3, tuning_level: int = 1, tuning_file: str | None = None) -> None:
if not is_available:
raise RuntimeError("libann is not available!")
if tuning_file and not exists(tuning_file):
raise ValueError("tuning_file must point to an existing (possibly empty) file!")
if tuning_level == 0 and tuning_file is None:
raise ValueError("tuning_level == 0 reads existing tuning information and requires a tuning_file")
if tuning_level < 0 or tuning_level > 3:
Expand All @@ -67,6 +65,12 @@ def __init__(self, log_level: int = 3, tuning_level: int = 1, tuning_file: str |
self.input_shapes: dict[int, tuple[tuple[int], ...]] = {}
self.ann: int | None = None
self.new()

if self.tuning_file is not None:
# make sure tuning file exists (without clearing contents)
# once filled, the tuning file reduces the cost/time of the first
# inference after model load by 10s of seconds
open(self.tuning_file, "a").close()

def new(self) -> None:
if self.ann is None:
Expand Down Expand Up @@ -95,17 +99,19 @@ def load(
model_path: str,
fast_math: bool = True,
fp16: bool = False,
save_cached_network: bool = False,
cached_network_path: str | None = None,
) -> int:
if not model_path.endswith((".armnn", ".tflite", ".onnx")):
raise ValueError("model_path must be a file with extension .armnn, .tflite or .onnx")
if not exists(model_path):
raise ValueError("model_path must point to an existing file!")

save_cached_network = False
if cached_network_path is not None and not exists(cached_network_path):
raise ValueError("cached_network_path must point to an existing (possibly empty) file!")
if save_cached_network and cached_network_path is None:
raise ValueError("save_cached_network is True, cached_network_path must be specified!")
save_cached_network = True
# create empty model cache file
open(cached_network_path, "a").close()

net_id: int = libann.load(
self.ann,
model_path.encode(),
Expand Down
71 changes: 69 additions & 2 deletions machine-learning/app/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from numpy.typing import NDArray
from PIL import Image

from app.config import log

from .main import app


Expand Down Expand Up @@ -96,12 +98,77 @@ def clip_tokenizer_cfg() -> dict[str, Any]:


@pytest.fixture(scope="function")
def providers(request: pytest.FixtureRequest) -> Iterator[dict[str, Any]]:
def providers(request: pytest.FixtureRequest) -> Iterator[mock.Mock]:
marker = request.node.get_closest_marker("providers")
if marker is None:
raise ValueError("Missing marker 'providers'")

providers = marker.args[0]
with mock.patch("app.models.base.ort.get_available_providers") as mocked:
with mock.patch("app.sessions.ort.ort.get_available_providers") as mocked:
mocked.return_value = providers
yield providers


@pytest.fixture(scope="function")
def ort_pybind() -> Iterator[mock.Mock]:
with mock.patch("app.sessions.ort.ort.capi._pybind_state") as mocked:
yield mocked


@pytest.fixture(scope="function")
def ov_device_ids(request: pytest.FixtureRequest, ort_pybind: mock.Mock) -> Iterator[mock.Mock]:
marker = request.node.get_closest_marker("ov_device_ids")
if marker is None:
raise ValueError("Missing marker 'ov_device_ids'")
ort_pybind.get_available_openvino_device_ids.return_value = marker.args[0]
return ort_pybind


@pytest.fixture(scope="function")
def ort_session() -> Iterator[mock.Mock]:
with mock.patch("app.sessions.ort.ort.InferenceSession") as mocked:
yield mocked


@pytest.fixture(scope="function")
def ann_session() -> Iterator[mock.Mock]:
with mock.patch("app.sessions.ann.Ann") as mocked:
yield mocked


@pytest.fixture(scope="function")
def rmtree() -> Iterator[mock.Mock]:
with mock.patch("app.models.base.rmtree", autospec=True) as mocked:
mocked.avoids_symlink_attacks = True
yield mocked


@pytest.fixture(scope="function")
def path() -> Iterator[mock.Mock]:
path = mock.MagicMock()
path.exists.return_value = True
path.is_dir.return_value = True
path.is_file.return_value = True
path.with_suffix.return_value = path
path.return_value = path

with mock.patch("app.models.base.Path", return_value=path) as mocked:
yield mocked


@pytest.fixture(scope="function")
def info() -> Iterator[mock.Mock]:
with mock.patch.object(log, "info") as mocked:
yield mocked


@pytest.fixture(scope="function")
def warning() -> Iterator[mock.Mock]:
with mock.patch.object(log, "warning") as mocked:
yield mocked


@pytest.fixture(scope="function")
def snapshot_download() -> Iterator[mock.Mock]:
with mock.patch("app.models.base.snapshot_download") as mocked:
yield mocked
140 changes: 22 additions & 118 deletions machine-learning/app/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
from shutil import rmtree
from typing import Any, ClassVar

import onnxruntime as ort
from huggingface_hub import snapshot_download

import ann.ann
from app.models.constants import SUPPORTED_PROVIDERS
from app.sessions.ort import OrtSession

from ..config import clean_name, log, settings
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
from .ann import AnnSession
from ..sessions.ann import AnnSession


class InferenceModel(ABC):
Expand All @@ -24,20 +23,17 @@ def __init__(
self,
model_name: str,
cache_dir: Path | str | None = None,
providers: list[str] | None = None,
provider_options: list[dict[str, Any]] | None = None,
sess_options: ort.SessionOptions | None = None,
preferred_format: ModelFormat | None = None,
session: ModelSession | None = None,
**model_kwargs: Any,
) -> None:
self.loaded = False
self.loaded = session is not None
self.load_attempts = 0
self.model_name = clean_name(model_name)
self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default
self.providers = providers if providers is not None else self.providers_default
self.provider_options = provider_options if provider_options is not None else self.provider_options_default
self.sess_options = sess_options if sess_options is not None else self.sess_options_default
self.preferred_format = preferred_format if preferred_format is not None else self.preferred_format_default
self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
self.model_format = preferred_format if preferred_format is not None else self._model_format_default
if session is not None:
self.session = session

def download(self) -> None:
if not self.cached:
Expand Down Expand Up @@ -70,7 +66,7 @@ def configure(self, **kwargs: Any) -> None:
pass

def _download(self) -> None:
ignore_patterns = [] if self.preferred_format == ModelFormat.ARMNN else ["*.armnn"]
ignore_patterns = [] if self.model_format == ModelFormat.ARMNN else ["*.armnn"]
snapshot_download(
f"immich-app/{clean_name(self.model_name)}",
cache_dir=self.cache_dir,
Expand Down Expand Up @@ -105,26 +101,11 @@ def clear_cache(self) -> None:
self.cache_dir.mkdir(parents=True, exist_ok=True)

def _make_session(self, model_path: Path) -> ModelSession:
if not model_path.is_file():
onnx_path = model_path.with_suffix(".onnx")
if not onnx_path.is_file():
raise ValueError(f"Model path '{model_path}' does not exist")

log.warning(
f"Could not find model path '{model_path}'. " f"Falling back to ONNX model path '{onnx_path}' instead.",
)
model_path = onnx_path

match model_path.suffix:
case ".armnn":
session = AnnSession(model_path)
session: ModelSession = AnnSession(model_path)
case ".onnx":
session = ort.InferenceSession(
model_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
session = OrtSession(model_path)
case _:
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session
Expand All @@ -135,7 +116,7 @@ def model_dir(self) -> Path:

@property
def model_path(self) -> Path:
return self.model_dir / f"model.{self.preferred_format}"
return self.model_dir / f"model.{self.model_format}"

@property
def model_task(self) -> ModelTask:
Expand All @@ -154,103 +135,26 @@ def cache_dir(self, cache_dir: Path) -> None:
self._cache_dir = cache_dir

@property
def cache_dir_default(self) -> Path:
def _cache_dir_default(self) -> Path:
return settings.cache_folder / self.model_task.value / self.model_name

@property
def cached(self) -> bool:
return self.model_path.is_file()

@property
def providers(self) -> list[str]:
return self._providers

@providers.setter
def providers(self, providers: list[str]) -> None:
log.info(
(f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"),
)
self._providers = providers

@property
def providers_default(self) -> list[str]:
available_providers = set(ort.get_available_providers())
log.debug(f"Available ORT providers: {available_providers}")
if (openvino := "OpenVINOExecutionProvider") in available_providers:
device_ids: list[str] = ort.capi._pybind_state.get_available_openvino_device_ids()
log.debug(f"Available OpenVINO devices: {device_ids}")

gpu_devices = [device_id for device_id in device_ids if device_id.startswith("GPU")]
if not gpu_devices:
log.warning("No GPU device found in OpenVINO. Falling back to CPU.")
available_providers.remove(openvino)
return [provider for provider in SUPPORTED_PROVIDERS if provider in available_providers]

@property
def provider_options(self) -> list[dict[str, Any]]:
return self._provider_options

@provider_options.setter
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
log.debug(f"Setting execution provider options to {provider_options}")
self._provider_options = provider_options

@property
def provider_options_default(self) -> list[dict[str, Any]]:
options = []
for provider in self.providers:
match provider:
case "CPUExecutionProvider" | "CUDAExecutionProvider":
option = {"arena_extend_strategy": "kSameAsRequested"}
case "OpenVINOExecutionProvider":
option = {"device_type": "GPU_FP32", "cache_dir": (self.cache_dir / "openvino").as_posix()}
case _:
option = {}
options.append(option)
return options

@property
def sess_options(self) -> ort.SessionOptions:
return self._sess_options

@sess_options.setter
def sess_options(self, sess_options: ort.SessionOptions) -> None:
log.debug(f"Setting execution_mode to {sess_options.execution_mode.name}")
log.debug(f"Setting inter_op_num_threads to {sess_options.inter_op_num_threads}")
log.debug(f"Setting intra_op_num_threads to {sess_options.intra_op_num_threads}")
self._sess_options = sess_options

@property
def sess_options_default(self) -> ort.SessionOptions:
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = False

# avoid thread contention between models
if settings.model_inter_op_threads > 0:
sess_options.inter_op_num_threads = settings.model_inter_op_threads
# these defaults work well for CPU, but bottleneck GPU
elif settings.model_inter_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
sess_options.inter_op_num_threads = 1

if settings.model_intra_op_threads > 0:
sess_options.intra_op_num_threads = settings.model_intra_op_threads
elif settings.model_intra_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
sess_options.intra_op_num_threads = 2

if sess_options.inter_op_num_threads > 1:
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL

return sess_options

@property
def preferred_format(self) -> ModelFormat:
def model_format(self) -> ModelFormat:
return self._preferred_format

@preferred_format.setter
def preferred_format(self, preferred_format: ModelFormat) -> None:
@model_format.setter
def model_format(self, preferred_format: ModelFormat) -> None:
log.debug(f"Setting preferred format to {preferred_format}")
self._preferred_format = preferred_format

@property
def preferred_format_default(self) -> ModelFormat:
return ModelFormat.ARMNN if ann.ann.is_available and settings.ann else ModelFormat.ONNX
def _model_format_default(self) -> ModelFormat:
prefer_ann = ann.ann.is_available and settings.ann
ann_exists = (self.model_dir / "model.armnn").is_file()
if prefer_ann and not ann_exists:
log.warning(f"ARM NN is available, but '{self.model_name}' does not support ARM NN. Falling back to ONNX.")
return ModelFormat.ARMNN if prefer_ann and ann_exists else ModelFormat.ONNX
Loading
Loading