Skip to content

Commit

Permalink
Make MPS default device for macOS M1
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jun 1, 2024
1 parent ab1490e commit 10bbe4a
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 3 deletions.
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public abstract class Engine {
private static final Pattern PATTERN =
Pattern.compile("KEY|TOKEN|PASSWORD", Pattern.CASE_INSENSITIVE);

private Device defaultDevice;
protected Device defaultDevice;

// use object to check if it's set
private Integer seed;
Expand Down
2 changes: 1 addition & 1 deletion api/src/test/java/ai/djl/DeviceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void testDevice() {
if (engine.getGpuCount() > 0) {
Assert.assertEquals(Device.gpu(), engine.defaultDevice());
} else {
Assert.assertEquals(Device.cpu(), engine.defaultDevice());
Assert.assertEquals(engine.defaultDevice().getDeviceId(), -1);
}
Assert.assertEquals(Device.gpu(), Device.of("gpu", 0));
Assert.assertEquals(Device.gpu(3), Device.of("gpu", 3));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ public boolean hasCapability(String capability) {
return JniUtils.getFeatures().contains(capability);
}

/** {@inheritDoc} */
@Override
public Device defaultDevice() {
if (defaultDevice != null) {
return defaultDevice;
}
if ("aarch64".equals(System.getProperty("os.arch"))
&& System.getProperty("os.name").startsWith("Mac")) {
defaultDevice = Device.of("mps", -1);
return defaultDevice;
}
return super.defaultDevice();
}

/** {@inheritDoc} */
@Override
public SymbolBlock newSymbolBlock(NDManager manager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package ai.djl.pytorch.integration;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
Expand Down Expand Up @@ -58,6 +59,7 @@ public void testProfiler()
Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.optDevice(Device.cpu()) // MPS doesn't support float64 (by profiler)
.optFilter("layers", "18")
.optTranslator(translator)
.optProgress(new ProgressBar())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

package ai.djl.pytorch.integration;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDList;
Expand Down Expand Up @@ -93,9 +94,10 @@ public void testInputOutput() throws IOException, ModelException {
try (InputStream is = Files.newInputStream(modelFile)) {
PtSymbolBlock block = JniUtils.loadModule(manager, is, true, false);
ByteArrayOutputStream os = new ByteArrayOutputStream();
// writeModule with MPS cannot be loaded back on MPS
JniUtils.writeModule(block, os, true);
ByteArrayInputStream bis = new ByteArrayInputStream(os.toByteArray());
JniUtils.loadModule(manager, bis, true, true);
JniUtils.loadModuleHandle(bis, Device.cpu(), true, true);
bis.close();
os.close();
}
Expand Down

0 comments on commit 10bbe4a

Please sign in to comment.