Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes to improve Apple Silicon MPS support #2873

Merged
merged 6 commits into from
Nov 28, 2023

Conversation

petebankhead
Copy link
Contributor

@petebankhead petebankhead commented Nov 25, 2023

Description

This PR contains minor changes to NDArrayEx and Classifications to avoid converting float32 to float64 when not required.

I have been using this workarounds here to enable MPS support on Apple Silicon for image classification tasks using PyTorch, and have seen a major speedup (as expected).

I added tests as a separate commit in case they are useful, but the relevant changes will only be noticed when running on Apple Silicon.

Update: Tests now moved out of separate files and into MpsTest

This PR relates to #2044

Happy to add any required copyright statements to the tests (or remove them entirely), or look for other MPS-related issues if this is useful. I couldn't get the full tests to run on Apple Silicon as it currently stands due to TensorFlow not being supported.

This fixes an issue where calling NDArrayEx.toTensor() fails on Apple Silicon due to a lack of support for float64.
Don't convert probabilities from float32 to float64, because this causes a failure on Apple Silicon.
Calling `NDArray.set(buf)` with PyTorch with MPS as the device can result in a fatal error (SIGSEGV) occurring soon afterwards.
@petebankhead
Copy link
Contributor Author

I have added a commit to address #2504

When combined with the other changes, it is possible to train using Mnist and FashionMnist with MPS as the device.

There seems to be a deeper issue though, related to the use of NDArray.set(buf). This appears to be the source of the reported error. The error occurred only when NDarray.toType(DataType.FLOAT32, false) was called, but toType appears to work if set(buf) is avoided. So the workaround here is to use the buffer at the time the array is created, rather than afterwards.

Here is a quick test to illustrate what happens:

    private static void quickTest() {
        Device device = Device.of("mps", -1);
        try (NDManager manager = NDManager.newBaseManager(device)) {
            // This works on MPS
            try (NDArray array = manager.create(ByteBuffer.wrap(new byte[]{127}), new Shape(1, 1, 1, 1), DataType.UINT8)) {
                NDArray output = array.toType(DataType.FLOAT32, false);
                System.out.println(output);
            }
            // This fails on MPS (SIGSEGV)
            try (NDArray array = manager.create(new Shape(1, 1, 1, 1), DataType.UINT8)) {
                array.set(new byte[]{127});
                NDArray output = array.toType(DataType.FLOAT32, false);
                System.out.println(output);
            }
        }
    }

To train, I used the following code based on the MNIST tutorial:

        Device device = Device.of("mps", 0);
        try (NDManager manager = NDManager.newBaseManager(device)) {
            Mnist mnist = Mnist.builder()
                    .optDevice(device)
                    .optManager(manager)
                    .setSampling(batchSize, true)
                    .build();

            mnist.prepare();

            Model model = Model.newInstance("mlp", device);
            model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));

            DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                    .optDevices(new Device[]{device})
                    .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
                    .addTrainingListeners(TrainingListener.Defaults.logging());

            // Now that we have our training configuration, we should create a new trainer for our model
            Trainer trainer = model.newTrainer(config);
            System.out.println("DEVICES: " + Arrays.toString(trainer.getDevices()));

            trainer.setMetrics(new Metrics());
            trainer.initialize(new Shape(1, 28*28));

            // Deep learning is typically trained in epochs where each epoch trains the model on each item in the dataset once.
            int epoch = 20;

            long start = System.currentTimeMillis();
            RandomAccessDataset split[] = mnist.randomSplit(6, 4);
            EasyTrain.fit(trainer, epoch, split[0], split[1]);
            long end = System.currentTimeMillis();

            System.out.println("Time: " + (end - start) / 1000.0);
        } catch (Exception e) {
            e.printStackTrace();
        }

Using Device.of("mps", 0) was important; if I change to Device.of("mps", -1) I see

java.lang.NullPointerException: Cannot invoke "java.lang.Integer.intValue()" because the return value of "java.util.Map.get(Object)" is null
	at ai.djl.training.ParameterStore.getValue(ParameterStore.java:104)
	at ai.djl.nn.core.Linear.forwardInternal(Linear.java:93)
	at ai.djl.nn.AbstractBaseBlock.forwardInternal(AbstractBaseBlock.java:128)
	at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:93)
	at ai.djl.nn.SequentialBlock.forwardInternal(SequentialBlock.java:211)
	at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:93)
	at ai.djl.training.Trainer.forward(Trainer.java:187)
	at ai.djl.training.EasyTrain.trainSplit(EasyTrain.java:122)
	at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:110)
	at ai.djl.training.EasyTrain.fit(EasyTrain.java:58)
	at ai.djl.pytorch.integration.MpsTest.main(MpsTest.java:114)

because ParameterStore.deviceMap expects the value to be 0.

@codecov-commenter
Copy link

codecov-commenter commented Nov 26, 2023

Codecov Report

Attention: 1365 lines in your changes are missing coverage. Please review.

Comparison is base (bb5073f) 72.08% compared to head (a5c9623) 72.31%.
Report is 921 commits behind head on master.

Files Patch % Lines
...va/ai/djl/modality/nlp/generate/TextGenerator.java 2.81% 276 Missing ⚠️
.../java/ai/djl/modality/nlp/generate/SeqBatcher.java 0.75% 132 Missing ⚠️
...ity/nlp/generate/ContrastiveSeqBatchScheduler.java 2.97% 98 Missing ⚠️
...i/djl/modality/nlp/generate/SeqBatchScheduler.java 9.83% 55 Missing ⚠️
.../java/ai/djl/modality/cv/BufferedImageFactory.java 40.96% 47 Missing and 2 partials ⚠️
...a/ai/djl/modality/nlp/generate/StepGeneration.java 2.04% 48 Missing ⚠️
api/src/main/java/ai/djl/ndarray/NDArray.java 43.42% 39 Missing and 4 partials ⚠️
...n/java/ai/djl/modality/cv/output/CategoryMask.java 22.00% 39 Missing ⚠️
...i/src/main/java/ai/djl/ndarray/NDArrayAdapter.java 71.21% 31 Missing and 7 partials ⚠️
.../cv/translator/SemanticSegmentationTranslator.java 37.50% 35 Missing ⚠️
... and 76 more

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@             Coverage Diff              @@
##             master    #2873      +/-   ##
============================================
+ Coverage     72.08%   72.31%   +0.22%     
- Complexity     5126     7166    +2040     
============================================
  Files           473      708     +235     
  Lines         21970    31922    +9952     
  Branches       2351     3317     +966     
============================================
+ Hits          15838    23084    +7246     
- Misses         4925     7260    +2335     
- Partials       1207     1578     +371     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@frankfliu frankfliu merged commit 36d4aec into deepjavalibrary:master Nov 28, 2023
5 checks passed
frankfliu added a commit that referenced this pull request Apr 26, 2024
* Support 32-bit toTensor()

This fixes an issue where calling NDArrayEx.toTensor() fails on Apple Silicon due to a lack of support for float64.

* Avoid float64 conversion in Classifications constructor

Don't convert probabilities from float32 to float64, because this causes a failure on Apple Silicon.

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants