Skip to content

Commit

Permalink
Rework Python driver/device creation.
Browse files Browse the repository at this point in the history
APIs removed:
  * HalDriver.create() (use iree.runtime.get_driver(driver_name) to get a
    cached instance).
  * Environment variable IREE_DEFAULT_DRIVER renamed to
    IREE_DEFAULT_DEVICE to better reflect the new syntax.
  * Config.driver attribute (no longer captured by this class)

APIs added:
  * iree.runtime.query_available_drivers() (alias of HalDriver.query())
  * iree.runtime.get_driver(driver_name)
  * iree.runtime.get_device_by_name(name_spec)
  * iree.runtime.get_first_device_by_name(name_specs)
  * iree.runtime.Config(, device: HalDevice) (to configure with an
    explicit device)
  * HalDriver.create_device(device_id: Union[int, tuple])
  * HalDriver.query_available_devices()

Devices can now be queried/constructed explicitly using
HalDriver.query_available_devices() and passing found device_ids to
HalDriver.create_device. Default configuration is extended to take "name
specs" instead of driver name. This can either be a raw driver name
(i.e. "vmvx") but can also be driver:index (i.e. "vulkan:3"). Some
logging is added to make it clearer what was selected. Devices created
in this way are cached now, since this facility is meant to be used for
trivial/default configuration. If explicitly creating devices, the user
is on their own to cache as desired.

Fixes #9277
  • Loading branch information
Stella Laurenzo committed Jun 4, 2022
1 parent a30c840 commit a55003d
Show file tree
Hide file tree
Showing 11 changed files with 328 additions and 85 deletions.
8 changes: 8 additions & 0 deletions runtime/bindings/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ iree_py_library(
"iree/runtime/flags.py"
"iree/runtime/function.py"
"iree/runtime/system_api.py"
"iree/runtime/system_setup.py"
"iree/runtime/tracing.py"
"iree/runtime/scripts/iree_benchmark_trace/__main__.py"
"iree/runtime/scripts/iree_run_trace/__main__.py"
Expand Down Expand Up @@ -146,6 +147,13 @@ iree_py_test(
"tests/system_api_test.py"
)

iree_py_test(
NAME
system_setup_test
SRCS
"tests/system_setup_test.py"
)

iree_py_test(
NAME
vm_test
Expand Down
65 changes: 63 additions & 2 deletions runtime/bindings/python/hal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,23 @@ HalDriver HalDriver::Create(const std::string& driver_name) {
return HalDriver::StealFromRawPtr(driver);
}

py::list HalDriver::QueryAvailableDevices() {
iree_hal_device_info_t* device_infos;
iree_host_size_t count;
CheckApiStatus(iree_hal_driver_query_available_devices(
raw_ptr(), iree_allocator_system(), &device_infos, &count),
"Error querying devices");
py::list results;
for (iree_host_size_t i = 0; i < count; ++i) {
results.append(py::make_tuple(
py::cast(device_infos[i].device_id),
py::str(device_infos[i].name.data, device_infos[i].name.size)));
}

iree_allocator_free(iree_allocator_system(), device_infos);
return results;
}

HalDevice HalDriver::CreateDefaultDevice() {
iree_hal_device_t* device;
CheckApiStatus(iree_hal_driver_create_default_device(
Expand All @@ -244,6 +261,38 @@ HalDevice HalDriver::CreateDefaultDevice() {
return HalDevice::StealFromRawPtr(device);
}

HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id) {
// Since the device ids are supposed to be opaque, we need to verify
// them by querying available devices.
py::list available_devices = QueryAvailableDevices();
bool found = false;
py::object compare_device_id = py::cast(device_id);
for (auto record : available_devices) {
// Each record is a tuple of (device_id, name).
auto record_tuple = py::cast<py::tuple>(record);
py::object found_device_id = record_tuple[0];
if (found_device_id == compare_device_id) {
found = true;
break;
}
}

if (!found) {
std::string msg;
msg.append("Device id ");
msg.append(std::to_string(device_id));
msg.append(" not found. Available devices: ");
msg.append(py::repr(available_devices));
throw std::invalid_argument(std::move(msg));
}

iree_hal_device_t* device;
CheckApiStatus(iree_hal_driver_create_device(
raw_ptr(), device_id, iree_allocator_system(), &device),
"Error creating default device");
return HalDevice::StealFromRawPtr(device);
}

//------------------------------------------------------------------------------
// Enum helpers
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -416,9 +465,21 @@ void SetupHalBindings(pybind11::module m) {

py::class_<HalDriver>(m, "HalDriver")
.def_static("query", &HalDriver::Query)
.def_static("create", &HalDriver::Create, py::arg("driver_name"))
.def("create_default_device", &HalDriver::CreateDefaultDevice,
py::keep_alive<0, 1>());
py::keep_alive<0, 1>())
.def("create_device", &HalDriver::CreateDevice, py::keep_alive<0, 1>())
.def("create_device", [](HalDriver &self, py::tuple device_info) -> HalDevice {
// Alias of create_device that takes a tuple as returned from
// query_available_devices for convenience.
auto device_id = py::cast<iree_hal_device_id_t>(device_info[0]);
return self.CreateDevice(device_id);
}, py::keep_alive<0, 1>())
.def("query_available_devices", &HalDriver::QueryAvailableDevices);

// We cache drivers at the Python level so we hide the actual native
// entry point to create them directly so that it doesn't show up in
// the public API.
m.def("_create_hal_driver", &HalDriver::Create, py::arg("driver_name"));

py::class_<HalAllocator>(m, "HalAllocator")
.def("trim",
Expand Down
2 changes: 2 additions & 0 deletions runtime/bindings/python/hal.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
static std::vector<std::string> Query();
static HalDriver Create(const std::string& driver_name);

py::list QueryAvailableDevices();
HalDevice CreateDefaultDevice();
HalDevice CreateDevice(iree_hal_device_id_t device_id);
};

class HalAllocator : public ApiRefCounted<HalAllocator, iree_hal_allocator_t> {
Expand Down
6 changes: 6 additions & 0 deletions runtime/bindings/python/iree/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@

from .array_interop import *
from .system_api import *
from .system_setup import (
get_device_by_name,
get_first_device_by_name,
get_driver,
query_available_drivers,
)
from .function import *
from .tracing import *

Expand Down
82 changes: 14 additions & 68 deletions runtime/bindings/python/iree/runtime/system_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,11 @@

from . import _binding
from .function import FunctionInvoker
from .system_setup import get_first_device_by_name
from . import tracing

import numpy as np

# Environment key for a comma-delimitted list of drivers to try to load.
PREFERRED_DRIVER_ENV_KEY = "IREE_DEFAULT_DRIVER"

# Default value for IREE_DRIVER
DEFAULT_IREE_DRIVER_VALUE = "dylib,vulkan,vmvx"

# Mapping from IREE target backends to their corresponding drivers.
TARGET_BACKEND_TO_DRIVER = {
"dylib-llvm-aot": "dylib",
Expand All @@ -53,68 +48,30 @@
}


def _create_default_iree_driver(
driver_names: Optional[Sequence[str]] = None) -> _binding.HalDriver:
"""Returns a default driver based on environment settings."""
# TODO(laurenzo): Ideally this should take a VmModule and join any explicitly
# provided driver list with environmental constraints and what the module
# was compiled for.
if driver_names is None:
# Read from environment.
driver_names = os.environ.get(PREFERRED_DRIVER_ENV_KEY)
if driver_names is None:
driver_names = DEFAULT_IREE_DRIVER_VALUE
driver_names = driver_names.split(",")
available_driver_names = _binding.HalDriver.query()
driver_exceptions = {}
for driver_name in driver_names:
if driver_name not in available_driver_names:
logging.error("Could not create driver %s (not registered)", driver_name)
continue
try:
driver = _binding.HalDriver.create(driver_name)
except Exception as ex: # pylint: disable=broad-except
logging.exception("Could not create default driver %s", driver_name)
driver_exceptions[driver_name] = ex
continue

# Sanity check creation of the default device and skip the driver if
# this fails (this works around issues where the driver is present
# but there are no devices). This default initialization scheme needs
# to be improved.
try:
device = driver.create_default_device()
except Exception as ex:
logging.exception("Could not create default driver device %s",
driver_name)
driver_exceptions[driver_name] = ex
continue

logging.debug("Created IREE driver %s: %r", driver_name, driver)
return driver

# All failed.
raise RuntimeError(
f"Could not create any requested driver {repr(driver_names)} (available="
f"{repr(available_driver_names)}) : {repr(driver_exceptions)}")


class Config:
"""System configuration."""

driver: _binding.HalDriver
device: _binding.HalDevice
vm_instance: _binding.VmInstance
default_vm_modules: Tuple[_binding.VmModule, ...]
tracer: Optional[tracing.Tracer]

def __init__(self,
driver_name: Optional[str] = None,
*,
device: Optional[_binding.HalDevice] = None,
tracer: Optional[tracing.Tracer] = None):
# Either use an explicit device or auto config based on driver names.
if device is not None and driver_name is not None:
raise ValueError(
"Either 'device' or 'driver_name' can be specified (not both)")
if device is not None:
self.device = device
else:
self.device = get_first_device_by_name(
driver_name.split(",") if driver_name is not None else None)

self.vm_instance = _binding.VmInstance()
self.driver = _create_default_iree_driver(
driver_name.split(",") if driver_name is not None else None)
self.device = self.driver.create_default_device()
hal_module = _binding.create_hal_module(self.device)
self.default_vm_modules = (hal_module,)
self.tracer = tracer or tracing.get_default_tracer()
Expand All @@ -125,16 +82,6 @@ def __init__(self,
self.tracer = None


_global_config = None


def _get_global_config():
global _global_config
if _global_config is None:
_global_config = Config()
return _global_config


def _bool_to_int8(
array: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]:
if not isinstance(array, np.ndarray):
Expand Down Expand Up @@ -245,8 +192,7 @@ class SystemContext:
"""Global system."""

def __init__(self, vm_modules=None, config: Optional[Config] = None):
self._config = config if config is not None else _get_global_config()
logging.debug("SystemContext driver=%r", self._config.driver)
self._config = config if config is not None else Config()
self._is_dynamic = vm_modules is None
if self._is_dynamic:
init_vm_modules = None
Expand Down
Loading

0 comments on commit a55003d

Please sign in to comment.