diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java index 78b62034590..a7396dd8a48 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java @@ -32,8 +32,10 @@ public static int toDeviceType(Device device) { return 0; } else if (Device.Type.GPU.equals(deviceType)) { return 1; + } else if ("mps".equals(deviceType)) { + return 13; } else { - throw new IllegalArgumentException("Unsupported device: " + device.toString()); + throw new IllegalArgumentException("Unsupported device: " + device); } } @@ -49,6 +51,8 @@ public static String fromDeviceType(int deviceType) { return Device.Type.CPU; case 1: return Device.Type.GPU; + case 13: + return "mps"; default: throw new IllegalArgumentException("Unsupported deviceType: " + deviceType); } diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java new file mode 100644 index 00000000000..8b4e2326f26 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.integration; + +import ai.djl.Device; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; + +import org.testng.Assert; +import org.testng.SkipException; +import org.testng.annotations.Test; + +public class MpsTest { + + @Test + public void testMps() { + if (!"aarch64".equals(System.getProperty("os.arch")) + || !System.getProperty("os.name").startsWith("Mac")) { + throw new SkipException("MPS test requires M1 macOS."); + } + + Device device = Device.of("mps", -1); + try (NDManager manager = NDManager.newBaseManager(device)) { + NDArray array = manager.zeros(new Shape(1, 2)); + Assert.assertEquals(array.getDevice().getDeviceType(), "mps"); + } + } +}