From f22398740fd8b2968d547d241b47e7bba05ae8a8 Mon Sep 17 00:00:00 2001 From: hrayrm Date: Tue, 27 Sep 2022 18:56:44 +0930 Subject: [PATCH 1/3] Add support for Apples' Metal Performance Shaders (MPS) in pytorch engine. --- api/src/main/java/ai/djl/Device.java | 28 +++++++++++++++---- .../ai/djl/pytorch/engine/PtDeviceType.java | 4 +++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/api/src/main/java/ai/djl/Device.java b/api/src/main/java/ai/djl/Device.java index ce9b29ae5ba..b4926e931b8 100644 --- a/api/src/main/java/ai/djl/Device.java +++ b/api/src/main/java/ai/djl/Device.java @@ -21,10 +21,10 @@ import java.util.regex.Pattern; /** - * The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code + * The {@code Device} class provides the specified assignment for CPU/GPU/MPS processing on the {@code * NDArray}. * - *

Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU with + *

Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU/MPS with * deviceType and deviceId provided. * * @see The D2L chapter @@ -36,6 +36,7 @@ public final class Device { private static final Device CPU = new Device(Type.CPU, -1); private static final Device GPU = Device.of(Type.GPU, 0); + private static final Device MPS = Device.of(Type.MPS, -1); private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)"); @@ -45,7 +46,7 @@ public final class Device { /** * Creates a {@code Device} with basic information. * - * @param deviceType the device type, typically CPU or GPU + * @param deviceType the device type, typically CPU, GPU, or MPS * @param deviceId the deviceId on the hardware. For example, if you have multiple GPUs, you can * choose which GPU to process the NDArray */ @@ -57,7 +58,7 @@ private Device(String deviceType, int deviceId) { /** * Returns a {@code Device} with device type and device id. * - * @param deviceType the device type, typically CPU or GPU + * @param deviceType the device type, typically CPU, GPU, or MPS * @param deviceId the deviceId on the hardware. * @return a {@code Device} instance */ @@ -83,7 +84,7 @@ public static Device fromName(String deviceName) { /** * Parses a deviceName string into a device. * - *

The main format of a device name string is "cpu", "gpu0", or "nc1". This is simply + *

The main format of a device name string is "cpu", "gpu0","mps", or "nc1". This is simply * deviceType concatenated with the deviceId. If no deviceId is used, -1 will be assumed. * *

There are also several simplified formats. The "-1", deviceNames corresponds to cpu. @@ -150,6 +151,15 @@ public boolean isGpu() { return Type.GPU.equals(deviceType); } + /** + * Returns if the {@code Device} is MPS. + * + * @return if the {@code Device} is MPS. + */ + public boolean isMps() { + return Type.MPS.equals(deviceType); + } + /** {@inheritDoc} */ @Override public String toString() { @@ -209,9 +219,17 @@ public static Device gpu(int deviceId) { return of(Type.GPU, deviceId); } + /** + * Returns the default Metal Performance Shaders (MPS) Device. + * + * @return the default MPS Device + */ + public static Device mps() { return MPS; } + /** Contains device type string constants. */ public interface Type { String CPU = "cpu"; String GPU = "gpu"; + String MPS = "mps"; } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java index 78b62034590..6d4b2002c6b 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java @@ -32,6 +32,8 @@ public static int toDeviceType(Device device) { return 0; } else if (Device.Type.GPU.equals(deviceType)) { return 1; + } else if (Device.Type.MPS.equals(deviceType)) { + return 13; } else { throw new IllegalArgumentException("Unsupported device: " + device.toString()); } @@ -49,6 +51,8 @@ public static String fromDeviceType(int deviceType) { return Device.Type.CPU; case 1: return Device.Type.GPU; + case 13: + return Device.Type.MPS; default: throw new IllegalArgumentException("Unsupported deviceType: " + deviceType); } From e7dd131d64ef9b6dda90814e20fc3e4c9ac369e1 Mon Sep 17 00:00:00 2001 From: hrayrm Date: Tue, 27 Sep 2022 20:21:32 +0930 Subject: [PATCH 2/3] Add support for Apples' Metal Performance Shaders (MPS) in pytorch engine. --- api/src/main/java/ai/djl/Device.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/api/src/main/java/ai/djl/Device.java b/api/src/main/java/ai/djl/Device.java index b4926e931b8..36d18f1f050 100644 --- a/api/src/main/java/ai/djl/Device.java +++ b/api/src/main/java/ai/djl/Device.java @@ -21,8 +21,8 @@ import java.util.regex.Pattern; /** - * The {@code Device} class provides the specified assignment for CPU/GPU/MPS processing on the {@code - * NDArray}. + * The {@code Device} class provides the specified assignment for CPU/GPU/MPS processing on the + * {@code NDArray}. * *

Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU/MPS with * deviceType and deviceId provided. @@ -220,11 +220,13 @@ public static Device gpu(int deviceId) { } /** - * Returns the default Metal Performance Shaders (MPS) Device. + * Returns the default Metal Performance Shaders (MPS) Device. * * @return the default MPS Device */ - public static Device mps() { return MPS; } + public static Device mps() { + return MPS; + } /** Contains device type string constants. */ public interface Type { From b9df9639cddddbd2136b9626df48a933f9efd0c5 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 28 Sep 2022 13:06:58 -0700 Subject: [PATCH 3/3] Add Unit test --- api/src/main/java/ai/djl/Device.java | 32 +++------------ .../ai/djl/pytorch/engine/PtDeviceType.java | 6 +-- .../ai/djl/pytorch/integration/MpsTest.java | 39 +++++++++++++++++++ 3 files changed, 48 insertions(+), 29 deletions(-) create mode 100644 engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java diff --git a/api/src/main/java/ai/djl/Device.java b/api/src/main/java/ai/djl/Device.java index 36d18f1f050..ce9b29ae5ba 100644 --- a/api/src/main/java/ai/djl/Device.java +++ b/api/src/main/java/ai/djl/Device.java @@ -21,10 +21,10 @@ import java.util.regex.Pattern; /** - * The {@code Device} class provides the specified assignment for CPU/GPU/MPS processing on the - * {@code NDArray}. + * The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code + * NDArray}. * - *

Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU/MPS with + *

Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU with * deviceType and deviceId provided. * * @see The D2L chapter @@ -36,7 +36,6 @@ public final class Device { private static final Device CPU = new Device(Type.CPU, -1); private static final Device GPU = Device.of(Type.GPU, 0); - private static final Device MPS = Device.of(Type.MPS, -1); private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)"); @@ -46,7 +45,7 @@ public final class Device { /** * Creates a {@code Device} with basic information. * - * @param deviceType the device type, typically CPU, GPU, or MPS + * @param deviceType the device type, typically CPU or GPU * @param deviceId the deviceId on the hardware. For example, if you have multiple GPUs, you can * choose which GPU to process the NDArray */ @@ -58,7 +57,7 @@ private Device(String deviceType, int deviceId) { /** * Returns a {@code Device} with device type and device id. * - * @param deviceType the device type, typically CPU, GPU, or MPS + * @param deviceType the device type, typically CPU or GPU * @param deviceId the deviceId on the hardware. * @return a {@code Device} instance */ @@ -84,7 +83,7 @@ public static Device fromName(String deviceName) { /** * Parses a deviceName string into a device. * - *

The main format of a device name string is "cpu", "gpu0","mps", or "nc1". This is simply + *

The main format of a device name string is "cpu", "gpu0", or "nc1". This is simply * deviceType concatenated with the deviceId. If no deviceId is used, -1 will be assumed. * *

There are also several simplified formats. The "-1", deviceNames corresponds to cpu. @@ -151,15 +150,6 @@ public boolean isGpu() { return Type.GPU.equals(deviceType); } - /** - * Returns if the {@code Device} is MPS. - * - * @return if the {@code Device} is MPS. - */ - public boolean isMps() { - return Type.MPS.equals(deviceType); - } - /** {@inheritDoc} */ @Override public String toString() { @@ -219,19 +209,9 @@ public static Device gpu(int deviceId) { return of(Type.GPU, deviceId); } - /** - * Returns the default Metal Performance Shaders (MPS) Device. - * - * @return the default MPS Device - */ - public static Device mps() { - return MPS; - } - /** Contains device type string constants. */ public interface Type { String CPU = "cpu"; String GPU = "gpu"; - String MPS = "mps"; } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java index 6d4b2002c6b..a7396dd8a48 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java @@ -32,10 +32,10 @@ public static int toDeviceType(Device device) { return 0; } else if (Device.Type.GPU.equals(deviceType)) { return 1; - } else if (Device.Type.MPS.equals(deviceType)) { + } else if ("mps".equals(deviceType)) { return 13; } else { - throw new IllegalArgumentException("Unsupported device: " + device.toString()); + throw new IllegalArgumentException("Unsupported device: " + device); } } @@ -52,7 +52,7 @@ public static String fromDeviceType(int deviceType) { case 1: return Device.Type.GPU; case 13: - return Device.Type.MPS; + return "mps"; default: throw new IllegalArgumentException("Unsupported deviceType: " + deviceType); } diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java new file mode 100644 index 00000000000..8b4e2326f26 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.integration; + +import ai.djl.Device; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; + +import org.testng.Assert; +import org.testng.SkipException; +import org.testng.annotations.Test; + +public class MpsTest { + + @Test + public void testMps() { + if (!"aarch64".equals(System.getProperty("os.arch")) + || !System.getProperty("os.name").startsWith("Mac")) { + throw new SkipException("MPS test requires M1 macOS."); + } + + Device device = Device.of("mps", -1); + try (NDManager manager = NDManager.newBaseManager(device)) { + NDArray array = manager.zeros(new Shape(1, 2)); + Assert.assertEquals(array.getDevice().getDeviceType(), "mps"); + } + } +}