Skip to content

Commit

Permalink
[pytorch] Downloads only matching cuda version of native library
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Apr 12, 2023
1 parent afb7d9d commit f6bd261
Showing 1 changed file with 20 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Properties;
Expand Down Expand Up @@ -289,6 +290,7 @@ private static LibTorch findNativeLibrary() {
}

private static LibTorch copyNativeLibraryFromClasspath(Platform platform) {
logger.debug("Found bundled PyTorch package: {}.", platform);
String version = platform.getVersion();
String flavor = platform.getFlavor();
if (!flavor.endsWith("-precxx11")
Expand Down Expand Up @@ -350,6 +352,7 @@ private static LibTorch downloadPyTorch(Platform platform) {
String classifier = platform.getClassifier();
String precxx11;
String flavor = Utils.getEnvOrSystemProperty("PYTORCH_FLAVOR");
boolean override;
if (flavor == null || flavor.isEmpty()) {
flavor = platform.getFlavor();
if (System.getProperty("os.name").startsWith("Linux")
Expand All @@ -360,9 +363,11 @@ private static LibTorch downloadPyTorch(Platform platform) {
precxx11 = "";
}
flavor += precxx11;
override = false;
} else {
logger.info("Uses override PYTORCH_FLAVOR: {}", flavor);
precxx11 = flavor.endsWith("-precxx11") ? "-precxx11" : "";
override = true;
}

Path cacheDir = Utils.getEngineCacheDir("pytorch");
Expand Down Expand Up @@ -399,23 +404,32 @@ private static LibTorch downloadPyTorch(Platform platform) {
Files.createDirectories(cacheDir);
List<String> lines = Utils.readLines(is);
if (flavor.startsWith("cu")) {
String cudaMajor = flavor.substring(0, 4);
int cudaVersion = Integer.parseInt(flavor.substring(2, 5));
Pattern pattern =
Pattern.compile(
'('
+ cudaMajor
+ "\\d"
"(cu(\\d\\d\\d)"
+ precxx11
+ ")/"
+ '/'
+ classifier
+ "/native/lib/"
+ NATIVE_LIB_NAME
+ ".gz");
List<Integer> cudaVersions = new ArrayList<>();
boolean match = false;
for (String line : lines) {
Matcher m = pattern.matcher(line);
if (m.matches()) {
flavor = m.group(1);
cudaVersions.add(Integer.parseInt(m.group(1)));
}
}
// find highest matching CUDA version
cudaVersions.sort(Collections.reverseOrder());
for (int cuda : cudaVersions) {
if (override && cuda == cudaVersion) {
match = true;
break;
} else if (cuda <= cudaVersion) {
flavor = "cu" + cuda + precxx11;
match = true;
break;
}
Expand Down

0 comments on commit f6bd261

Please sign in to comment.