diff --git a/examples/basics.rs b/examples/basics.rs index 8c0b1fcc..255a646d 100644 --- a/examples/basics.rs +++ b/examples/basics.rs @@ -13,7 +13,6 @@ fn grad_example() { } fn main() { - tch::maybe_init_cuda(); let t = Tensor::of_slice(&[3, 1, 4, 1, 5]); t.print(); let t = Tensor::randn(&[5, 4], kind::FLOAT_CPU); diff --git a/src/lib.rs b/src/lib.rs index 1d5d81fc..a452ce25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,9 +28,3 @@ pub use tensor::{ pub mod nn; pub mod vision; - -pub fn maybe_init_cuda() { - unsafe { - torch_sys::dummy_cuda_dependency(); - } -} diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 44f4459f..5e707ace 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -3,7 +3,6 @@ use crate::{Device, Kind, Scalar, TchError}; use half::f16; use std::convert::{TryFrom, TryInto}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use torch_sys::*; pub mod index; mod iter; @@ -750,6 +749,3 @@ try_from_impl!(i32); try_from_impl!(f64); try_from_impl!(i64); try_from_impl!(bool); - -#[used] -static INIT_ARRAY: [unsafe extern "C" fn(); 1] = [dummy_cuda_dependency]; diff --git a/torch-sys/build.rs b/torch-sys/build.rs index 715eb5fd..d9981708 100644 --- a/torch-sys/build.rs +++ b/torch-sys/build.rs @@ -133,14 +133,8 @@ fn prepare_libtorch_dir() -> PathBuf { } } -fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { +fn make>(libtorch: P) { let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); - - let cuda_dependency = if use_cuda || use_hip { - "libtch/dummy_cuda_dependency.cpp" - } else { - "libtch/fake_cuda_dependency.cpp" - }; println!("cargo:rerun-if-changed=libtch/torch_api.cpp"); println!("cargo:rerun-if-changed=libtch/torch_api.h"); println!("cargo:rerun-if-changed=libtch/torch_api_generated.cpp.h"); @@ -165,7 +159,6 @@ fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { .flag("-std=c++14") .flag(&format!("-D_GLIBCXX_USE_CXX11_ABI={}", libtorch_cxx11_abi)) .file("libtch/torch_api.cpp") - .file(cuda_dependency) .compile("tch"); } "windows" => { @@ -179,7 +172,6 @@ fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { .include(libtorch.as_ref().join("include")) .include(libtorch.as_ref().join("include/torch/csrc/api/include")) .file("libtch/torch_api.cpp") - .file(cuda_dependency) .compile("tch"); } _ => panic!("Unsupported OS"), @@ -189,20 +181,6 @@ fn make>(libtorch: P, use_cuda: bool, use_hip: bool) { fn main() { if !cfg!(feature = "doc-only") { let libtorch = prepare_libtorch_dir(); - // use_cuda is a hacky way to detect whether cuda is available and - // if it's the case link to it by explicitly depending on a symbol - // from the torch_cuda library. - // It would be better to use -Wl,--no-as-needed but there is no way - // to specify arbitrary linker flags at the moment. - // - // Once https://github.com/rust-lang/cargo/pull/8441 is available - // we should switch to using rustc-link-arg instead e.g. with the - // following flags: - // -Wl,--no-as-needed -Wl,--copy-dt-needed-entries -ltorch - // - // This will be available starting from cargo 1.50 but will be a nightly - // only option to start with. - // https://github.com/rust-lang/cargo/blob/master/CHANGELOG.md let use_cuda = libtorch.join("lib").join("libtorch_cuda.so").exists() || libtorch.join("lib").join("torch_cuda.dll").exists(); let use_cuda_cu = libtorch.join("lib").join("libtorch_cuda_cu.so").exists() @@ -216,8 +194,11 @@ fn main() { libtorch.join("lib").display() ); - make(&libtorch, use_cuda, use_hip); + make(&libtorch); + // PyTorch uses two separate library for the cpu bits and the cuda bits. + // The following flags ensure that the linker does not remove the + // cuda dependencies. println!("cargo:rustc-link-lib=static=tch"); if use_cuda { println!("cargo:rustc-link-lib=torch_cuda"); @@ -231,7 +212,9 @@ fn main() { if use_hip { println!("cargo:rustc-link-lib=torch_hip"); } - println!("cargo:rustc-link-lib=torch"); + println!("cargo:rustc-link-arg=-Wl,--no-as-needed"); + println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries"); + println!("cargo:rustc-link-arg=-ltorch"); println!("cargo:rustc-link-lib=torch_cpu"); println!("cargo:rustc-link-lib=c10"); if use_hip { diff --git a/torch-sys/libtch/dummy_cuda_dependency.cpp b/torch-sys/libtch/dummy_cuda_dependency.cpp deleted file mode 100644 index f2495a82..00000000 --- a/torch-sys/libtch/dummy_cuda_dependency.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include -#include -using namespace std; -extern "C" { - void dummy_cuda_dependency(); -} - -namespace at { - namespace cuda { - void *getCurrentCUDABlasHandle(); - int warp_size(); - } -} -char * magma_strerror(int err); -void dummy_cuda_dependency() { - at::cuda::getCurrentCUDABlasHandle(); - at::cuda::warp_size(); -} diff --git a/torch-sys/libtch/fake_cuda_dependency.cpp b/torch-sys/libtch/fake_cuda_dependency.cpp deleted file mode 100644 index 9db48039..00000000 --- a/torch-sys/libtch/fake_cuda_dependency.cpp +++ /dev/null @@ -1,6 +0,0 @@ -extern "C" { - void dummy_cuda_dependency(); -} - -void dummy_cuda_dependency() { -} diff --git a/torch-sys/src/lib.rs b/torch-sys/src/lib.rs index 32cf5049..e6b692c1 100644 --- a/torch-sys/src/lib.rs +++ b/torch-sys/src/lib.rs @@ -254,7 +254,3 @@ extern "C" { noutputs: c_int, ); } - -extern "C" { - pub fn dummy_cuda_dependency(); -}