From 5ffa9b7517c9fd4d2d3e472897c845d5b042f7f9 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 27 Feb 2016 20:44:26 +0800 Subject: [PATCH 1/7] add finalize() for Symbol & KVStore --- .../core/src/main/scala/ml/dmlc/mxnet/KVStore.scala | 6 ++++++ .../core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala | 2 ++ .../core/src/main/scala/ml/dmlc/mxnet/Symbol.scala | 6 ++++++ .../src/main/native/ml_dmlc_mxnet_native_c_api.cc | 10 ++++++++++ 4 files changed, 24 insertions(+) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala index 11f3376874ac..9f772bf92da8 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala @@ -23,9 +23,14 @@ object KVStore { } } +// scalastyle:off finalize class KVStore(private val handle: KVStoreHandle) { private var updaterFunc: MXKVStoreUpdater = null + override def finalize(): Unit = { + checkCall(_LIB.mxKVStoreFree(handle)) + } + /** * Initialize a single or a sequence of key-value pairs into the store. * For each key, one must init it before push and pull. @@ -202,3 +207,4 @@ class KVStore(private val handle: KVStoreHandle) { checkCall(_LIB.mxKVStoreSendCommmandToServers(handle, head, body)) } } +// scalastyle:off finalize diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 9fbdbe1dc1d8..c39dbc3365c0 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -84,6 +84,7 @@ class LibInfo { @native def mxKVStoreBarrier(handle: KVStoreHandle): Int @native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int @native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int + @native def mxKVStoreFree(handle: KVStoreHandle): Int // DataIter Funcs @native def mxListDataIters(handles: ListBuffer[DataIterCreator]): Int @@ -187,6 +188,7 @@ class LibInfo { // scalastyle:on parameterNum @native def mxSymbolSaveToFile(handle: SymbolHandle, fname: String): Int @native def mxSymbolCreateFromFile(fname: String, handle: SymbolHandleRef): Int + @native def mxSymbolFree(handle: SymbolHandle): Int // Random @native def mxRandomSeed(seed: Int): Int diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 80cc41c3decd..1f74fc7614ad 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -9,7 +9,12 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * Symbolic configuration API of mxnet. * @author Yizhi Liu */ +// scalastyle:off finalize class Symbol(private[mxnet] val handle: SymbolHandle) { + override def finalize(): Unit = { + checkCall(_LIB.mxSymbolFree(handle)) + } + def +(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Plus")(Array(this, other)) def +[@specialized(Int, Float, Double) V](other: V): Symbol = { Symbol.createFromListedSymbols("_PlusScalar")(Array(this), Map("scalar" -> other.toString)) @@ -713,6 +718,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { jsonStr.value } } +// scalastyle:on finalize object Symbol { private type SymbolCreateNamedFunc = Map[String, Any] => Symbol diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 12965581dc86..96cab0ae5ff5 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -420,6 +420,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetRank return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreFree + (JNIEnv * env, jobject obj, jlong ptr) { + return MXKVStoreFree((KVStoreHandle) ptr); +} + JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorOutputs (JNIEnv *env, jobject obj, jlong executorPtr, jobject outputs) { mx_uint outSize; @@ -676,6 +681,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetPadNum } // Symbol functions +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolFree + (JNIEnv * env, jobject obj, jlong ptr) { + return MXSymbolFree((SymbolHandle) ptr); +} + JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListAtomicSymbolCreators (JNIEnv *env, jobject obj, jobject symbolList) { mx_uint outSize; From 96a1bebe73599f09557e467f757b84dfee1d8ec6 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 28 Feb 2016 12:45:52 +0800 Subject: [PATCH 2/7] Fix missing DeleteLocalRef. Get JNIEnv from global jvm --- .../src/main/scala/ml/dmlc/mxnet/Base.scala | 1 + .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 1 + .../native/src/main/native/jni_helper_func.h | 7 -- .../main/native/ml_dmlc_mxnet_native_c_api.cc | 92 +++++++++++++------ 4 files changed, 65 insertions(+), 36 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 9f02c6e4e2e0..8bbfd18fe8a7 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -31,6 +31,7 @@ object Base { System.loadLibrary("mxnet-scala") val _LIB = new LibInfo + checkCall(_LIB.nativeLibInit()) // helper function definitions /** diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index c39dbc3365c0..39056bd86f07 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -9,6 +9,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * @author Yizhi Liu */ class LibInfo { + @native def nativeLibInit(): Int // NDArray @native def mxNDArrayFree(handle: NDArrayHandle): Int @native def mxGetLastError(): String diff --git a/scala-package/native/src/main/native/jni_helper_func.h b/scala-package/native/src/main/native/jni_helper_func.h index c86a451bbdce..cce9cb0efe22 100644 --- a/scala-package/native/src/main/native/jni_helper_func.h +++ b/scala-package/native/src/main/native/jni_helper_func.h @@ -3,13 +3,6 @@ #ifndef MXNET_SCALA_JNI_HELPER_FUNC_H #define MXNET_SCALA_JNI_HELPER_FUNC_H -// Define an env closure -// e.g. it can be used to implement java callback -typedef struct { - JNIEnv *env; - jobject obj; -} JNIClosure; - jlong getLongField(JNIEnv *env, jobject obj) { jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); jfieldID refFid = env->GetFieldID(refClass, "value", "J"); diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 96cab0ae5ff5..a1ccb78709a8 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -9,6 +9,13 @@ #include "jni_helper_func.h" #include "ml_dmlc_mxnet_native_c_api.h" // generated by javah +JavaVM *cached_jvm; + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_nativeLibInit + (JNIEnv *env, jobject obj) { + return env->GetJavaVM(&cached_jvm); +} + JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayCreateNone (JNIEnv *env, jobject obj, jobject ndArrayHandle) { NDArrayHandle out; @@ -154,6 +161,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayGetShape for (int i = 0; i < ndim; ++i) { jobject data = env->NewObject(integerClass, newInteger, pdata[i]); env->CallObjectMethod(dataBuf, arrayAppend, data); + env->DeleteLocalRef(data); } // set ndimRef @@ -237,12 +245,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayLoad for (int i = 0; i < outSize; ++i) { jobject handle = env->NewObject(longCls, longConst, outArr[i]); env->CallObjectMethod(jhandles, arrayAppend, handle); + env->DeleteLocalRef(handle); } // fill names for (int i = 0; i < outNameSize; ++i) { jstring jname = env->NewStringUTF(outNames[i]); env->CallObjectMethod(jnames, arrayAppend, jname); + env->DeleteLocalRef(jname); } return ret; @@ -258,6 +268,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySave jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); const char *key = env->GetStringUTFChars(jkey, 0); keys[i] = key; + env->DeleteLocalRef(jkey); } } @@ -274,6 +285,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySave for (int i = 0; i < numArgs; i++) { jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); env->ReleaseStringUTFChars(jkey, keys[i]); + env->DeleteLocalRef(jkey); } delete[] keys; } @@ -283,9 +295,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySave extern "C" void KVStoreUpdaterCallbackFunc (int key, NDArrayHandle recv, NDArrayHandle local, void *handle) { - JNIClosure *closure = (JNIClosure *) handle; - JNIEnv *env = closure->env; - jobject updaterFuncObjGlb = closure->obj; + jobject updaterFuncObjGlb = (jobject) handle; + + JNIEnv *env; + cached_jvm->AttachCurrentThread((void **)&env, NULL); // find java updater method jclass updtClass = env->GetObjectClass(updaterFuncObjGlb); @@ -300,6 +313,9 @@ extern "C" void KVStoreUpdaterCallbackFunc jobject ndLocal = env->NewObject(ndObjClass, ndObjConstructor, (jlong)local, true); env->CallVoidMethod(updaterFuncObjGlb, updtFunc, key, ndRecv, ndLocal); + + env->DeleteLocalRef(ndLocal); + env->DeleteLocalRef(ndRecv); // FIXME: This function can be called multiple times, // can we find a way to safely destroy these two objects ? // env->DeleteGlobalRef(updaterFuncObjGlb); @@ -309,11 +325,8 @@ extern "C" void KVStoreUpdaterCallbackFunc JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject updaterFuncObj) { jobject updaterFuncObjGlb = env->NewGlobalRef(updaterFuncObj); - JNIClosure *closure = new JNIClosure(); - closure->env = env; - closure->obj = updaterFuncObjGlb; return MXKVStoreSetUpdater((KVStoreHandle) kvStorePtr, - KVStoreUpdaterCallbackFunc, (void *) closure); + KVStoreUpdaterCallbackFunc, (void *) updaterFuncObjGlb); } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreIsWorkerNode @@ -518,23 +531,23 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env //IO funcs JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters (JNIEnv * env, jobject obj, jobject creators) { - jclass longCls = env->FindClass("java/lang/Long"); - jmethodID longConst = env->GetMethodID(longCls, "", "(J)V"); - - // scala.collection.mutable.ListBuffer append method - jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); - jmethodID listAppend = env->GetMethodID(listClass, - "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); - - // Get function list - DataIterCreator *outArray; - mx_uint outSize; - int ret = MXListDataIters(&outSize, &outArray); - for (int i = 0; i < outSize; ++i) { - env->CallObjectMethod(creators, listAppend, - env->NewObject(longCls, longConst, (long)outArray[i])); - } - return ret; + jclass longCls = env->FindClass("java/lang/Long"); + jmethodID longConst = env->GetMethodID(longCls, "", "(J)V"); + + // scala.collection.mutable.ListBuffer append method + jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + + // Get function list + DataIterCreator *outArray; + mx_uint outSize; + int ret = MXListDataIters(&outSize, &outArray); + for (int i = 0; i < outSize; ++i) { + env->CallObjectMethod(creators, listAppend, + env->NewObject(longCls, longConst, (long)outArray[i])); + } + return ret; } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterCreateIter @@ -546,18 +559,20 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterCreateIter char** vals = new char*[paramSize]; jstring jkey, jval; //use strcpy and release char* created by JNI inplace - for(int i=0; iGetObjectArrayElement(jkeys, i); const char* ckey = env->GetStringUTFChars(jkey, 0); keys[i] = new char[env->GetStringLength(jkey)]; strcpy(keys[i], ckey); env->ReleaseStringUTFChars(jkey, ckey); + env->DeleteLocalRef(jkey); jval = (jstring) env->GetObjectArrayElement(jvals, i); const char* cval = env->GetStringUTFChars(jval, 0); vals[i] = new char[env->GetStringLength(jval)]; strcpy(vals[i], cval); env->ReleaseStringUTFChars(jval, cval); + env->DeleteLocalRef(jval); } //create iter @@ -758,10 +773,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCreateAtomicSymbol jstring key = (jstring) env->GetObjectArrayElement(paramKeys, i); const char *rawKey = env->GetStringUTFChars(key, 0); keys[i] = rawKey; + env->DeleteLocalRef(key); jstring value = (jstring) env->GetObjectArrayElement(paramVals, i); const char *rawValue = env->GetStringUTFChars(value, 0); vals[i] = rawValue; + env->DeleteLocalRef(value); } SymbolHandle out; @@ -773,8 +790,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCreateAtomicSymbol for (int i = 0; i < paramSize; i++) { jstring key = (jstring) env->GetObjectArrayElement(paramKeys, i); env->ReleaseStringUTFChars(key, keys[i]); + env->DeleteLocalRef(key); + jstring value = (jstring) env->GetObjectArrayElement(paramVals, i); env->ReleaseStringUTFChars(value, vals[i]); + env->DeleteLocalRef(value); } delete[] keys; delete[] vals; @@ -803,6 +823,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCompose jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); const char *key = env->GetStringUTFChars(jkey, 0); keys[i] = key; + env->DeleteLocalRef(jkey); } } jlong *args = env->GetLongArrayElements(jargs, NULL); @@ -810,16 +831,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCompose int ret = MXSymbolCompose((SymbolHandle) symbolPtr, name, (mx_uint) argSize, keys, (SymbolHandle*) args); + env->ReleaseStringUTFChars(jname, name); + env->ReleaseLongArrayElements(jargs, args, 0); // release allocated memory if (jkeys != NULL) { for (int i = 0; i < argSize; i++) { jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); env->ReleaseStringUTFChars(jkey, keys[i]); + env->DeleteLocalRef(jkey); } delete[] keys; } - env->ReleaseStringUTFChars(jname, name); - env->ReleaseLongArrayElements(jargs, args, 0); return ret; } @@ -859,6 +881,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListArguments for (int i = 0; i < outSize; i++) { jstring argument = env->NewStringUTF(outStrArray[i]); env->CallObjectMethod(arguments, arrayAppend, argument); + env->DeleteLocalRef(argument); } return ret; @@ -876,6 +899,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListOutputs for (int i = 0; i < outSize; i++) { jstring output = env->NewStringUTF(outStrArray[i]); env->CallObjectMethod(outputs, arrayAppend, output); + env->DeleteLocalRef(output); } return ret; @@ -893,6 +917,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListAuxiliaryStates for (int i = 0; i < outSize; i++) { jstring output = env->NewStringUTF(outStrArray[i]); env->CallObjectMethod(outputs, arrayAppend, output); + env->DeleteLocalRef(output); } return ret; @@ -952,6 +977,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolInferType jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); const char *key = env->GetStringUTFChars(jkey, 0); keys[i] = key; + env->DeleteLocalRef(jkey); } } @@ -982,14 +1008,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolInferType for (int i = 0; i < inTypeSize; ++i) { jobject data = env->NewObject(integerClass, newInteger, inTypeData[i]); env->CallObjectMethod(jargTypeData, listAppend, data); + env->DeleteLocalRef(data); } for (int i = 0; i < outTypeSize; ++i) { jobject data = env->NewObject(integerClass, newInteger, outTypeData[i]); env->CallObjectMethod(joutTypeData, listAppend, data); + env->DeleteLocalRef(data); } for (int i = 0; i < auxTypeSize; ++i) { jobject data = env->NewObject(integerClass, newInteger, auxTypeData[i]); env->CallObjectMethod(jauxTypeData, listAppend, data); + env->DeleteLocalRef(data); } setIntField(env, jcomplete, complete); @@ -999,6 +1028,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolInferType for (int i = 0; i < numArgs; i++) { jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); env->ReleaseStringUTFChars(jkey, keys[i]); + env->DeleteLocalRef(jkey); } delete[] keys; } @@ -1036,14 +1066,14 @@ int FillSymbolInferShape (JNIEnv *env, jmethodID listAppend, jobject joutData, mx_uint shapeSize, const mx_uint *shapeNdim, const mx_uint **shapeData) { for (int i = 0; i < shapeSize; ++i) { - jintArray jshape; - jshape = env->NewIntArray(shapeNdim[i]); + jintArray jshape = env->NewIntArray(shapeNdim[i]); if (jshape == NULL) { // TODO: out of memory error thrown, return a specific error code ? return -1; } env->SetIntArrayRegion(jshape, 0, shapeNdim[i], (const jint *) shapeData[i]); env->CallObjectMethod(joutData, listAppend, jshape); + env->DeleteLocalRef(jshape); } return 0; } @@ -1058,6 +1088,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolInferShape jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); const char *key = env->GetStringUTFChars(jkey, 0); keys[i] = key; + env->DeleteLocalRef(jkey); } } @@ -1121,6 +1152,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolInferShape for (int i = 0; i < jnumArgs; i++) { jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); env->ReleaseStringUTFChars(jkey, keys[i]); + env->DeleteLocalRef(jkey); } delete[] keys; } @@ -1142,6 +1174,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBindX jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i); const char *key = env->GetStringUTFChars(jkey, 0); mapKeys[i] = key; + env->DeleteLocalRef(jkey); } jlong *auxStates = env->GetLongArrayElements(jauxArgsHandle, NULL); jint *gradReqType = env->GetIntArrayElements(jreqsArray, NULL); @@ -1172,6 +1205,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBindX for (int i = 0; i < numCtx; i++) { jstring jkey = (jstring) env->GetObjectArrayElement(jctxMapKeys, i); env->ReleaseStringUTFChars(jkey, mapKeys[i]); + env->DeleteLocalRef(jkey); } delete[] mapKeys; From c5cf6b95dd4ce0d5ecfc234977c015aff48877c8 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 29 Feb 2016 00:19:47 +0800 Subject: [PATCH 3/7] add notify shutdown --- .../src/main/scala/ml/dmlc/mxnet/Base.scala | 5 +++++ .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 2 ++ .../imclassification/TrainMnist.scala | 1 + .../native/src/main/native/jni_helper_func.h | 11 ++++++++-- .../main/native/ml_dmlc_mxnet_native_c_api.cc | 20 +++++++++---------- 5 files changed, 27 insertions(+), 12 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 8bbfd18fe8a7..5bf86cdb62d0 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -47,6 +47,11 @@ object Base { } } + // Notify MXNet about a shutdown + def notifyShutdown(): Unit = { + checkCall(_LIB.mxNotifyShutdown()) + } + // Convert ctypes returned doc string information into parameters docstring. def ctypes2docstring( argNames: Seq[String], diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 39056bd86f07..788b5ae26506 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -193,4 +193,6 @@ class LibInfo { // Random @native def mxRandomSeed(seed: Int): Int + + @native def mxNotifyShutdown(): Int } diff --git a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala index 44facfab5820..cd3ad1c5c207 100644 --- a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala +++ b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala @@ -97,6 +97,7 @@ object TrainMnist { kvStore = inst.kvStore, numEpochs = inst.numEpochs, modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch, lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch) + Base.notifyShutdown() logger.info("Finish fit ...") } catch { case ex: Exception => { diff --git a/scala-package/native/src/main/native/jni_helper_func.h b/scala-package/native/src/main/native/jni_helper_func.h index cce9cb0efe22..0001fd3825d1 100644 --- a/scala-package/native/src/main/native/jni_helper_func.h +++ b/scala-package/native/src/main/native/jni_helper_func.h @@ -6,30 +6,37 @@ jlong getLongField(JNIEnv *env, jobject obj) { jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); jfieldID refFid = env->GetFieldID(refClass, "value", "J"); - return env->GetLongField(obj, refFid); + jlong ret = env->GetLongField(obj, refFid); + env->DeleteLocalRef(refClass); + return ret; } jint getIntField(JNIEnv *env, jobject obj) { jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt"); jfieldID refFid = env->GetFieldID(refClass, "value", "I"); - return env->GetIntField(obj, refFid); + jint ret = env->GetIntField(obj, refFid); + env->DeleteLocalRef(refClass); + return ret; } void setIntField(JNIEnv *env, jobject obj, jint value) { jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt"); jfieldID refFid = env->GetFieldID(refClass, "value", "I"); env->SetIntField(obj, refFid, value); + env->DeleteLocalRef(refClass); } void setLongField(JNIEnv *env, jobject obj, jlong value) { jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); jfieldID refFid = env->GetFieldID(refClass, "value", "J"); env->SetLongField(obj, refFid, value); + env->DeleteLocalRef(refClass); } void setStringField(JNIEnv *env, jobject obj, const char *value) { jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefString"); jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;"); env->SetObjectField(obj, refFid, env->NewStringUTF(value)); + env->DeleteLocalRef(refClass); } #endif diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index a1ccb78709a8..9fb750f5b55b 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -20,9 +20,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayCreateNone (JNIEnv *env, jobject obj, jobject ndArrayHandle) { NDArrayHandle out; int ret = MXNDArrayCreateNone(&out); - jclass ndClass = env->GetObjectClass(ndArrayHandle); - jfieldID ptr = env->GetFieldID(ndClass, "value", "J"); - env->SetLongField(ndArrayHandle, ptr, (long)out); + setLongField(env, ndArrayHandle, (jlong) out); return ret; } @@ -37,9 +35,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayCreate(JNIEnv *env, j NDArrayHandle out; int ret = MXNDArrayCreate((mx_uint *)shapeArr, (mx_uint)ndim, devType, devId, delayAlloc, &out); env->ReleaseIntArrayElements(shape, shapeArr, 0); - jclass ndClass = env->GetObjectClass(ndArrayHandle); - jfieldID ptr = env->GetFieldID(ndClass, "value", "J"); - env->SetLongField(ndArrayHandle, ptr, (long)out); + setLongField(env, ndArrayHandle, (jlong) out); return ret; } @@ -316,10 +312,11 @@ extern "C" void KVStoreUpdaterCallbackFunc env->DeleteLocalRef(ndLocal); env->DeleteLocalRef(ndRecv); + env->DeleteLocalRef(ndObjClass); + env->DeleteLocalRef(updtClass); // FIXME: This function can be called multiple times, // can we find a way to safely destroy these two objects ? // env->DeleteGlobalRef(updaterFuncObjGlb); - // delete closure; } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater @@ -523,9 +520,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorSetMonitorCallback } JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) { - char *tmpstr = "MXNetError"; - jstring rtstr = env->NewStringUTF(tmpstr); - return rtstr; + return env->NewStringUTF(MXGetLastError()); } //IO funcs @@ -1217,3 +1212,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxRandomSeed (JNIEnv *env, jobject obj, jint seed) { return MXRandomSeed(seed); } + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNotifyShutdown + (JNIEnv *env, jobject obj) { + return MXNotifyShutdown(); +} From 35b2607cbbf4fc8b114ad62a481bcbc00ec4b6f7 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 29 Feb 2016 00:39:04 +0800 Subject: [PATCH 4/7] jvm shutdown hook --- .../core/src/main/scala/ml/dmlc/mxnet/Base.scala | 8 +++++++- .../dmlc/mxnet/examples/imclassification/TrainMnist.scala | 1 - 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 5bf86cdb62d0..26bceb7ce7ae 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -33,6 +33,12 @@ object Base { val _LIB = new LibInfo checkCall(_LIB.nativeLibInit()) + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = { + notifyShutdown() + } + }) + // helper function definitions /** * Check the return value of C API call @@ -48,7 +54,7 @@ object Base { } // Notify MXNet about a shutdown - def notifyShutdown(): Unit = { + private def notifyShutdown(): Unit = { checkCall(_LIB.mxNotifyShutdown()) } diff --git a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala index cd3ad1c5c207..44facfab5820 100644 --- a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala +++ b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala @@ -97,7 +97,6 @@ object TrainMnist { kvStore = inst.kvStore, numEpochs = inst.numEpochs, modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch, lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch) - Base.notifyShutdown() logger.info("Finish fit ...") } catch { case ex: Exception => { From 2d7e4fe3dfd896507a8a04455dda8e024fe3124a Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 29 Feb 2016 21:55:44 +0800 Subject: [PATCH 5/7] add destroy method to NDArray, Executor, Symbol, KVStore. And destroy useless objects automatically during training --- .../core/src/main/scala/ml/dmlc/mxnet/Base.scala | 1 + .../src/main/scala/ml/dmlc/mxnet/Executor.scala | 15 ++++++++++++++- .../core/src/main/scala/ml/dmlc/mxnet/IO.scala | 12 ++++++++++-- .../src/main/scala/ml/dmlc/mxnet/KVStore.scala | 10 +++++++++- .../core/src/main/scala/ml/dmlc/mxnet/Model.scala | 2 ++ .../src/main/scala/ml/dmlc/mxnet/NDArray.scala | 11 ++++++++++- .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 11 ++++++++++- .../examples/imclassification/ModelTrain.scala | 1 + 8 files changed, 57 insertions(+), 6 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 26bceb7ce7ae..1e0b1de96169 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -36,6 +36,7 @@ object Base { Runtime.getRuntime.addShutdownHook(new Thread() { override def run(): Unit = { notifyShutdown() + System.gc() } }) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index 14f0c4b1c2a4..3b30392cb55b 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -119,8 +119,17 @@ class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val sym protected var _auxDict: Map[String, NDArray] = null protected var monitorCallback: MXMonitorCallback = null + private var destroyed = false + override def finalize(): Unit = { - checkCall(_LIB.mxExecutorFree(handle)) + destroy() + } + + def destroy(): Unit = { + if (!destroyed) { + _LIB.mxExecutorFree(handle) + destroyed = true + } } /** @@ -338,6 +347,10 @@ class DataParallelExecutorManager(symbol: Symbol, } private[mxnet] val cpuOutputArrays = outputShapes.map(NDArray.zeros(_)) + def destroy(): Unit = { + trainExecs.foreach(_.destroy()) + } + // Install monitor on all executors def installMonitor(monitor: Monitor): Unit = { trainExecs.foreach(monitor.install) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 8bdf9213ab81..83f09df7263e 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -104,8 +104,16 @@ object IO { case class DataBatch(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray], index: IndexedSeq[Long], - pad: Int) - + pad: Int) { + def destroy(): Unit = { + if (data != null) { + data.foreach(arr => if (arr != null) arr.destroy()) + } + if (label != null) { + label.foreach(arr => if (arr != null) arr.destroy()) + } + } +} /** * DataIter object in mxnet. diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala index 9f772bf92da8..2fd845249ee9 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala @@ -26,9 +26,17 @@ object KVStore { // scalastyle:off finalize class KVStore(private val handle: KVStoreHandle) { private var updaterFunc: MXKVStoreUpdater = null + private var destroyed = false override def finalize(): Unit = { - checkCall(_LIB.mxKVStoreFree(handle)) + destroy() + } + + def destroy(): Unit = { + if (!destroyed) { + _LIB.mxKVStoreFree(handle) + destroyed = true + } } /** diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala index a23ce54d489c..b6e7a95a4006 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala @@ -270,6 +270,7 @@ object Model { if (epochSize != -1 && nBatch >= epochSize) { doReset = false } + dataBatch.destroy() dataBatch = trainData.next() } if (doReset) { @@ -306,6 +307,7 @@ object Model { } epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams)) } + executorManager.destroy() } // scalastyle:on parameterNum } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 9c040ea4c4b4..73646e343c9e 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -569,8 +569,17 @@ object NDArray { */ // scalastyle:off finalize class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = true) { + private var destroyed = false + override def finalize(): Unit = { - checkCall(_LIB.mxNDArrayFree(handle)) + destroy() + } + + def destroy(): Unit = { + if (!destroyed) { + _LIB.mxNDArrayFree(handle) + destroyed = true + } } /** diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 1f74fc7614ad..02aab193ade8 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -11,8 +11,17 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} */ // scalastyle:off finalize class Symbol(private[mxnet] val handle: SymbolHandle) { + private var destroyed = false + override def finalize(): Unit = { - checkCall(_LIB.mxSymbolFree(handle)) + destroy() + } + + def destroy(): Unit = { + if (!destroyed) { + _LIB.mxSymbolFree(handle) + destroyed = true + } } def +(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Plus")(Array(this, other)) diff --git a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala index af01f73a8b3e..e2dd81d1fda2 100644 --- a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala +++ b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala @@ -77,6 +77,7 @@ object ModelTrain { kvStore = kv, batchEndCallback = new Speedometer(batchSize, 50), epochEndCallback = checkpoint) + kv.destroy() } // scalastyle:on parameterNum } From bc9fbe0f289a0465449469e202ddf9d71a52e226 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 1 Mar 2016 01:18:33 +0800 Subject: [PATCH 6/7] fix memory leak in SGD. destroy updater when training finishes --- .../main/scala/ml/dmlc/mxnet/Executor.scala | 3 ++- .../main/scala/ml/dmlc/mxnet/KVStore.scala | 2 +- .../src/main/scala/ml/dmlc/mxnet/Model.scala | 3 +++ .../main/scala/ml/dmlc/mxnet/NDArray.scala | 2 +- .../main/scala/ml/dmlc/mxnet/Optimizer.scala | 7 ++++++ .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 2 +- .../scala/ml/dmlc/mxnet/io/MXDataIter.scala | 2 +- .../scala/ml/dmlc/mxnet/optimizer/SGD.scala | 22 +++++++++++++++++-- .../scala/ml/dmlc/mxnet/KVStoreSuite.scala | 1 + 9 files changed, 37 insertions(+), 7 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index 3b30392cb55b..e972fc066017 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -121,12 +121,13 @@ class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val sym private var destroyed = false - override def finalize(): Unit = { + override protected def finalize(): Unit = { destroy() } def destroy(): Unit = { if (!destroyed) { + outputs.foreach(_.destroy()) _LIB.mxExecutorFree(handle) destroyed = true } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala index 2fd845249ee9..cdf9c4a8c528 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala @@ -28,7 +28,7 @@ class KVStore(private val handle: KVStoreHandle) { private var updaterFunc: MXKVStoreUpdater = null private var destroyed = false - override def finalize(): Unit = { + override protected def finalize(): Unit = { destroy() } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala index b6e7a95a4006..e5d7ffb99e90 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala @@ -295,6 +295,7 @@ object Model { executorManager.loadDataBatch(evalBatch) executorManager.forward(isTrain = false) evalMetric.update(evalBatch.label, executorManager.cpuOutputArrays) + evalBatch.destroy() evalBatch = evalDataIter.next() } @@ -307,6 +308,8 @@ object Model { } epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams)) } + + updaterLocal.destroy() executorManager.destroy() } // scalastyle:on parameterNum diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 73646e343c9e..effde099d52a 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -571,7 +571,7 @@ object NDArray { class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = true) { private var destroyed = false - override def finalize(): Unit = { + override protected def finalize(): Unit = { destroy() } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala index 16b1f2502965..ca5b8774c447 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala @@ -10,6 +10,12 @@ object Optimizer { val state = states.getOrElseUpdate(index, optimizer.createState(index, weight)) optimizer.update(index, weight, grad, state) } + override def destroy(): Unit = { + states.values.foreach { + case array: NDArray => array.destroy() + case _ => + } + } } } } @@ -83,4 +89,5 @@ trait MXKVStoreUpdater { * @param local the value stored on local on this key */ def update(key: Int, recv: NDArray, local: NDArray): Unit + def destroy(): Unit } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 02aab193ade8..0d5bc185cb0c 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -13,7 +13,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} class Symbol(private[mxnet] val handle: SymbolHandle) { private var destroyed = false - override def finalize(): Unit = { + override protected def finalize(): Unit = { destroy() } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala index 06c0569be6bb..33878ef0760c 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala @@ -31,7 +31,7 @@ class MXDataIter(private[mxnet] val handle: DataIterHandle, val _provideLabel: Map[String, Shape] = Map(labelName -> label.shape) override val batchSize = data.shape(0) - override def finalize(): Unit = { + override protected def finalize(): Unit = { checkCall(_LIB.mxDataIterFree(handle)) } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala index c39ec5383665..0deb3e29ca13 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala @@ -42,17 +42,35 @@ class SGD(private val learningRate: Float = 0.01f, private val momentum: Float = var resdGrad = grad * this.rescaleGrad if (clipGradient != 0f) { + // to get rid of memory leak + val oldResdGrad = resdGrad resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) + oldResdGrad.destroy() } + if (state != null) { val mom = state.asInstanceOf[NDArray] mom *= momentum - mom += -lr * (resdGrad + wd * weight) + // to get rid of memory leak + val adder = wd * weight + adder += resdGrad + adder *= (-lr) + // mom += -lr * (resdGrad + wd * weight) + mom += adder weight += mom + adder.destroy() } else { require(momentum == 0f) - weight += -lr * (resdGrad + this.wd * weight) + // to get rid of memory leak + val adder = this.wd * weight + adder += resdGrad + adder *= (-lr) + // weight += -lr * (resdGrad + this.wd * weight) + weight += adder + adder.destroy() } + + resdGrad.destroy() } // Create additional optimizer state such as momentum. diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala index 7a3dc940a737..743c2376cfbc 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala @@ -33,6 +33,7 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll { // scalastyle:on println stored += input * 2 } + override def destroy(): Unit = {} } kv.setUpdater(updater) From 789ac5ef4398e50c778248e4d51e22b8495642e8 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 1 Mar 2016 23:57:13 +0800 Subject: [PATCH 7/7] Add javadocs for native memory dispose. Release kvstore at the end if it is created in fit(). --- .../src/main/scala/ml/dmlc/mxnet/Base.scala | 1 - .../main/scala/ml/dmlc/mxnet/Executor.scala | 30 ++++++--- .../src/main/scala/ml/dmlc/mxnet/IO.scala | 54 ++++++++-------- .../main/scala/ml/dmlc/mxnet/KVStore.scala | 20 ++++-- .../src/main/scala/ml/dmlc/mxnet/Model.scala | 9 +-- .../main/scala/ml/dmlc/mxnet/NDArray.scala | 24 ++++--- .../main/scala/ml/dmlc/mxnet/Optimizer.scala | 9 ++- .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 22 ++++--- .../scala/ml/dmlc/mxnet/io/MXDataIter.scala | 62 +++++++++++-------- .../scala/ml/dmlc/mxnet/io/NDArrayIter.scala | 52 ++++++++-------- .../ml/dmlc/mxnet/io/PrefetchingIter.scala | 46 +++++++------- .../scala/ml/dmlc/mxnet/optimizer/Adam.scala | 4 ++ .../scala/ml/dmlc/mxnet/optimizer/SGD.scala | 16 ++--- .../scala/ml/dmlc/mxnet/KVStoreSuite.scala | 2 +- .../imclassification/ModelTrain.scala | 2 +- 15 files changed, 203 insertions(+), 150 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 1e0b1de96169..26bceb7ce7ae 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -36,7 +36,6 @@ object Base { Runtime.getRuntime.addShutdownHook(new Thread() { override def run(): Unit = { notifyShutdown() - System.gc() } }) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index e972fc066017..32d4f8e8954e 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -101,7 +101,12 @@ object Executor { } /** - * Symbolic Executor component of MXNet + * Symbolic Executor component of MXNet
+ * + * WARNING: it is your responsibility to clear this object through dispose(). + * NEVER rely on the GC strategy + * + * * @author Yizhi Liu * * Constructor: please use Symbol.bind and Symbol.simpleBind instead. @@ -110,7 +115,8 @@ object Executor { * @see Symbol.bind : to create executor */ // scalastyle:off finalize -class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val symbol: Symbol) { +class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, + private[mxnet] val symbol: Symbol) { private[mxnet] var argArrays: Array[NDArray] = null private[mxnet] var gradArrays: Array[NDArray] = null private[mxnet] var auxArrays: Array[NDArray] = null @@ -119,17 +125,17 @@ class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val sym protected var _auxDict: Map[String, NDArray] = null protected var monitorCallback: MXMonitorCallback = null - private var destroyed = false + private var disposed = false override protected def finalize(): Unit = { - destroy() + dispose() } - def destroy(): Unit = { - if (!destroyed) { - outputs.foreach(_.destroy()) + def dispose(): Unit = { + if (!disposed) { + outputs.foreach(_.dispose()) _LIB.mxExecutorFree(handle) - destroyed = true + disposed = true } } @@ -348,8 +354,12 @@ class DataParallelExecutorManager(symbol: Symbol, } private[mxnet] val cpuOutputArrays = outputShapes.map(NDArray.zeros(_)) - def destroy(): Unit = { - trainExecs.foreach(_.destroy()) + /** + * Release the related executors. + * The object shall never be used after it is disposed. + */ + def dispose(): Unit = { + trainExecs.foreach(_.dispose()) } // Install monitor on all executors diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 83f09df7263e..b8a0eeb4548d 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -15,7 +15,7 @@ object IO { type PackCreateFunc = (Map[String, String]) => DataPack private val logger = LoggerFactory.getLogger(classOf[DataIter]) - private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule() + private val iterCreateFuncs: Map[String, IterCreateFunc] = initIOModule() def MNISTIter: IterCreateFunc = iterCreateFuncs("MNISTIter") def ImageRecordIter: IterCreateFunc = iterCreateFuncs("ImageRecordIter") @@ -30,7 +30,7 @@ object IO { * create iterator via iterName and params * @param iterName name of iterator; "MNISTIter" or "ImageRecordIter" * @param params parameters for create iterator - * @return + * @return created data iterator */ def createIterator(iterName: String, params: Map[String, String]): DataIter = { iterCreateFuncs(iterName)(params) @@ -40,23 +40,23 @@ object IO { * create dataPack for iterator via itername and params * @param iterName name of iterator: "MNISTIter" or "ImageRecordIter" * @param params parameters for create iterator - * @return + * @return created dataPack */ def createMXDataPack(iterName: String)(params: Map[String, String]): DataPack = { new MXDataPack(iterName, params) } /** - * initi all IO creator Functions - * @return + * initialize all IO creator Functions + * @return Map from name to iter creator function */ - private def _initIOModule(): Map[String, IterCreateFunc] = { + private def initIOModule(): Map[String, IterCreateFunc] = { val IterCreators = new ListBuffer[DataIterCreator] checkCall(_LIB.mxListDataIters(IterCreators)) - IterCreators.map(_makeIOIterator).toMap + IterCreators.map(makeIOIterator).toMap } - private def _makeIOIterator(handle: DataIterCreator): (String, IterCreateFunc) = { + private def makeIOIterator(handle: DataIterCreator): (String, IterCreateFunc) = { val name = new RefString val desc = new RefString val argNames = new ListBuffer[String] @@ -71,12 +71,12 @@ object IO { /** * DataIter creator - * @param handle - * @param params - * @return + * @param handle native memory ptr for the iterator + * @param params parameter passed to the iterator + * @return created DataIter */ private def creator(handle: DataIterCreator)( - params: Map[String, String]): DataIter = { + params: Map[String, String]): DataIter = { val out = new DataIterHandleRef val keys = params.keys.toArray val vals = params.values.toArray @@ -96,21 +96,21 @@ object IO { /** * class batch of data - * @param data - * @param label - * @param index - * @param pad */ case class DataBatch(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray], index: IndexedSeq[Long], pad: Int) { - def destroy(): Unit = { + /** + * Dispose its data and labels + * The object shall never be used after it is disposed. + */ + def dispose(): Unit = { if (data != null) { - data.foreach(arr => if (arr != null) arr.destroy()) + data.foreach(arr => if (arr != null) arr.dispose()) } if (label != null) { - label.foreach(arr => if (arr != null) arr.destroy()) + label.foreach(arr => if (arr != null) arr.dispose()) } } } @@ -145,15 +145,15 @@ abstract class DataIter(val batchSize: Int = 0) extends Iterator[DataBatch] { def getLabel(): IndexedSeq[NDArray] /** - * get the number of padding examples + * Get the number of padding examples * in current batch * @return number of padding examples in current batch */ def getPad(): Int /** - * the index of current batch - * @return + * Get the index of current batch + * @return the index of current batch */ def getIndex(): IndexedSeq[Long] @@ -165,13 +165,13 @@ abstract class DataIter(val batchSize: Int = 0) extends Iterator[DataBatch] { } /** - * pack of DataIter, use as Iterable class - */ + * pack of DataIter, use as Iterable class + */ abstract class DataPack() extends Iterable[DataBatch] { /** - * get data iterator - * @return DataIter - */ + * get data iterator + * @return DataIter + */ def iterator: DataIter } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala index cdf9c4a8c528..324b6e50d470 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala @@ -8,7 +8,11 @@ import ml.dmlc.mxnet.Base._ */ object KVStore { /** - * Create a new KVStore. + * Create a new KVStore.
+ * + * WARNING: it is your responsibility to clear this object through dispose(). + * NEVER rely on the GC strategy + * * * @param name : {'local', 'dist'} * The type of KVStore @@ -26,16 +30,20 @@ object KVStore { // scalastyle:off finalize class KVStore(private val handle: KVStoreHandle) { private var updaterFunc: MXKVStoreUpdater = null - private var destroyed = false + private var disposed = false override protected def finalize(): Unit = { - destroy() + dispose() } - def destroy(): Unit = { - if (!destroyed) { + /** + * Release the native memory. + * The object shall never be used after it is disposed. + */ + def dispose(): Unit = { + if (!disposed) { _LIB.mxKVStoreFree(handle) - destroyed = true + disposed = true } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala index e5d7ffb99e90..83294e57f59b 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala @@ -270,7 +270,7 @@ object Model { if (epochSize != -1 && nBatch >= epochSize) { doReset = false } - dataBatch.destroy() + dataBatch.dispose() dataBatch = trainData.next() } if (doReset) { @@ -295,7 +295,7 @@ object Model { executorManager.loadDataBatch(evalBatch) executorManager.forward(isTrain = false) evalMetric.update(evalBatch.label, executorManager.cpuOutputArrays) - evalBatch.destroy() + evalBatch.dispose() evalBatch = evalDataIter.next() } @@ -309,8 +309,8 @@ object Model { epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams)) } - updaterLocal.destroy() - executorManager.destroy() + updaterLocal.dispose() + executorManager.dispose() } // scalastyle:on parameterNum } @@ -519,6 +519,7 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams) fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore, epochEndCallback, batchEndCallback, logger, workLoadList) + kvStore.foreach(_.dispose()) } def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric, diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index effde099d52a..0261488994ae 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -565,20 +565,28 @@ object NDArray { /** * NDArray object in mxnet. - * NDArray is basic ndarray/Tensor like data structure in mxnet. + * NDArray is basic ndarray/Tensor like data structure in mxnet.
+ * + * WARNING: it is your responsibility to clear this object through dispose(). + * NEVER rely on the GC strategy + * */ // scalastyle:off finalize -class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = true) { - private var destroyed = false - +class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, + val writable: Boolean = true) { + private var disposed = false override protected def finalize(): Unit = { - destroy() + dispose() } - def destroy(): Unit = { - if (!destroyed) { + /** + * Release the native memory. + * The object shall never be used after it is disposed. + */ + def dispose(): Unit = { + if (!disposed) { _LIB.mxNDArrayFree(handle) - destroyed = true + disposed = true } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala index ca5b8774c447..5aeba8a4f880 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala @@ -10,9 +10,12 @@ object Optimizer { val state = states.getOrElseUpdate(index, optimizer.createState(index, weight)) optimizer.update(index, weight, grad, state) } - override def destroy(): Unit = { + override def dispose(): Unit = { states.values.foreach { - case array: NDArray => array.destroy() + case array: NDArray => array.dispose() + case sym: Symbol => sym.dispose() + case exec: Executor => exec.dispose() + case kv: KVStore => kv.dispose() case _ => } } @@ -89,5 +92,5 @@ trait MXKVStoreUpdater { * @param local the value stored on local on this key */ def update(key: Int, recv: NDArray, local: NDArray): Unit - def destroy(): Unit + def dispose(): Unit } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 0d5bc185cb0c..ba9dea6d7a12 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -6,21 +6,29 @@ import org.slf4j.LoggerFactory import scala.collection.mutable.{ArrayBuffer, ListBuffer} /** - * Symbolic configuration API of mxnet. + * Symbolic configuration API of mxnet.
+ * + * WARNING: it is your responsibility to clear this object through dispose(). + * NEVER rely on the GC strategy + * * @author Yizhi Liu */ // scalastyle:off finalize -class Symbol(private[mxnet] val handle: SymbolHandle) { - private var destroyed = false +class Symbol private(private[mxnet] val handle: SymbolHandle) { + private var disposed = false override protected def finalize(): Unit = { - destroy() + dispose() } - def destroy(): Unit = { - if (!destroyed) { + /** + * Release the native memory. + * The object shall never be used after it is disposed. + */ + def dispose(): Unit = { + if (!disposed) { _LIB.mxSymbolFree(handle) - destroyed = true + disposed = true } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala index 33878ef0760c..c71f9dfa3afd 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala @@ -8,13 +8,13 @@ import org.slf4j.LoggerFactory import scala.collection.mutable.ListBuffer /** - * DataIter built in MXNet. - * @param handle the handle to the underlying C++ Data Iterator - */ + * DataIter built in MXNet. + * @param handle the handle to the underlying C++ Data Iterator + */ // scalastyle:off finalize -class MXDataIter(private[mxnet] val handle: DataIterHandle, - private val dataName: String = "data", - private val labelName: String = "label") extends DataIter { +class MXDataIter private[mxnet](private[mxnet] val handle: DataIterHandle, + private val dataName: String = "data", + private val labelName: String = "label") extends DataIter { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) // use currentBatch to implement hasNext @@ -30,14 +30,26 @@ class MXDataIter(private[mxnet] val handle: DataIterHandle, val _provideData: Map[String, Shape] = Map(dataName -> data.shape) val _provideLabel: Map[String, Shape] = Map(labelName -> label.shape) override val batchSize = data.shape(0) + private var disposed = false override protected def finalize(): Unit = { - checkCall(_LIB.mxDataIterFree(handle)) + dispose() } /** - * reset the iterator - */ + * Release the native memory. + * The object shall never be used after it is disposed. + */ + def dispose(): Unit = { + if (!disposed) { + _LIB.mxDataIterFree(handle) + disposed = true + } + } + + /** + * reset the iterator + */ override def reset(): Unit = { // TODO: self._debug_at_begin = True currentBatch = null @@ -64,9 +76,9 @@ class MXDataIter(private[mxnet] val handle: DataIterHandle, } /** - * Iterate to next batch - * @return whether the move is successful - */ + * Iterate to next batch + * @return whether the move is successful + */ private def iterNext(): Boolean = { val next = new RefInt checkCall(_LIB.mxDataIterNext(handle, next)) @@ -79,9 +91,9 @@ class MXDataIter(private[mxnet] val handle: DataIterHandle, } /** - * get data of current batch - * @return the data of current batch - */ + * get data of current batch + * @return the data of current batch + */ override def getData(): IndexedSeq[NDArray] = { val out = new NDArrayHandleRef checkCall(_LIB.mxDataIterGetData(handle, out)) @@ -89,9 +101,9 @@ class MXDataIter(private[mxnet] val handle: DataIterHandle, } /** - * Get label of current batch - * @return the label of current batch - */ + * Get label of current batch + * @return the label of current batch + */ override def getLabel(): IndexedSeq[NDArray] = { val out = new NDArrayHandleRef checkCall(_LIB.mxDataIterGetLabel(handle, out)) @@ -99,9 +111,9 @@ class MXDataIter(private[mxnet] val handle: DataIterHandle, } /** - * the index of current batch - * @return - */ + * Get the index of current batch + * @return the index of current batch + */ override def getIndex(): IndexedSeq[Long] = { val outIndex = new ListBuffer[Long] val outSize = new RefLong @@ -110,10 +122,10 @@ class MXDataIter(private[mxnet] val handle: DataIterHandle, } /** - * get the number of padding examples - * in current batch - * @return number of padding examples in current batch - */ + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ override def getPad(): MXUint = { val out = new MXUintRef checkCall(_LIB.mxDataIterGetPadNum(handle, out)) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala index f0169b87bcb8..32997dadf33d 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala @@ -4,49 +4,49 @@ import ml.dmlc.mxnet.Base._ import ml.dmlc.mxnet.{DataIter, NDArray, Shape} /** - * TODO - * NDArrayIter object in mxnet. Taking NDArray or numpy array to get dataiter. - * @param data NDArrayIter supports single or multiple data and label. - * @param label Same as data, but is not fed to the model during testing. - * @param batchSize Batch Size - * @param shuffle Whether to shuffle the data - * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch - * @note + * TODO + * NDArrayIter object in mxnet. Taking NDArray or numpy array to get dataiter. + * @param data NDArrayIter supports single or multiple data and label. + * @param label Same as data, but is not fed to the model during testing. + * @param batchSize Batch Size + * @param shuffle Whether to shuffle the data + * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch + * @note * This iterator will pad, discard or roll over the last batch if - * the size of data does not match batch_size. Roll over is intended - * for training and can cause problems if used for prediction. - */ + * the size of data does not match batch_size. Roll over is intended + * for training and can cause problems if used for prediction. + */ class NDArrayIter(data: NDArray, label: NDArray = null, batchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad") extends DataIter(batchSize) { /** - * reset the iterator - */ + * reset the iterator + */ override def reset(): Unit = ??? /** - * get data of current batch - * @return the data of current batch - */ + * get data of current batch + * @return the data of current batch + */ override def getData(): IndexedSeq[NDArray] = ??? /** - * Get label of current batch - * @return the label of current batch - */ + * Get label of current batch + * @return the label of current batch + */ override def getLabel(): IndexedSeq[NDArray] = ??? /** - * the index of current batch - * @return - */ + * the index of current batch + * @return + */ override def getIndex(): IndexedSeq[Long] = ??? /** - * get the number of padding examples - * in current batch - * @return number of padding examples in current batch - */ + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ override def getPad(): MXUint = ??? // The name and shape of data provided by this iterator diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala index 270584079cfa..1419e76ff98a 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala @@ -3,48 +3,48 @@ package ml.dmlc.mxnet.io import ml.dmlc.mxnet.{DataIter, NDArray, Shape} /** - * TODO - * Base class for prefetching iterators. Takes one or more DataIters - * (or any class with "reset" and "read" methods) and combine them with - * prefetching. - * @param iters list of DataIters - * @param dataNames - * @param labelNames - */ + * TODO + * Base class for prefetching iterators. Takes one or more DataIters + * (or any class with "reset" and "read" methods) and combine them with + * prefetching. + * @param iters list of DataIters + * @param dataNames + * @param labelNames + */ class PrefetchingIter(val iters: List[DataIter], val dataNames: Map[String, String] = null, val labelNames: Map[String, String] = null) extends DataIter { /** - * reset the iterator - */ + * reset the iterator + */ override def reset(): Unit = ??? /** - * get data of current batch - * @return the data of current batch - */ + * get data of current batch + * @return the data of current batch + */ override def getData(): IndexedSeq[NDArray] = ??? /** - * Get label of current batch - * @return the label of current batch - */ + * Get label of current batch + * @return the label of current batch + */ override def getLabel(): IndexedSeq[NDArray] = ??? /** - * the index of current batch - * @return - */ + * the index of current batch + * @return + */ override def getIndex(): IndexedSeq[Long] = ??? // The name and shape of label provided by this iterator override def provideLabel: Map[String, Shape] = ??? /** - * get the number of padding examples - * in current batch - * @return number of padding examples in current batch - */ + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ override def getPad(): Int = ??? // The name and shape of data provided by this iterator diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/Adam.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/Adam.scala index 9077ccf20ceb..8957e8263ab7 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/Adam.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/Adam.scala @@ -10,6 +10,10 @@ import ml.dmlc.mxnet.NDArrayConversions._ * Adam: A Method for Stochastic Optimization, * http://arxiv.org/abs/1412.6980 * + * WARNING + * TODO: This class has NOT been tested yet. + * And there exists potential memory leak in the implementation + * * @author Yuan Tang, Yizhi Liu * * @param learningRate Float, Step size. diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala index 0deb3e29ca13..e7255f5588d1 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala @@ -45,32 +45,32 @@ class SGD(private val learningRate: Float = 0.01f, private val momentum: Float = // to get rid of memory leak val oldResdGrad = resdGrad resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) - oldResdGrad.destroy() + oldResdGrad.dispose() } if (state != null) { val mom = state.asInstanceOf[NDArray] mom *= momentum - // to get rid of memory leak + // adder = -lr * (resdGrad + wd * weight) + // we write in this way to get rid of memory leak val adder = wd * weight adder += resdGrad adder *= (-lr) - // mom += -lr * (resdGrad + wd * weight) mom += adder weight += mom - adder.destroy() + adder.dispose() } else { require(momentum == 0f) - // to get rid of memory leak + // adder = -lr * (resdGrad + this.wd * weight) + // we write in this way to get rid of memory leak val adder = this.wd * weight adder += resdGrad adder *= (-lr) - // weight += -lr * (resdGrad + this.wd * weight) weight += adder - adder.destroy() + adder.dispose() } - resdGrad.destroy() + resdGrad.dispose() } // Create additional optimizer state such as momentum. diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala index 743c2376cfbc..2ddfabeab77c 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala @@ -33,7 +33,7 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll { // scalastyle:on println stored += input * 2 } - override def destroy(): Unit = {} + override def dispose(): Unit = {} } kv.setUpdater(updater) diff --git a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala index e2dd81d1fda2..2311d1e2a15b 100644 --- a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala +++ b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala @@ -77,7 +77,7 @@ object ModelTrain { kvStore = kv, batchEndCallback = new Speedometer(batchSize, 50), epochEndCallback = checkpoint) - kv.destroy() + kv.dispose() } // scalastyle:on parameterNum }