Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[FEATURE] Load libcuda with dlopen instead of dynamic linking (#20484)
Browse files Browse the repository at this point in the history
* Load libcuda with dlopen instead of dynamic linking

* Fix a lint error

* Make naming style consistent

* Use correct the CUDA library names on Windows and Unix-like system
  • Loading branch information
TristonC authored Aug 4, 2021
1 parent 66fff25 commit e7866d0
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 14 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -585,11 +585,11 @@ if(USE_CUDA)

string(REPLACE ";" " " CUDA_ARCH_FLAGS_SPACES "${CUDA_ARCH_FLAGS}")

find_package(CUDAToolkit REQUIRED cublas cufft cusolver curand nvrtc cuda_driver
find_package(CUDAToolkit REQUIRED cublas cufft cusolver curand nvrtc
OPTIONAL_COMPONENTS nvToolsExt)

list(APPEND mxnet_LINKER_LIBS CUDA::cudart CUDA::cublas CUDA::cufft CUDA::cusolver CUDA::curand
CUDA::nvrtc CUDA::cuda_driver)
CUDA::nvrtc)
list(APPEND SOURCE ${CUDA})
add_definitions(-DMXNET_USE_CUDA=1)

Expand Down
52 changes: 43 additions & 9 deletions src/common/cuda/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <algorithm>

#include "rtc.h"
#include "../../initialize.h"
#include "rtc/half-inl.h"
#include "rtc/util-inl.h"
#include "rtc/forward_functions-inl.h"
Expand All @@ -41,12 +42,30 @@
#include "rtc/reducer-inl.h"
#include "utils.h"

typedef CUresult (*cuDeviceGetPtr) (CUdevice* device, int ordinal);
typedef CUresult (*cuDevicePrimaryCtxRetainPtr) (CUcontext* pctx, CUdevice dev);
typedef CUresult (*cuModuleLoadDataExPtr) (CUmodule* module, const void* image,
unsigned int numOptions, CUjit_option* options, void** optionValues);
typedef CUresult (*cuModuleGetFunctionPtr) (CUfunction* hfunc, CUmodule hmod,
const char* name);
typedef CUresult (*cuLaunchKernelPtr) (CUfunction f, unsigned int gridDimX,
unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams,
void** extra);
typedef CUresult (*cuGetErrorStringPtr) (CUresult error, const char** pStr);

namespace mxnet {
namespace common {
namespace cuda {
namespace rtc {

#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
const char cuda_lib_name[] = "nvcuda.dll";
#else
const char cuda_lib_name[] = "libcuda.so";
#endif

std::mutex lock;

namespace util {
Expand Down Expand Up @@ -149,6 +168,8 @@ CUfunction get_function(const std::string &parameters,
std::string ptx;
std::vector<CUfunction> functions;
};
void* cuda_lib_handle = LibraryInitializer::Get()->lib_load(cuda_lib_name);

// Maps from the kernel name and parameters to the ptx and jit-compiled CUfunctions.
using KernelCache = std::unordered_map<std::string, KernelInfo>;
// Per-gpu-architecture compiled kernel cache with jit-compiled function for each device context
Expand Down Expand Up @@ -233,8 +254,12 @@ CUfunction get_function(const std::string &parameters,
// Make sure driver context is set to the proper device
CUdevice cu_device;
CUcontext context;
CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, dev_id));
CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device));
cuDeviceGetPtr device_get_ptr = get_func<cuDeviceGetPtr>(cuda_lib_handle, "cuDeviceGet");
CUDA_DRIVER_CALL((*device_get_ptr)(&cu_device, dev_id));
cuDevicePrimaryCtxRetainPtr device_primary_ctx_retain_ptr =
get_func<cuDevicePrimaryCtxRetainPtr>(cuda_lib_handle, "cuDevicePrimaryCtxRetain");
CUDA_DRIVER_CALL((*device_primary_ctx_retain_ptr)(&context, cu_device));

// Jit-compile ptx for the driver's current context
CUmodule module;

Expand All @@ -250,10 +275,15 @@ CUfunction get_function(const std::string &parameters,
void* jit_opt_values[] = {reinterpret_cast<void*>(debug_info),
reinterpret_cast<void*>(line_info)};

CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, kinfo.ptx.c_str(), 2, jit_opts, jit_opt_values));
CUDA_DRIVER_CALL(cuModuleGetFunction(&kinfo.functions[dev_id],
module,
kinfo.mangled_name.c_str()));
cuModuleLoadDataExPtr module_load_data_ex_ptr =
get_func<cuModuleLoadDataExPtr>(cuda_lib_handle, "cuModuleLoadDataEx");
CUDA_DRIVER_CALL((*module_load_data_ex_ptr)(&module, kinfo.ptx.c_str(), 2,
jit_opts, jit_opt_values));
cuModuleGetFunctionPtr module_get_function_ptr =
get_func<cuModuleGetFunctionPtr>(cuda_lib_handle, "cuModuleGetFunction");
CUDA_DRIVER_CALL((*module_get_function_ptr)(&kinfo.functions[dev_id],
module,
kinfo.mangled_name.c_str()));
}
return kinfo.functions[dev_id];
}
Expand All @@ -266,8 +296,10 @@ void launch(CUfunction function,
std::vector<const void*> *args) {
CHECK(args->size() != 0) <<
"Empty argument list passed to a kernel.";
// CUDA_DRIVER_CALL(
CUresult err = cuLaunchKernel(function, // function to launch
void* cuda_lib_handle = LibraryInitializer::Get()->lib_load(cuda_lib_name);
cuLaunchKernelPtr launch_kernel_ptr =
get_func<cuLaunchKernelPtr>(cuda_lib_handle, "cuLaunchKernel");
CUresult err = (*launch_kernel_ptr)(function, // function to launch
grid_dim.x, grid_dim.y, grid_dim.z, // grid dim
block_dim.x, block_dim.y, block_dim.z, // block dim
shared_mem_bytes, // shared memory
Expand All @@ -276,7 +308,9 @@ void launch(CUfunction function,
nullptr); // );
if (err != CUDA_SUCCESS) {
const char* error_string;
cuGetErrorString(err, &error_string);
cuGetErrorStringPtr get_error_string_ptr =
get_func<cuGetErrorStringPtr>(cuda_lib_handle, "cuGetErrorString");
(*get_error_string_ptr)(err, &error_string);
LOG(FATAL) << "cuLaunchKernel failed: "
<< err << " " << error_string << ": "
<< reinterpret_cast<void*>(function) << " "
Expand Down
2 changes: 1 addition & 1 deletion src/initialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ void LibraryInitializer::lib_close(void* handle) {
* \param func function pointer that gets output address
* \param name function name to be fetched
*/
void LibraryInitializer::get_sym(void* handle, void** func, char* name) {
void LibraryInitializer::get_sym(void* handle, void** func, const char* name) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
*func = GetProcAddress((HMODULE)handle, name);
if (!(*func)) {
Expand Down
4 changes: 2 additions & 2 deletions src/initialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class LibraryInitializer {
bool lib_is_loaded(const std::string& path) const;
void* lib_load(const char* path);
void lib_close(void* handle);
static void get_sym(void* handle, void** func, char* name);
static void get_sym(void* handle, void** func, const char* name);

/**
* Original pid of the process which first loaded and initialized the library
Expand Down Expand Up @@ -114,7 +114,7 @@ class LibraryInitializer {
* \return func a function pointer
*/
template<typename T>
T get_func(void *lib, char *func_name) {
T get_func(void *lib, const char *func_name) {
T func;
LibraryInitializer::Get()->get_sym(lib, reinterpret_cast<void**>(&func), func_name);
if (!func)
Expand Down

0 comments on commit e7866d0

Please sign in to comment.