Skip to content

Commit

Permalink
Add a workaround script for arm64 tch-rs build issue (#180) (#182)
Browse files Browse the repository at this point in the history
* Add a workaround script for arm64 tch-rs build issue (#180)
  • Loading branch information
antimora authored Mar 1, 2023
1 parent 37806b5 commit a62738b
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ target
Cargo.lock
# These are backup files generated by rustfmt
**/*.rs.bk
.DS_Store
burn-tch/.cargo/config.toml
9 changes: 8 additions & 1 deletion burn-tch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ burn-tensor = {path = "../burn-tensor"}
half = {workspace = true}
lazy_static = {workspace = true}
rand = {workspace = true, features = ["std"]}
tch = {version = "0.10.1"}

[target.'cfg(not(target_arch = "aarch64"))'.dependencies]
tch = {version = "0.10.3"}

# Temporary workaround for https://github.com/burn-rs/burn/issues/180
# Remove this and build.rs once tch-rs upgrades to Torch 2.0 at least
[target.'cfg(target_arch = "aarch64")'.dependencies]
tch = {version = "0.10.3", default-features = false} # Disables torch downloading

[dev-dependencies]
burn-autodiff = {path = "../burn-autodiff", default-features = false, features = [
Expand Down
16 changes: 16 additions & 0 deletions burn-tch/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use std::env;

fn main() {
// Temporary workaround for https://github.com/burn-rs/burn/issues/180
// Remove this once tch-rs upgrades to Torch 2.0 at least

if cfg!(all(target_arch = "aarch64", target_os = "macos")) {
let message = "Run scripts/fix-tch-build-arm64.py to fix the environment variables for torch.\n See https://github.com/burn-rs/burn/issues/180 ";
env::var("LIBTORCH").expect(message);
env::var("DYLD_LIBRARY_PATH").expect(message);
} else if cfg!(all(target_arch = "aarch64", target_os = "linux")) {
let message = "Libtorch for AARCH64 Linux must be manually installed and set up.\n See https://github.com/burn-rs/burn/issues/180 ";
env::var("LIBTORCH").expect(message);
env::var("DYLD_LIBRARY_PATH").expect(message);
}
}
71 changes: 71 additions & 0 deletions scripts/fix-tch-build-arm64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
This is a helper script to fix burn-tch build issues on Mac M1/M2 machines.
It's a temporary workaround for https://github.com/burn-rs/burn/issues/180
till tch-rs starts using Torch 2.0 libraries.
This script installs torch via pip3 and creates environment variables in
burn-tch/.cargo/config.toml for tch-rs to link cc libs properly.
"""

import os
import pathlib


def torch_path():
import torch
return pathlib.Path(torch.__file__).parent


def update_toml_config():
import tomli
import tomli_w

cargo_cfg_dir = pathlib.Path(__file__).parent.parent.joinpath(
"burn-tch/.cargo").resolve()
cargo_cfg_dir.exists()
if not cargo_cfg_dir.exists():
os.makedirs(cargo_cfg_dir)

toml_file_path = cargo_cfg_dir.joinpath("config.toml")

# Create toml file if does not exists
with open(toml_file_path, 'a') as f:
pass

with open(toml_file_path, 'rb') as f:
config = tomli.load(f)

config["env"] = config.get("env", dict())

config["env"]["LIBTORCH"] = dict(
value="{}".format(torch_path()),
force=True,
)

config["env"]["DYLD_LIBRARY_PATH"] = dict(
value="{}/lib".format(torch_path()),
force=True,
)

with open(toml_file_path, 'wb') as f:
tomli_w.dump(config, f)


def main():
print("Installing/Upgrading torch via pip install ...")
os.system("pip3 install -U torch")
os.system("pip3 install -U tomli")
os.system("pip3 install -U tomli-w")

print("Updating config.toml with torch library paths ... ")
update_toml_config()


if __name__ == '__main__':
main()

0 comments on commit a62738b

Please sign in to comment.