From e143db7090c8c7d689d68e0bc1dbc99d4a9d1b93 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 15 Nov 2021 21:09:17 -0800 Subject: [PATCH] [mxnet] Updates MXNet engine to 1.9.0 Change-Id: I17f0fb861093cd2b09c21cf063d5a39d213dc4a9 --- .github/workflows/native_build_mxnet_osx.yml | 2 +- .../main/java/ai/djl/mxnet/jna/LibUtils.java | 36 +++++-------------- engines/mxnet/native/build.gradle | 2 +- gradle.properties | 2 +- 4 files changed, 12 insertions(+), 30 deletions(-) diff --git a/.github/workflows/native_build_mxnet_osx.yml b/.github/workflows/native_build_mxnet_osx.yml index b55dacad3e3..9ff49a937be 100644 --- a/.github/workflows/native_build_mxnet_osx.yml +++ b/.github/workflows/native_build_mxnet_osx.yml @@ -12,7 +12,7 @@ jobs: uses: actions/checkout@v2 with: repository: apache/incubator-mxnet - ref: 1.8.0 + ref: 1.9.0 submodules: 'recursive' - name: Install Dependencies run: | diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/LibUtils.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/LibUtils.java index 54b7b09f513..f318ea50b46 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/LibUtils.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/LibUtils.java @@ -232,31 +232,9 @@ private static String downloadMxnet(Platform platform) { flavor = "mkl"; } } else if ("linux".equals(os)) { - boolean match = - lines.contains(os + '/' + flavor + "/libmxnet.so.gz") - && supported(platform); - if (!match) { - String cudaMajor = flavor.substring(0, 4); - Pattern pattern = - Pattern.compile( - '(' - + cudaMajor - + ".+)/" - + classifier - + "/native/lib/" - + libName - + ".gz"); - for (String line : lines) { - Matcher m = pattern.matcher(line); - if (m.matches()) { - flavor = m.group(1); - match = true; - break; - } - } - } - - if (!match) { + // MXNet must use exactly matched cuda minor version + if (!lines.contains(os + '/' + flavor + "/libmxnet.so.gz") + || !supported(platform)) { logger.warn( "No matching cuda flavor for {} found: {}/sm_{}.", os, @@ -320,11 +298,15 @@ private static String downloadMxnet(Platform platform) { private static boolean supported(Platform platform) { // mxnet-native-cu102mkl:1.8.0: 3.0, 5.0, 6.0, 7.0, 7.5 // mxnet-native-cu110mkl:1.8.0: 5.0, 6.0, 7.0, 8.0 - if (platform.getVersion().startsWith("1.8.")) { + String version = platform.getVersion(); + if (version.startsWith("1.8.") || version.startsWith("1.9.")) { String flavor = platform.getFlavor(); String cudaArch = platform.getCudaArch(); if (flavor.startsWith("cu11")) { - return Arrays.asList("50", "60", "70", "80").contains(cudaArch); + if (version.startsWith("1.8.")) { + return Arrays.asList("50", "60", "70", "80").contains(cudaArch); + } + return Arrays.asList("50", "60", "70", "75", "80").contains(cudaArch); } else if (flavor.startsWith("cu10")) { return Arrays.asList("30", "50", "60", "70", "75").contains(cudaArch); } diff --git a/engines/mxnet/native/build.gradle b/engines/mxnet/native/build.gradle index 52687826494..957978b0a19 100644 --- a/engines/mxnet/native/build.gradle +++ b/engines/mxnet/native/build.gradle @@ -176,7 +176,7 @@ task downloadMxnetNativeLib() { "linux/common/libquadmath.so.0.gz": "mkl/linux/native/lib/libquadmath.so.0", "linux/mkl/libmxnet.so.gz" : "mkl/linux/native/lib/libmxnet.so", "linux/cu102mkl/libmxnet.so.gz" : "cu102mkl/linux/native/lib/libmxnet.so", - "linux/cu110mkl/libmxnet.so.gz" : "cu110mkl/linux/native/lib/libmxnet.so", + "linux/cu112mkl/libmxnet.so.gz" : "cu112mkl/linux/native/lib/libmxnet.so", "osx/mkl/libmxnet.dylib.gz" : "mkl/osx/native/lib/libmxnet.dylib", "win/common/libgcc_s_seh-1.dll.gz": "mkl/win/native/lib/libgcc_s_seh-1.dll", "win/common/libgfortran-3.dll.gz" : "mkl/win/native/lib/libgfortran-3.dll", diff --git a/gradle.properties b/gradle.properties index cfaecf42332..8d3a72ac3e1 100644 --- a/gradle.properties +++ b/gradle.properties @@ -8,7 +8,7 @@ systemProp.org.gradle.internal.http.connectionTimeout=60000 systemProp.org.gradle.internal.publish.checksums.insecure=true djl_version=0.16.0 -mxnet_version=1.8.0 +mxnet_version=1.9.0 pytorch_version=1.10.0 tensorflow_version=2.7.0 tflite_version=2.6.2