Skip to content

Commit

Permalink
[serving] Fixes NeuronUtils for SageMaker
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Mar 29, 2023
1 parent 7c63c23 commit 84737c2
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 178 deletions.
3 changes: 2 additions & 1 deletion serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,8 @@ private String[] parseDevices(String devices, Engine engine, Path modelDir) {
String v = Utils.getenv("TENSOR_PARALLEL_DEGREE", "-1");
v = prop.getProperty("option.tensor_parallel_degree", v);
int tensorParallelDegree = Integer.parseInt(v);
if (tensorParallelDegree > 0 && NeuronUtils.isInf2()) {
if (tensorParallelDegree > 0) {
// Assume user understand TP only works on inf2
int procs = neurons / tensorParallelDegree;
if (procs == 0) {
throw new EngineException(
Expand Down
73 changes: 24 additions & 49 deletions serving/src/main/java/ai/djl/serving/util/NeuronUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,30 @@
*/
package ai.djl.serving.util;

import ai.djl.util.Ec2Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Iterator;
import java.util.stream.Stream;

/** A utility class to detect number of nueron cores. */
public final class NeuronUtils {

private static int instanceType = -1;
private static final Logger logger = LoggerFactory.getLogger(NeuronUtils.class);

private NeuronUtils() {}

static void setInstanceType(int type) {
instanceType = type;
}

/**
* Returns whether Neuron runtime library is in the system.
*
* @return {@code true} if Neuron runtime library is in the system
*/
public static boolean hasNeuron() {
return isInf1() || isInf2();
}

/**
* Returns whether the instance is an inf1.
*
* @return {code true} if the instance is an inf1
*/
public static boolean isInf1() {
return getInstanceType() == 1;
}

/**
* Returns whether the instance is an inf2 or trn1.
*
* @return {code true} if the instance is an inf2 or trn1
*/
public static boolean isInf2() {
return getInstanceType() == 2;
return getNeuronCores() > 0;
}

/**
Expand All @@ -64,36 +44,31 @@ public static boolean isInf2() {
* @return the number of NeuronCores available in the system
*/
public static int getNeuronCores() {
if (!hasNeuron()) {
return getNeuronCores("/sys/devices/virtual/neuron_device/");
}

static int getNeuronCores(String location) {
Path path = Paths.get(location);
if (!Files.exists(path)) {
return 0;
}
try (Stream<Path> paths = Files.list(Paths.get("/dev"))) {
long nd = paths.filter(p -> p.getFileName().toString().startsWith("neuron")).count();
if (isInf1()) {
// inf1 has 4 cores on each device
return (int) nd * 4;
int count = 0;
try (Stream<Path> dev = Files.list(path)) {
for (Iterator<Path> it = dev.iterator(); it.hasNext(); ) {
Path dir = it.next();
if (dir.getFileName().toString().startsWith("neuron")) {
Stream<Path> cores = Files.list(dir);
count += Math.toIntExact(cores.filter(NeuronUtils::matches).count());
cores.close();
}
}
// inf2 has 2 cores on each device
return (int) nd * 2;
} catch (IOException e) {
throw new AssertionError("Failed to list neuron cores", e);
logger.warn("Failed to list neuron cores", e);
}
return count;
}

@SuppressWarnings("PMD.NonThreadSafeSingleton")
private static int getInstanceType() {
if (instanceType == -1) {
String metadata = Ec2Utils.readMetadata("instance-type");
if (metadata == null) {
NeuronUtils.setInstanceType(0);
} else if (metadata.startsWith("inf1")) {
NeuronUtils.setInstanceType(1);
} else if (metadata.startsWith("inf2") || metadata.startsWith("trn1")) {
NeuronUtils.setInstanceType(2);
} else {
NeuronUtils.setInstanceType(0);
}
}
return instanceType;
private static boolean matches(Path p) {
return p.getFileName().toString().startsWith("neuron_core");
}
}
144 changes: 16 additions & 128 deletions serving/src/test/java/ai/djl/serving/util/NeuronUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,145 +12,33 @@
*/
package ai.djl.serving.util;

import ai.djl.util.Utils;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.Proxy;
import java.net.URL;
import java.net.URLConnection;
import java.net.URLStreamHandler;
import java.net.URLStreamHandlerFactory;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

public class NeuronUtilsTest {

@Test
public void testNeuronUtils() {
MockURLStreamHandlerFactory factory = new MockURLStreamHandlerFactory();

public void testNeuronUtils() throws IOException {
Path dir = Paths.get("build/neuron_device/");
try {
URL.setURLStreamHandlerFactory(factory);

NeuronUtils.setInstanceType(-1);
factory.setMockMode(0); // inf1
Assert.assertTrue(NeuronUtils.hasNeuron());
Assert.assertTrue(NeuronUtils.isInf1());
if (System.getProperty("os.name").startsWith("Linux")) {
NeuronUtils.getNeuronCores();
}

NeuronUtils.setInstanceType(-1);
factory.setMockMode(1); // inf2
Assert.assertTrue(NeuronUtils.hasNeuron());
Assert.assertTrue(NeuronUtils.isInf2());

NeuronUtils.setInstanceType(-1);
factory.setMockMode(2); // inf1
Assert.assertTrue(NeuronUtils.hasNeuron());
Assert.assertTrue(NeuronUtils.isInf2());

NeuronUtils.setInstanceType(-1);
factory.setMockMode(3); // inf1
Assert.assertFalse(NeuronUtils.hasNeuron());

NeuronUtils.setInstanceType(-1);
factory.setMockMode(4); // inf1
Assert.assertFalse(NeuronUtils.hasNeuron());
} finally {
factory.setMock(false);
}
}

static final class MockURLStreamHandlerFactory implements URLStreamHandlerFactory {

private boolean mock = true;
private int mockMode;

public void setMock(boolean mock) {
this.mock = mock;
}

public void setMockMode(int mockMode) {
this.mockMode = mockMode;
}

/** {@inheritDoc} */
@Override
public URLStreamHandler createURLStreamHandler(String protocol) {
if (!mock) {
return null;
}
return new URLStreamHandler() {

/** {@inheritDoc} */
@Override
protected URLConnection openConnection(URL u, Proxy proxy) {
return openConnection(u);
for (int i = 0; i < 4; ++i) {
Path nd = dir.resolve("neuron" + i);
for (int j = 0; j < 2; ++j) {
Path nc = nd.resolve("neuron_core" + j);
Files.createDirectories(nc);
}

/** {@inheritDoc} */
@Override
protected URLConnection openConnection(URL u) {
return new MockHttpURLConnection(u);
}
};
}

final class MockHttpURLConnection extends HttpURLConnection {

public MockHttpURLConnection(URL u) {
super(u);
}

/** {@inheritDoc} */
@Override
public void disconnect() {}

/** {@inheritDoc} */
@Override
public boolean usingProxy() {
return true;
}

/** {@inheritDoc} */
@Override
public void connect() {}

/** {@inheritDoc} */
@Override
public InputStream getInputStream() throws IOException {
switch (mockMode) {
case 0:
// EC2 inf1.2xlarge
return new ByteArrayInputStream(
"inf1.2xlarge".getBytes(StandardCharsets.UTF_8));
case 1:
// EC2 inf2.24xlarge
return new ByteArrayInputStream(
"inf2.24xlarge".getBytes(StandardCharsets.UTF_8));
case 2:
// EC2 trn1.32xlarge
return new ByteArrayInputStream(
"trn1.32xlarge".getBytes(StandardCharsets.UTF_8));
case 3:
// EC2 c5.xlarge
return new ByteArrayInputStream(
"c5.xlarge".getBytes(StandardCharsets.UTF_8));
default:
// non-AWS
throw new IOException("Timeout");
}
}

/** {@inheritDoc} */
@Override
public int getResponseCode() {
return HttpURLConnection.HTTP_OK;
}
NeuronUtils.hasNeuron();
Assert.assertEquals(NeuronUtils.getNeuronCores(dir.toString()), 8);
} finally {
Utils.deleteQuietly(dir);
}
}
}

0 comments on commit 84737c2

Please sign in to comment.