Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segfault from Dataset.prepare() on MPS device #2504

Open
Ubehebe opened this issue Apr 5, 2023 · 1 comment
Open

Segfault from Dataset.prepare() on MPS device #2504

Ubehebe opened this issue Apr 5, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@Ubehebe
Copy link

Ubehebe commented Apr 5, 2023

Description

I am trying to combine the MNIST tutorial with the directions in #2037 to train a model on the MPS pytorch backend. I have gotten the MPSTest in that PR to pass on my M2 machine: it successfully creates an NDArray on the MPS device.

Expected Behavior

I expected the MNIST dataset to prepare successfully on the MPS device, so that I could create a model from it. (Please let me know if this is unreasonable, or if I have to prepare the dataset on the CPU and then use some other API to load it onto MPS.)

Error Message

#
# A fatal error has been detected by the Java Runtime Environment:
#
#  SIGSEGV (0xb) at pc=0x0000000197d43f00, pid=94738, tid=8707
#
# JRE version: OpenJDK Runtime Environment Zulu17.32+13-CA (17.0.2+8) (build 17.0.2+8-LTS)
# Java VM: OpenJDK 64-Bit Server VM Zulu17.32+13-CA (17.0.2+8-LTS, mixed mode, sharing, tiered, compressed oops, compressed class ptrs, g1 gc, bsd-aarch64)
# Problematic frame:
# C  [libobjc.A.dylib+0x7f00]  objc_retain+0x10
#
# No core dump will be written. Core dumps have been disabled. To enable core dumping, try "ulimit -c unlimited" before starting Java again
#
# If you would like to submit a bug report, please visit:
#   http://www.azul.com/support/
# The crash happened outside the Java Virtual Machine in native code.
# See problematic frame for where to report the bug.
#

---------------  S U M M A R Y ------------

Command Line: ml.djl.DjlSampleKt

Host: "Mac14,5" arm64 1 MHz, 12 cores, 64G, Darwin 22.4.0, macOS 13.3 (22E252)
Time: Wed Apr  5 07:23:05 2023 PDT elapsed time: 0.326062 seconds (0d 0h 0m 0s)

---------------  T H R E A D  ---------------

Current thread (0x000000014500dc00):  JavaThread "main" [_thread_in_native, id=8707, stack(0x000000016b62c000,0x000000016b82f000)]

Stack: [0x000000016b62c000,0x000000016b82f000],  sp=0x000000016b82d7a0,  free space=2053k
Native frames: (J=compiled Java code, j=interpreted, Vv=VM code, C=native code)
C  [libobjc.A.dylib+0x7f00]  objc_retain+0x10
C  [libtorch_cpu.dylib+0x4458830]  at::native::mps::copy_cast_mps(at::Tensor&, at::Tensor const&, id<MTLBuffer>, objc_object<MTLBuffer>, bool)+0x2ec
C  [libtorch_cpu.dylib+0x445a958]  at::native::mps::mps_copy_(at::Tensor&, at::Tensor const&, bool)+0x1e10
C  [libtorch_cpu.dylib+0x49b6d8]  at::native::copy_impl(at::Tensor&, at::Tensor const&, bool)+0x5cc
C  [libtorch_cpu.dylib+0x49b04c]  at::native::copy_(at::Tensor&, at::Tensor const&, bool)+0x64
C  [libtorch_cpu.dylib+0x10d3f14]  at::_ops::copy_::call(at::Tensor&, at::Tensor const&, bool)+0x120
C  [libtorch_cpu.dylib+0x7c0f0c]  at::native::_to_copy(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0xbd8
C  [libtorch_cpu.dylib+0xc64120]  at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0xbc
C  [libtorch_cpu.dylib+0xc64120]  at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0xbc
C  [libtorch_cpu.dylib+0x280348c]  c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>), &torch::autograd::VariableType::(anonymous namespace)::_to_copy(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>>>, at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0x448
C  [libtorch_cpu.dylib+0xc63de8]  at::_ops::_to_copy::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0x154
C  [libtorch_cpu.dylib+0xe1bed0]  at::_ops::to_device::call(at::Tensor const&, c10::Device, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>)+0x140
C  [0.21.0-libdjl_torch.dylib+0x78974]  at::Tensor::to(c10::Device, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>) const+0x8c
C  [0.21.0-libdjl_torch.dylib+0x786c4]  Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTo+0xb4
j  ai.djl.pytorch.jni.PyTorchLibrary.torchTo(JI[I)J+0
j  ai.djl.pytorch.jni.JniUtils.to(Lai/djl/pytorch/engine/PtNDArray;Lai/djl/ndarray/types/DataType;Lai/djl/Device;)Lai/djl/pytorch/engine/PtNDArray;+61
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/pytorch/engine/PtNDArray;+23
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/ndarray/NDArray;+3
j  ai.djl.basicdataset.cv.classification.Mnist.readLabel(Lai/djl/repository/Artifact$Item;)Lai/djl/ndarray/NDArray;+88
j  ai.djl.basicdataset.cv.classification.Mnist.prepare(Lai/djl/util/Progress;)V+146
j  ai.djl.training.dataset.Dataset.prepare()V+2
j  ml.djl.DjlSampleKt.main([Ljava/lang/String;)V+36
v  ~StubRoutines::call_stub
V  [libjvm.dylib+0x46b270]  JavaCalls::call_helper(JavaValue*, methodHandle const&, JavaCallArguments*, JavaThread*)+0x38c
V  [libjvm.dylib+0x4cfa64]  jni_invoke_static(JNIEnv_*, JavaValue*, _jobject*, JNICallType, _jmethodID*, JNI_ArgumentPusher*, JavaThread*)+0x12c
V  [libjvm.dylib+0x4d30f8]  jni_CallStaticVoidMethod+0x130
C  [libjli.dylib+0x5378]  JavaMain+0x9d4
C  [libjli.dylib+0x76e8]  ThreadJavaMain+0xc
C  [libsystem_pthread.dylib+0x6fa8]  _pthread_start+0x94

Java frames: (J=compiled Java code, j=interpreted, Vv=VM code)
j  ai.djl.pytorch.jni.PyTorchLibrary.torchTo(JI[I)J+0
j  ai.djl.pytorch.jni.JniUtils.to(Lai/djl/pytorch/engine/PtNDArray;Lai/djl/ndarray/types/DataType;Lai/djl/Device;)Lai/djl/pytorch/engine/PtNDArray;+61
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/pytorch/engine/PtNDArray;+23
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/ndarray/NDArray;+3
j  ai.djl.basicdataset.cv.classification.Mnist.readLabel(Lai/djl/repository/Artifact$Item;)Lai/djl/ndarray/NDArray;+88
j  ai.djl.basicdataset.cv.classification.Mnist.prepare(Lai/djl/util/Progress;)V+146
j  ai.djl.training.dataset.Dataset.prepare()V+2
j  ml.djl.DjlSampleKt.main([Ljava/lang/String;)V+36
v  ~StubRoutines::call_stub

The top of the JVM stack at the moment of the segfault is this line, where the dataset is casting the NDArray to float32:

try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) {
    array.set(buf);
    return array.toType(DataType.FLOAT32, false);
}

The top of the native stack is objc_retain, followed by copy_cast_mps. (I can provide more details from the JVM crash report if that would be valuable.)

How to Reproduce?

Here is a minimal repro. It's in Kotlin, but I'm happy to rewrite in Java if you prefer:

import ai.djl.Device
import ai.djl.basicdataset.cv.classification.Mnist
import ai.djl.ndarray.NDManager

fun main(args: Array<String>) {
  val device = Device.fromName("mps")
  val manager = NDManager.newBaseManager(device)
  Mnist.builder()
      .optManager(manager)
      .setSampling(32 /* batchSize */, true /* random */)
      .build()
      .prepare() // segfault!
}

Steps to reproduce

Run the program above, with the pytorch-engine, pytorch-jni, and pytorch-native-cpu-osx-aarch64 jars on the runtime classpath.

What have you tried to solve it?

Nothing, beyond isolating this repro. (I'm new to ML in general, but have a lot of JVM experience.)

Environment Info

I'm using DJL v0.21.0 and pytorch-jni v1.13.1. System information:

$ uname -mrsv
Darwin 22.4.0 Darwin Kernel Version 22.4.0: Mon Mar  6 20:59:58 PST 2023; root:xnu-8796.101.5~3/RELEASE_ARM64_T6020 arm64
@Ubehebe Ubehebe added the bug Something isn't working label Apr 5, 2023
@frankfliu
Copy link
Contributor

MPS has many limitations, and we do observe crash when using MPS device. See: #2044

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants