-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minor fixes to improve Apple Silicon MPS support #2873
Conversation
This fixes an issue where calling NDArrayEx.toTensor() fails on Apple Silicon due to a lack of support for float64.
Don't convert probabilities from float32 to float64, because this causes a failure on Apple Silicon.
Calling `NDArray.set(buf)` with PyTorch with MPS as the device can result in a fatal error (SIGSEGV) occurring soon afterwards.
I have added a commit to address #2504 When combined with the other changes, it is possible to train using Mnist and FashionMnist with MPS as the device. There seems to be a deeper issue though, related to the use of Here is a quick test to illustrate what happens: private static void quickTest() {
Device device = Device.of("mps", -1);
try (NDManager manager = NDManager.newBaseManager(device)) {
// This works on MPS
try (NDArray array = manager.create(ByteBuffer.wrap(new byte[]{127}), new Shape(1, 1, 1, 1), DataType.UINT8)) {
NDArray output = array.toType(DataType.FLOAT32, false);
System.out.println(output);
}
// This fails on MPS (SIGSEGV)
try (NDArray array = manager.create(new Shape(1, 1, 1, 1), DataType.UINT8)) {
array.set(new byte[]{127});
NDArray output = array.toType(DataType.FLOAT32, false);
System.out.println(output);
}
}
} To train, I used the following code based on the MNIST tutorial: Device device = Device.of("mps", 0);
try (NDManager manager = NDManager.newBaseManager(device)) {
Mnist mnist = Mnist.builder()
.optDevice(device)
.optManager(manager)
.setSampling(batchSize, true)
.build();
mnist.prepare();
Model model = Model.newInstance("mlp", device);
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.optDevices(new Device[]{device})
.addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
.addTrainingListeners(TrainingListener.Defaults.logging());
// Now that we have our training configuration, we should create a new trainer for our model
Trainer trainer = model.newTrainer(config);
System.out.println("DEVICES: " + Arrays.toString(trainer.getDevices()));
trainer.setMetrics(new Metrics());
trainer.initialize(new Shape(1, 28*28));
// Deep learning is typically trained in epochs where each epoch trains the model on each item in the dataset once.
int epoch = 20;
long start = System.currentTimeMillis();
RandomAccessDataset split[] = mnist.randomSplit(6, 4);
EasyTrain.fit(trainer, epoch, split[0], split[1]);
long end = System.currentTimeMillis();
System.out.println("Time: " + (end - start) / 1000.0);
} catch (Exception e) {
e.printStackTrace();
} Using
because |
Codecov ReportAttention:
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## master #2873 +/- ##
============================================
+ Coverage 72.08% 72.31% +0.22%
- Complexity 5126 7166 +2040
============================================
Files 473 708 +235
Lines 21970 31922 +9952
Branches 2351 3317 +966
============================================
+ Hits 15838 23084 +7246
- Misses 4925 7260 +2335
- Partials 1207 1578 +371 ☔ View full report in Codecov by Sentry. |
* Support 32-bit toTensor() This fixes an issue where calling NDArrayEx.toTensor() fails on Apple Silicon due to a lack of support for float64. * Avoid float64 conversion in Classifications constructor Don't convert probabilities from float32 to float64, because this causes a failure on Apple Silicon. Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
Description
This PR contains minor changes to
NDArrayEx
andClassifications
to avoid converting float32 to float64 when not required.I have been using this workarounds here to enable MPS support on Apple Silicon for image classification tasks using PyTorch, and have seen a major speedup (as expected).
I added tests as a separate commit in case they are useful, but the relevant changes will only be noticed when running on Apple Silicon.
This PR relates to #2044
Happy to add any required copyright statements to the tests (or remove them entirely), or look for other MPS-related issues if this is useful. I couldn't get the full tests to run on Apple Silicon as it currently stands due to TensorFlow not being supported.