Skip to content

Commit

Permalink
Rework Python driver/device creation. (#9330)
Browse files Browse the repository at this point in the history
APIs removed:
  * HalDriver.create() (use iree.runtime.get_driver(driver_name) to get a
    cached instance).
  * Environment variable IREE_DEFAULT_DRIVER renamed to
    IREE_DEFAULT_DEVICE to better reflect the new syntax.
  * Config.driver attribute (no longer captured by this class)

APIs added:
  * iree.runtime.query_available_drivers() (alias of HalDriver.query())
  * iree.runtime.get_driver(device_uri)
  * iree.runtime.get_device(device_uri)
  * iree.runtime.get_first_device(device_uris)
  * iree.runtime.Config(, device: HalDevice) (to configure with an
    explicit device)
  * HalDriver.create_device(device_id: Union[int, tuple])
  * HalDriver.query_available_devices()
  * HalDriver.create_device_by_uri(device_uri: str)

Both driver and device lookup is done by a device URI, as defined by the runtime (when creating a driver, only the 'scheme' is used). Driver instances are cached by name in the native code, which should avoid various bad behavior in terms of driver lifetimes and lack of care to process state. Devices are optionally (default True) cached at the Python level.

Fixes #9277
Expected to fix #9936
  • Loading branch information
Stella Laurenzo authored Jul 29, 2022
1 parent 3326029 commit 3813758
Show file tree
Hide file tree
Showing 16 changed files with 394 additions and 92 deletions.
8 changes: 8 additions & 0 deletions runtime/bindings/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ iree_py_library(
"iree/runtime/flags.py"
"iree/runtime/function.py"
"iree/runtime/system_api.py"
"iree/runtime/system_setup.py"
"iree/runtime/tracing.py"
"iree/runtime/scripts/iree_benchmark_trace/__main__.py"
"iree/runtime/scripts/iree_run_trace/__main__.py"
Expand Down Expand Up @@ -147,6 +148,13 @@ iree_py_test(
"tests/system_api_test.py"
)

iree_py_test(
NAME
system_setup_test
SRCS
"tests/system_setup_test.py"
)

iree_py_test(
NAME
vm_test
Expand Down
112 changes: 106 additions & 6 deletions runtime/bindings/python/hal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "./hal.h"

#include "iree/base/internal/path.h"
#include "iree/base/tracing.h"
#include "iree/hal/api.h"
#include "pybind11/numpy.h"
Expand Down Expand Up @@ -226,14 +227,47 @@ std::vector<std::string> HalDriver::Query() {
return driver_names;
}

HalDriver HalDriver::Create(const std::string& driver_name) {
py::object HalDriver::Create(const std::string& device_uri,
py::dict& driver_cache) {
iree_string_view_t driver_name, device_path, params_str;
iree_string_view_t device_uri_sv{device_uri.data(), device_uri.size()};
iree_uri_split(device_uri_sv, &driver_name, &device_path, &params_str);

// Check cache.
py::str cache_key(driver_name.data, driver_name.size);
py::object cached = driver_cache.attr("get")(cache_key);
if (!cached.is_none()) {
return cached;
}

// Create.
iree_hal_driver_t* driver;
CheckApiStatus(iree_hal_driver_registry_try_create(
iree_hal_driver_registry_default(),
{driver_name.data(), driver_name.size()},
iree_hal_driver_registry_default(), driver_name,
iree_allocator_system(), &driver),
"Error creating driver");
return HalDriver::StealFromRawPtr(driver);

// Cache.
py::object driver_obj = py::cast(HalDriver::StealFromRawPtr(driver));
driver_cache[cache_key] = driver_obj;
return driver_obj;
}

py::list HalDriver::QueryAvailableDevices() {
iree_hal_device_info_t* device_infos;
iree_host_size_t count;
CheckApiStatus(iree_hal_driver_query_available_devices(
raw_ptr(), iree_allocator_system(), &count, &device_infos),
"Error querying devices");
py::list results;
for (iree_host_size_t i = 0; i < count; ++i) {
results.append(py::make_tuple(
py::cast(device_infos[i].device_id),
py::str(device_infos[i].name.data, device_infos[i].name.size)));
}

iree_allocator_free(iree_allocator_system(), device_infos);
return results;
}

HalDevice HalDriver::CreateDefaultDevice() {
Expand All @@ -244,6 +278,50 @@ HalDevice HalDriver::CreateDefaultDevice() {
return HalDevice::StealFromRawPtr(device);
}

HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id) {
// Since the device ids are supposed to be opaque, we need to verify
// them by querying available devices.
py::list available_devices = QueryAvailableDevices();
bool found = false;
py::object compare_device_id = py::cast(device_id);
for (auto record : available_devices) {
// Each record is a tuple of (device_id, name).
auto record_tuple = py::cast<py::tuple>(record);
py::object found_device_id = record_tuple[0];
if (found_device_id == compare_device_id) {
found = true;
break;
}
}

if (!found) {
std::string msg;
msg.append("Device id ");
msg.append(std::to_string(device_id));
msg.append(" not found. Available devices: ");
msg.append(py::repr(available_devices));
throw std::invalid_argument(std::move(msg));
}

std::vector<iree_string_pair_t> params;
iree_hal_device_t* device;
CheckApiStatus(iree_hal_driver_create_device_by_id(
raw_ptr(), device_id, params.size(), &params.front(),
iree_allocator_system(), &device),
"Error creating default device");
return HalDevice::StealFromRawPtr(device);
}

HalDevice HalDriver::CreateDeviceByURI(std::string& device_uri) {
iree_hal_device_t* device;
iree_string_view_t device_uri_sv{device_uri.data(), device_uri.size()};
CheckApiStatus(
iree_hal_driver_create_device_by_uri(raw_ptr(), device_uri_sv,
iree_allocator_system(), &device),
"Error creating device");
return HalDevice::StealFromRawPtr(device);
}

//------------------------------------------------------------------------------
// Enum helpers
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -308,6 +386,8 @@ py::object MapElementTypeToDType(iree_hal_element_type_t element_type) {
//------------------------------------------------------------------------------

void SetupHalBindings(pybind11::module m) {
py::dict driver_cache;

// Enums.
py::enum_<enum iree_hal_memory_type_bits_t>(m, "MemoryType")
.value("NONE", IREE_HAL_MEMORY_TYPE_NONE)
Expand Down Expand Up @@ -443,9 +523,29 @@ void SetupHalBindings(pybind11::module m) {

py::class_<HalDriver>(m, "HalDriver")
.def_static("query", &HalDriver::Query)
.def_static("create", &HalDriver::Create, py::arg("driver_name"))
.def("create_default_device", &HalDriver::CreateDefaultDevice,
py::keep_alive<0, 1>());
py::keep_alive<0, 1>())
.def("create_device", &HalDriver::CreateDevice, py::keep_alive<0, 1>())
.def("create_device_by_uri", &HalDriver::CreateDeviceByURI,
py::keep_alive<0, 1>())
.def(
"create_device",
[](HalDriver& self, py::tuple device_info) -> HalDevice {
// Alias of create_device that takes a tuple as returned from
// query_available_devices for convenience.
auto device_id = py::cast<iree_hal_device_id_t>(device_info[0]);
return self.CreateDevice(device_id);
},
py::keep_alive<0, 1>())
.def("query_available_devices", &HalDriver::QueryAvailableDevices);

m.def(
"get_cached_hal_driver",
[driver_cache](std::string device_uri) {
return HalDriver::Create(device_uri,
const_cast<py::dict&>(driver_cache));
},
py::arg("device_uri"));

py::class_<HalAllocator>(m, "HalAllocator")
.def("trim",
Expand Down
6 changes: 5 additions & 1 deletion runtime/bindings/python/hal.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,13 @@ class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> {
class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
public:
static std::vector<std::string> Query();
static HalDriver Create(const std::string& driver_name);
static py::object Create(const std::string& device_uri,
py::dict& driver_cache);

py::list QueryAvailableDevices();
HalDevice CreateDefaultDevice();
HalDevice CreateDevice(iree_hal_device_id_t device_id);
HalDevice CreateDeviceByURI(std::string& device_uri);
};

class HalAllocator : public ApiRefCounted<HalAllocator, iree_hal_allocator_t> {
Expand Down
6 changes: 6 additions & 0 deletions runtime/bindings/python/iree/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
from .array_interop import *
from .benchmark import *
from .system_api import *
from .system_setup import (
get_device,
get_first_device,
get_driver,
query_available_drivers,
)
from .function import *
from .tracing import *

Expand Down
1 change: 1 addition & 0 deletions runtime/bindings/python/iree/runtime/_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
import sys

from iree import _runtime

sys.modules[__name__] = _runtime
82 changes: 14 additions & 68 deletions runtime/bindings/python/iree/runtime/system_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,11 @@

from . import _binding
from .function import FunctionInvoker
from .system_setup import get_first_device
from . import tracing

import numpy as np

# Environment key for a comma-delimitted list of drivers to try to load.
PREFERRED_DRIVER_ENV_KEY = "IREE_DEFAULT_DRIVER"

# Default value for IREE_DRIVER
DEFAULT_IREE_DRIVER_VALUE = "local-task,vulkan"

# Mapping from IREE target backends to their corresponding drivers.
TARGET_BACKEND_TO_DRIVER = {
"dylib-llvm-aot": "local-task",
Expand All @@ -53,68 +48,30 @@
}


def _create_default_iree_driver(
driver_names: Optional[Sequence[str]] = None) -> _binding.HalDriver:
"""Returns a default driver based on environment settings."""
# TODO(laurenzo): Ideally this should take a VmModule and join any explicitly
# provided driver list with environmental constraints and what the module
# was compiled for.
if driver_names is None:
# Read from environment.
driver_names = os.environ.get(PREFERRED_DRIVER_ENV_KEY)
if driver_names is None:
driver_names = DEFAULT_IREE_DRIVER_VALUE
driver_names = driver_names.split(",")
available_driver_names = _binding.HalDriver.query()
driver_exceptions = {}
for driver_name in driver_names:
if driver_name not in available_driver_names:
logging.error("Could not create driver %s (not registered)", driver_name)
continue
try:
driver = _binding.HalDriver.create(driver_name)
except Exception as ex: # pylint: disable=broad-except
logging.exception("Could not create default driver %s", driver_name)
driver_exceptions[driver_name] = ex
continue

# Sanity check creation of the default device and skip the driver if
# this fails (this works around issues where the driver is present
# but there are no devices). This default initialization scheme needs
# to be improved.
try:
device = driver.create_default_device()
except Exception as ex:
logging.exception("Could not create default driver device %s",
driver_name)
driver_exceptions[driver_name] = ex
continue

logging.debug("Created IREE driver %s: %r", driver_name, driver)
return driver

# All failed.
raise RuntimeError(
f"Could not create any requested driver {repr(driver_names)} (available="
f"{repr(available_driver_names)}) : {repr(driver_exceptions)}")


class Config:
"""System configuration."""

driver: _binding.HalDriver
device: _binding.HalDevice
vm_instance: _binding.VmInstance
default_vm_modules: Tuple[_binding.VmModule, ...]
tracer: Optional[tracing.Tracer]

def __init__(self,
driver_name: Optional[str] = None,
*,
device: Optional[_binding.HalDevice] = None,
tracer: Optional[tracing.Tracer] = None):
# Either use an explicit device or auto config based on driver names.
if device is not None and driver_name is not None:
raise ValueError(
"Either 'device' or 'driver_name' can be specified (not both)")
if device is not None:
self.device = device
else:
self.device = get_first_device(
driver_name.split(",") if driver_name is not None else None)

self.vm_instance = _binding.VmInstance()
self.driver = _create_default_iree_driver(
driver_name.split(",") if driver_name is not None else None)
self.device = self.driver.create_default_device()
hal_module = _binding.create_hal_module(self.device)
self.default_vm_modules = (hal_module,)
self.tracer = tracer or tracing.get_default_tracer()
Expand All @@ -125,16 +82,6 @@ def __init__(self,
self.tracer = None


_global_config = None


def _get_global_config():
global _global_config
if _global_config is None:
_global_config = Config()
return _global_config


def _bool_to_int8(
array: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]:
if not isinstance(array, np.ndarray):
Expand Down Expand Up @@ -245,8 +192,7 @@ class SystemContext:
"""Global system."""

def __init__(self, vm_modules=None, config: Optional[Config] = None):
self._config = config if config is not None else _get_global_config()
logging.debug("SystemContext driver=%r", self._config.driver)
self._config = config if config is not None else Config()
self._is_dynamic = vm_modules is None
if self._is_dynamic:
init_vm_modules = None
Expand Down
Loading

0 comments on commit 3813758

Please sign in to comment.