From 84737c27f63e8b57839bf53f97d1ce4f0db8e150 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 29 Mar 2023 09:14:12 -0700 Subject: [PATCH] [serving] Fixes NeuronUtils for SageMaker --- .../main/java/ai/djl/serving/ModelServer.java | 3 +- .../java/ai/djl/serving/util/NeuronUtils.java | 73 +++------ .../ai/djl/serving/util/NeuronUtilsTest.java | 144 ++---------------- 3 files changed, 42 insertions(+), 178 deletions(-) diff --git a/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/src/main/java/ai/djl/serving/ModelServer.java index 01e22ecdf3..c4d81e83f8 100644 --- a/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -577,7 +577,8 @@ private String[] parseDevices(String devices, Engine engine, Path modelDir) { String v = Utils.getenv("TENSOR_PARALLEL_DEGREE", "-1"); v = prop.getProperty("option.tensor_parallel_degree", v); int tensorParallelDegree = Integer.parseInt(v); - if (tensorParallelDegree > 0 && NeuronUtils.isInf2()) { + if (tensorParallelDegree > 0) { + // Assume user understand TP only works on inf2 int procs = neurons / tensorParallelDegree; if (procs == 0) { throw new EngineException( diff --git a/serving/src/main/java/ai/djl/serving/util/NeuronUtils.java b/serving/src/main/java/ai/djl/serving/util/NeuronUtils.java index 59b19241c7..91f6d75710 100644 --- a/serving/src/main/java/ai/djl/serving/util/NeuronUtils.java +++ b/serving/src/main/java/ai/djl/serving/util/NeuronUtils.java @@ -12,50 +12,30 @@ */ package ai.djl.serving.util; -import ai.djl.util.Ec2Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Iterator; import java.util.stream.Stream; /** A utility class to detect number of nueron cores. */ public final class NeuronUtils { - private static int instanceType = -1; + private static final Logger logger = LoggerFactory.getLogger(NeuronUtils.class); private NeuronUtils() {} - static void setInstanceType(int type) { - instanceType = type; - } - /** * Returns whether Neuron runtime library is in the system. * * @return {@code true} if Neuron runtime library is in the system */ public static boolean hasNeuron() { - return isInf1() || isInf2(); - } - - /** - * Returns whether the instance is an inf1. - * - * @return {code true} if the instance is an inf1 - */ - public static boolean isInf1() { - return getInstanceType() == 1; - } - - /** - * Returns whether the instance is an inf2 or trn1. - * - * @return {code true} if the instance is an inf2 or trn1 - */ - public static boolean isInf2() { - return getInstanceType() == 2; + return getNeuronCores() > 0; } /** @@ -64,36 +44,31 @@ public static boolean isInf2() { * @return the number of NeuronCores available in the system */ public static int getNeuronCores() { - if (!hasNeuron()) { + return getNeuronCores("/sys/devices/virtual/neuron_device/"); + } + + static int getNeuronCores(String location) { + Path path = Paths.get(location); + if (!Files.exists(path)) { return 0; } - try (Stream paths = Files.list(Paths.get("/dev"))) { - long nd = paths.filter(p -> p.getFileName().toString().startsWith("neuron")).count(); - if (isInf1()) { - // inf1 has 4 cores on each device - return (int) nd * 4; + int count = 0; + try (Stream dev = Files.list(path)) { + for (Iterator it = dev.iterator(); it.hasNext(); ) { + Path dir = it.next(); + if (dir.getFileName().toString().startsWith("neuron")) { + Stream cores = Files.list(dir); + count += Math.toIntExact(cores.filter(NeuronUtils::matches).count()); + cores.close(); + } } - // inf2 has 2 cores on each device - return (int) nd * 2; } catch (IOException e) { - throw new AssertionError("Failed to list neuron cores", e); + logger.warn("Failed to list neuron cores", e); } + return count; } - @SuppressWarnings("PMD.NonThreadSafeSingleton") - private static int getInstanceType() { - if (instanceType == -1) { - String metadata = Ec2Utils.readMetadata("instance-type"); - if (metadata == null) { - NeuronUtils.setInstanceType(0); - } else if (metadata.startsWith("inf1")) { - NeuronUtils.setInstanceType(1); - } else if (metadata.startsWith("inf2") || metadata.startsWith("trn1")) { - NeuronUtils.setInstanceType(2); - } else { - NeuronUtils.setInstanceType(0); - } - } - return instanceType; + private static boolean matches(Path p) { + return p.getFileName().toString().startsWith("neuron_core"); } } diff --git a/serving/src/test/java/ai/djl/serving/util/NeuronUtilsTest.java b/serving/src/test/java/ai/djl/serving/util/NeuronUtilsTest.java index d7778a6b0f..9360e504e7 100644 --- a/serving/src/test/java/ai/djl/serving/util/NeuronUtilsTest.java +++ b/serving/src/test/java/ai/djl/serving/util/NeuronUtilsTest.java @@ -12,145 +12,33 @@ */ package ai.djl.serving.util; +import ai.djl.util.Utils; + import org.testng.Assert; import org.testng.annotations.Test; -import java.io.ByteArrayInputStream; import java.io.IOException; -import java.io.InputStream; -import java.net.HttpURLConnection; -import java.net.Proxy; -import java.net.URL; -import java.net.URLConnection; -import java.net.URLStreamHandler; -import java.net.URLStreamHandlerFactory; -import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; public class NeuronUtilsTest { @Test - public void testNeuronUtils() { - MockURLStreamHandlerFactory factory = new MockURLStreamHandlerFactory(); - + public void testNeuronUtils() throws IOException { + Path dir = Paths.get("build/neuron_device/"); try { - URL.setURLStreamHandlerFactory(factory); - - NeuronUtils.setInstanceType(-1); - factory.setMockMode(0); // inf1 - Assert.assertTrue(NeuronUtils.hasNeuron()); - Assert.assertTrue(NeuronUtils.isInf1()); - if (System.getProperty("os.name").startsWith("Linux")) { - NeuronUtils.getNeuronCores(); - } - - NeuronUtils.setInstanceType(-1); - factory.setMockMode(1); // inf2 - Assert.assertTrue(NeuronUtils.hasNeuron()); - Assert.assertTrue(NeuronUtils.isInf2()); - - NeuronUtils.setInstanceType(-1); - factory.setMockMode(2); // inf1 - Assert.assertTrue(NeuronUtils.hasNeuron()); - Assert.assertTrue(NeuronUtils.isInf2()); - - NeuronUtils.setInstanceType(-1); - factory.setMockMode(3); // inf1 - Assert.assertFalse(NeuronUtils.hasNeuron()); - - NeuronUtils.setInstanceType(-1); - factory.setMockMode(4); // inf1 - Assert.assertFalse(NeuronUtils.hasNeuron()); - } finally { - factory.setMock(false); - } - } - - static final class MockURLStreamHandlerFactory implements URLStreamHandlerFactory { - - private boolean mock = true; - private int mockMode; - - public void setMock(boolean mock) { - this.mock = mock; - } - - public void setMockMode(int mockMode) { - this.mockMode = mockMode; - } - - /** {@inheritDoc} */ - @Override - public URLStreamHandler createURLStreamHandler(String protocol) { - if (!mock) { - return null; - } - return new URLStreamHandler() { - - /** {@inheritDoc} */ - @Override - protected URLConnection openConnection(URL u, Proxy proxy) { - return openConnection(u); + for (int i = 0; i < 4; ++i) { + Path nd = dir.resolve("neuron" + i); + for (int j = 0; j < 2; ++j) { + Path nc = nd.resolve("neuron_core" + j); + Files.createDirectories(nc); } - - /** {@inheritDoc} */ - @Override - protected URLConnection openConnection(URL u) { - return new MockHttpURLConnection(u); - } - }; - } - - final class MockHttpURLConnection extends HttpURLConnection { - - public MockHttpURLConnection(URL u) { - super(u); - } - - /** {@inheritDoc} */ - @Override - public void disconnect() {} - - /** {@inheritDoc} */ - @Override - public boolean usingProxy() { - return true; - } - - /** {@inheritDoc} */ - @Override - public void connect() {} - - /** {@inheritDoc} */ - @Override - public InputStream getInputStream() throws IOException { - switch (mockMode) { - case 0: - // EC2 inf1.2xlarge - return new ByteArrayInputStream( - "inf1.2xlarge".getBytes(StandardCharsets.UTF_8)); - case 1: - // EC2 inf2.24xlarge - return new ByteArrayInputStream( - "inf2.24xlarge".getBytes(StandardCharsets.UTF_8)); - case 2: - // EC2 trn1.32xlarge - return new ByteArrayInputStream( - "trn1.32xlarge".getBytes(StandardCharsets.UTF_8)); - case 3: - // EC2 c5.xlarge - return new ByteArrayInputStream( - "c5.xlarge".getBytes(StandardCharsets.UTF_8)); - default: - // non-AWS - throw new IOException("Timeout"); - } - } - - /** {@inheritDoc} */ - @Override - public int getResponseCode() { - return HttpURLConnection.HTTP_OK; } + NeuronUtils.hasNeuron(); + Assert.assertEquals(NeuronUtils.getNeuronCores(dir.toString()), 8); + } finally { + Utils.deleteQuietly(dir); } } }