diff --git a/.gitignore b/.gitignore index a6d2bacf8e..79cf6399f4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ target Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk +.DS_Store +burn-tch/.cargo/config.toml diff --git a/burn-tch/Cargo.toml b/burn-tch/Cargo.toml index 8cc6117494..69e7bc4c70 100644 --- a/burn-tch/Cargo.toml +++ b/burn-tch/Cargo.toml @@ -18,7 +18,14 @@ burn-tensor = {path = "../burn-tensor"} half = {workspace = true} lazy_static = {workspace = true} rand = {workspace = true, features = ["std"]} -tch = {version = "0.10.1"} + +[target.'cfg(not(target_arch = "aarch64"))'.dependencies] +tch = {version = "0.10.3"} + +# Temporary workaround for https://github.com/burn-rs/burn/issues/180 +# Remove this and build.rs once tch-rs upgrades to Torch 2.0 at least +[target.'cfg(target_arch = "aarch64")'.dependencies] +tch = {version = "0.10.3", default-features = false} # Disables torch downloading [dev-dependencies] burn-autodiff = {path = "../burn-autodiff", default-features = false, features = [ diff --git a/burn-tch/build.rs b/burn-tch/build.rs new file mode 100644 index 0000000000..aed356cfca --- /dev/null +++ b/burn-tch/build.rs @@ -0,0 +1,16 @@ +use std::env; + +fn main() { + // Temporary workaround for https://github.com/burn-rs/burn/issues/180 + // Remove this once tch-rs upgrades to Torch 2.0 at least + + if cfg!(all(target_arch = "aarch64", target_os = "macos")) { + let message = "Run scripts/fix-tch-build-arm64.py to fix the environment variables for torch.\n See https://github.com/burn-rs/burn/issues/180 "; + env::var("LIBTORCH").expect(message); + env::var("DYLD_LIBRARY_PATH").expect(message); + } else if cfg!(all(target_arch = "aarch64", target_os = "linux")) { + let message = "Libtorch for AARCH64 Linux must be manually installed and set up.\n See https://github.com/burn-rs/burn/issues/180 "; + env::var("LIBTORCH").expect(message); + env::var("DYLD_LIBRARY_PATH").expect(message); + } +} diff --git a/scripts/fix-tch-build-arm64.py b/scripts/fix-tch-build-arm64.py new file mode 100755 index 0000000000..a34d3717bf --- /dev/null +++ b/scripts/fix-tch-build-arm64.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +This is a helper script to fix burn-tch build issues on Mac M1/M2 machines. + +It's a temporary workaround for https://github.com/burn-rs/burn/issues/180 +till tch-rs starts using Torch 2.0 libraries. + +This script installs torch via pip3 and creates environment variables in +burn-tch/.cargo/config.toml for tch-rs to link cc libs properly. + + +""" + +import os +import pathlib + + +def torch_path(): + import torch + return pathlib.Path(torch.__file__).parent + + +def update_toml_config(): + import tomli + import tomli_w + + cargo_cfg_dir = pathlib.Path(__file__).parent.parent.joinpath( + "burn-tch/.cargo").resolve() + cargo_cfg_dir.exists() + if not cargo_cfg_dir.exists(): + os.makedirs(cargo_cfg_dir) + + toml_file_path = cargo_cfg_dir.joinpath("config.toml") + + # Create toml file if does not exists + with open(toml_file_path, 'a') as f: + pass + + with open(toml_file_path, 'rb') as f: + config = tomli.load(f) + + config["env"] = config.get("env", dict()) + + config["env"]["LIBTORCH"] = dict( + value="{}".format(torch_path()), + force=True, + ) + + config["env"]["DYLD_LIBRARY_PATH"] = dict( + value="{}/lib".format(torch_path()), + force=True, + ) + + with open(toml_file_path, 'wb') as f: + tomli_w.dump(config, f) + + +def main(): + print("Installing/Upgrading torch via pip install ...") + os.system("pip3 install -U torch") + os.system("pip3 install -U tomli") + os.system("pip3 install -U tomli-w") + + print("Updating config.toml with torch library paths ... ") + update_toml_config() + + +if __name__ == '__main__': + main()