Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] Load libcuda with dlopen instead of dynamic linking #20484

Merged
merged 4 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,11 @@ if(USE_CUDA)

string(REPLACE ";" " " CUDA_ARCH_FLAGS_SPACES "${CUDA_ARCH_FLAGS}")

find_package(CUDAToolkit REQUIRED cublas cufft cusolver curand nvrtc cuda_driver
find_package(CUDAToolkit REQUIRED cublas cufft cusolver curand nvrtc
OPTIONAL_COMPONENTS nvToolsExt)

list(APPEND mxnet_LINKER_LIBS CUDA::cudart CUDA::cublas CUDA::cufft CUDA::cusolver CUDA::curand
CUDA::nvrtc CUDA::cuda_driver)
CUDA::nvrtc)
list(APPEND SOURCE ${CUDA})
add_definitions(-DMXNET_USE_CUDA=1)

Expand Down
52 changes: 43 additions & 9 deletions src/common/cuda/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <algorithm>

#include "rtc.h"
#include "../../initialize.h"
#include "rtc/half-inl.h"
#include "rtc/util-inl.h"
#include "rtc/forward_functions-inl.h"
Expand All @@ -41,12 +42,30 @@
#include "rtc/reducer-inl.h"
#include "utils.h"

typedef CUresult (*cuDeviceGetPtr) (CUdevice* device, int ordinal);
typedef CUresult (*cuDevicePrimaryCtxRetainPtr) (CUcontext* pctx, CUdevice dev);
typedef CUresult (*cuModuleLoadDataExPtr) (CUmodule* module, const void* image,
unsigned int numOptions, CUjit_option* options, void** optionValues);
typedef CUresult (*cuModuleGetFunctionPtr) (CUfunction* hfunc, CUmodule hmod,
const char* name);
typedef CUresult (*cuLaunchKernelPtr) (CUfunction f, unsigned int gridDimX,
unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams,
void** extra);
typedef CUresult (*cuGetErrorStringPtr) (CUresult error, const char** pStr);

namespace mxnet {
namespace common {
namespace cuda {
namespace rtc {

#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
const char cuda_lib_name[] = "nvcuda.dll";
#else
const char cuda_lib_name[] = "libcuda.so";
#endif

std::mutex lock;

namespace util {
Expand Down Expand Up @@ -149,6 +168,8 @@ CUfunction get_function(const std::string &parameters,
std::string ptx;
std::vector<CUfunction> functions;
};
void* cuda_lib_handle = LibraryInitializer::Get()->lib_load(cuda_lib_name);

// Maps from the kernel name and parameters to the ptx and jit-compiled CUfunctions.
using KernelCache = std::unordered_map<std::string, KernelInfo>;
// Per-gpu-architecture compiled kernel cache with jit-compiled function for each device context
Expand Down Expand Up @@ -233,8 +254,12 @@ CUfunction get_function(const std::string &parameters,
// Make sure driver context is set to the proper device
CUdevice cu_device;
CUcontext context;
CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, dev_id));
CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device));
cuDeviceGetPtr device_get_ptr = get_func<cuDeviceGetPtr>(cuda_lib_handle, "cuDeviceGet");
CUDA_DRIVER_CALL((*device_get_ptr)(&cu_device, dev_id));
cuDevicePrimaryCtxRetainPtr device_primary_ctx_retain_ptr =
get_func<cuDevicePrimaryCtxRetainPtr>(cuda_lib_handle, "cuDevicePrimaryCtxRetain");
CUDA_DRIVER_CALL((*device_primary_ctx_retain_ptr)(&context, cu_device));

// Jit-compile ptx for the driver's current context
CUmodule module;

Expand All @@ -250,10 +275,15 @@ CUfunction get_function(const std::string &parameters,
void* jit_opt_values[] = {reinterpret_cast<void*>(debug_info),
reinterpret_cast<void*>(line_info)};

CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, kinfo.ptx.c_str(), 2, jit_opts, jit_opt_values));
CUDA_DRIVER_CALL(cuModuleGetFunction(&kinfo.functions[dev_id],
module,
kinfo.mangled_name.c_str()));
cuModuleLoadDataExPtr module_load_data_ex_ptr =
get_func<cuModuleLoadDataExPtr>(cuda_lib_handle, "cuModuleLoadDataEx");
CUDA_DRIVER_CALL((*module_load_data_ex_ptr)(&module, kinfo.ptx.c_str(), 2,
jit_opts, jit_opt_values));
cuModuleGetFunctionPtr module_get_function_ptr =
get_func<cuModuleGetFunctionPtr>(cuda_lib_handle, "cuModuleGetFunction");
CUDA_DRIVER_CALL((*module_get_function_ptr)(&kinfo.functions[dev_id],
module,
kinfo.mangled_name.c_str()));
}
return kinfo.functions[dev_id];
}
Expand All @@ -266,8 +296,10 @@ void launch(CUfunction function,
std::vector<const void*> *args) {
CHECK(args->size() != 0) <<
"Empty argument list passed to a kernel.";
// CUDA_DRIVER_CALL(
CUresult err = cuLaunchKernel(function, // function to launch
void* cuda_lib_handle = LibraryInitializer::Get()->lib_load(cuda_lib_name);
cuLaunchKernelPtr launch_kernel_ptr =
get_func<cuLaunchKernelPtr>(cuda_lib_handle, "cuLaunchKernel");
CUresult err = (*launch_kernel_ptr)(function, // function to launch
grid_dim.x, grid_dim.y, grid_dim.z, // grid dim
block_dim.x, block_dim.y, block_dim.z, // block dim
shared_mem_bytes, // shared memory
Expand All @@ -276,7 +308,9 @@ void launch(CUfunction function,
nullptr); // );
if (err != CUDA_SUCCESS) {
const char* error_string;
cuGetErrorString(err, &error_string);
cuGetErrorStringPtr get_error_string_ptr =
get_func<cuGetErrorStringPtr>(cuda_lib_handle, "cuGetErrorString");
(*get_error_string_ptr)(err, &error_string);
LOG(FATAL) << "cuLaunchKernel failed: "
<< err << " " << error_string << ": "
<< reinterpret_cast<void*>(function) << " "
Expand Down
2 changes: 1 addition & 1 deletion src/initialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ void LibraryInitializer::lib_close(void* handle) {
* \param func function pointer that gets output address
* \param name function name to be fetched
*/
void LibraryInitializer::get_sym(void* handle, void** func, char* name) {
void LibraryInitializer::get_sym(void* handle, void** func, const char* name) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
*func = GetProcAddress((HMODULE)handle, name);
if (!(*func)) {
Expand Down
4 changes: 2 additions & 2 deletions src/initialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class LibraryInitializer {
bool lib_is_loaded(const std::string& path) const;
void* lib_load(const char* path);
void lib_close(void* handle);
static void get_sym(void* handle, void** func, char* name);
static void get_sym(void* handle, void** func, const char* name);

/**
* Original pid of the process which first loaded and initialized the library
Expand Down Expand Up @@ -114,7 +114,7 @@ class LibraryInitializer {
* \return func a function pointer
*/
template<typename T>
T get_func(void *lib, char *func_name) {
T get_func(void *lib, const char *func_name) {
T func;
LibraryInitializer::Get()->get_sym(lib, reinterpret_cast<void**>(&func), func_name);
if (!func)
Expand Down