Skip to content

Commit

Permalink
Merge pull request apache#35 from javelinjs/scala-package-mem-leak-fix
Browse files Browse the repository at this point in the history
Great work!
  • Loading branch information
yanqingmen committed Mar 2, 2016
2 parents 4366e9c + 789ac5e commit 2956813
Show file tree
Hide file tree
Showing 18 changed files with 369 additions and 162 deletions.
12 changes: 12 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ object Base {

System.loadLibrary("mxnet-scala")
val _LIB = new LibInfo
checkCall(_LIB.nativeLibInit())

Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = {
notifyShutdown()
}
})

// helper function definitions
/**
Expand All @@ -46,6 +53,11 @@ object Base {
}
}

// Notify MXNet about a shutdown
private def notifyShutdown(): Unit = {
checkCall(_LIB.mxNotifyShutdown())
}

// Convert ctypes returned doc string information into parameters docstring.
def ctypes2docstring(
argNames: Seq[String],
Expand Down
32 changes: 28 additions & 4 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ object Executor {
}

/**
* Symbolic Executor component of MXNet
* Symbolic Executor component of MXNet <br />
* <b>
* WARNING: it is your responsibility to clear this object through dispose().
* NEVER rely on the GC strategy
* </b>
*
* @author Yizhi Liu
*
* Constructor: please use Symbol.bind and Symbol.simpleBind instead.
Expand All @@ -110,7 +115,8 @@ object Executor {
* @see Symbol.bind : to create executor
*/
// scalastyle:off finalize
class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val symbol: Symbol) {
class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] val symbol: Symbol) {
private[mxnet] var argArrays: Array[NDArray] = null
private[mxnet] var gradArrays: Array[NDArray] = null
private[mxnet] var auxArrays: Array[NDArray] = null
Expand All @@ -119,8 +125,18 @@ class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val sym
protected var _auxDict: Map[String, NDArray] = null
protected var monitorCallback: MXMonitorCallback = null

override def finalize(): Unit = {
checkCall(_LIB.mxExecutorFree(handle))
private var disposed = false

override protected def finalize(): Unit = {
dispose()
}

def dispose(): Unit = {
if (!disposed) {
outputs.foreach(_.dispose())
_LIB.mxExecutorFree(handle)
disposed = true
}
}

/**
Expand Down Expand Up @@ -338,6 +354,14 @@ class DataParallelExecutorManager(symbol: Symbol,
}
private[mxnet] val cpuOutputArrays = outputShapes.map(NDArray.zeros(_))

/**
* Release the related executors.
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
trainExecs.foreach(_.dispose())
}

// Install monitor on all executors
def installMonitor(monitor: Monitor): Unit = {
trainExecs.foreach(monitor.install)
Expand Down
60 changes: 34 additions & 26 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ object IO {
type PackCreateFunc = (Map[String, String]) => DataPack

private val logger = LoggerFactory.getLogger(classOf[DataIter])
private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule()
private val iterCreateFuncs: Map[String, IterCreateFunc] = initIOModule()

def MNISTIter: IterCreateFunc = iterCreateFuncs("MNISTIter")
def ImageRecordIter: IterCreateFunc = iterCreateFuncs("ImageRecordIter")
Expand All @@ -30,7 +30,7 @@ object IO {
* create iterator via iterName and params
* @param iterName name of iterator; "MNISTIter" or "ImageRecordIter"
* @param params parameters for create iterator
* @return
* @return created data iterator
*/
def createIterator(iterName: String, params: Map[String, String]): DataIter = {
iterCreateFuncs(iterName)(params)
Expand All @@ -40,23 +40,23 @@ object IO {
* create dataPack for iterator via itername and params
* @param iterName name of iterator: "MNISTIter" or "ImageRecordIter"
* @param params parameters for create iterator
* @return
* @return created dataPack
*/
def createMXDataPack(iterName: String)(params: Map[String, String]): DataPack = {
new MXDataPack(iterName, params)
}

/**
* initi all IO creator Functions
* @return
* initialize all IO creator Functions
* @return Map from name to iter creator function
*/
private def _initIOModule(): Map[String, IterCreateFunc] = {
private def initIOModule(): Map[String, IterCreateFunc] = {
val IterCreators = new ListBuffer[DataIterCreator]
checkCall(_LIB.mxListDataIters(IterCreators))
IterCreators.map(_makeIOIterator).toMap
IterCreators.map(makeIOIterator).toMap
}

private def _makeIOIterator(handle: DataIterCreator): (String, IterCreateFunc) = {
private def makeIOIterator(handle: DataIterCreator): (String, IterCreateFunc) = {
val name = new RefString
val desc = new RefString
val argNames = new ListBuffer[String]
Expand All @@ -71,12 +71,12 @@ object IO {

/**
* DataIter creator
* @param handle
* @param params
* @return
* @param handle native memory ptr for the iterator
* @param params parameter passed to the iterator
* @return created DataIter
*/
private def creator(handle: DataIterCreator)(
params: Map[String, String]): DataIter = {
params: Map[String, String]): DataIter = {
val out = new DataIterHandleRef
val keys = params.keys.toArray
val vals = params.values.toArray
Expand All @@ -96,16 +96,24 @@ object IO {

/**
* class batch of data
* @param data
* @param label
* @param index
* @param pad
*/
case class DataBatch(data: IndexedSeq[NDArray],
label: IndexedSeq[NDArray],
index: IndexedSeq[Long],
pad: Int)

pad: Int) {
/**
* Dispose its data and labels
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
if (data != null) {
data.foreach(arr => if (arr != null) arr.dispose())
}
if (label != null) {
label.foreach(arr => if (arr != null) arr.dispose())
}
}
}

/**
* DataIter object in mxnet.
Expand Down Expand Up @@ -137,15 +145,15 @@ abstract class DataIter(val batchSize: Int = 0) extends Iterator[DataBatch] {
def getLabel(): IndexedSeq[NDArray]

/**
* get the number of padding examples
* Get the number of padding examples
* in current batch
* @return number of padding examples in current batch
*/
def getPad(): Int

/**
* the index of current batch
* @return
* Get the index of current batch
* @return the index of current batch
*/
def getIndex(): IndexedSeq[Long]

Expand All @@ -157,13 +165,13 @@ abstract class DataIter(val batchSize: Int = 0) extends Iterator[DataBatch] {
}

/**
* pack of DataIter, use as Iterable class
*/
* pack of DataIter, use as Iterable class
*/
abstract class DataPack() extends Iterable[DataBatch] {
/**
* get data iterator
* @return DataIter
*/
* get data iterator
* @return DataIter
*/
def iterator: DataIter
}

Expand Down
24 changes: 23 additions & 1 deletion scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ import ml.dmlc.mxnet.Base._
*/
object KVStore {
/**
* Create a new KVStore.
* Create a new KVStore. <br />
* <b>
* WARNING: it is your responsibility to clear this object through dispose().
* NEVER rely on the GC strategy
* </b>
*
* @param name : {'local', 'dist'}
* The type of KVStore
Expand All @@ -23,8 +27,25 @@ object KVStore {
}
}

// scalastyle:off finalize
class KVStore(private val handle: KVStoreHandle) {
private var updaterFunc: MXKVStoreUpdater = null
private var disposed = false

override protected def finalize(): Unit = {
dispose()
}

/**
* Release the native memory.
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
if (!disposed) {
_LIB.mxKVStoreFree(handle)
disposed = true
}
}

/**
* Initialize a single or a sequence of key-value pairs into the store.
Expand Down Expand Up @@ -202,3 +223,4 @@ class KVStore(private val handle: KVStoreHandle) {
checkCall(_LIB.mxKVStoreSendCommmandToServers(handle, head, body))
}
}
// scalastyle:off finalize
5 changes: 5 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
* @author Yizhi Liu
*/
class LibInfo {
@native def nativeLibInit(): Int
// NDArray
@native def mxNDArrayFree(handle: NDArrayHandle): Int
@native def mxGetLastError(): String
Expand Down Expand Up @@ -84,6 +85,7 @@ class LibInfo {
@native def mxKVStoreBarrier(handle: KVStoreHandle): Int
@native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int
@native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int
@native def mxKVStoreFree(handle: KVStoreHandle): Int

// DataIter Funcs
@native def mxListDataIters(handles: ListBuffer[DataIterCreator]): Int
Expand Down Expand Up @@ -187,7 +189,10 @@ class LibInfo {
// scalastyle:on parameterNum
@native def mxSymbolSaveToFile(handle: SymbolHandle, fname: String): Int
@native def mxSymbolCreateFromFile(fname: String, handle: SymbolHandleRef): Int
@native def mxSymbolFree(handle: SymbolHandle): Int

// Random
@native def mxRandomSeed(seed: Int): Int

@native def mxNotifyShutdown(): Int
}
6 changes: 6 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ object Model {
if (epochSize != -1 && nBatch >= epochSize) {
doReset = false
}
dataBatch.dispose()
dataBatch = trainData.next()
}
if (doReset) {
Expand All @@ -294,6 +295,7 @@ object Model {
executorManager.loadDataBatch(evalBatch)
executorManager.forward(isTrain = false)
evalMetric.update(evalBatch.label, executorManager.cpuOutputArrays)
evalBatch.dispose()
evalBatch = evalDataIter.next()
}

Expand All @@ -306,6 +308,9 @@ object Model {
}
epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams))
}

updaterLocal.dispose()
executorManager.dispose()
}
// scalastyle:on parameterNum
}
Expand Down Expand Up @@ -514,6 +519,7 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp
val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams)
fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
epochEndCallback, batchEndCallback, logger, workLoadList)
kvStore.foreach(_.dispose())
}

def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
Expand Down
25 changes: 21 additions & 4 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -565,12 +565,29 @@ object NDArray {

/**
* NDArray object in mxnet.
* NDArray is basic ndarray/Tensor like data structure in mxnet.
* NDArray is basic ndarray/Tensor like data structure in mxnet. <br />
* <b>
* WARNING: it is your responsibility to clear this object through dispose().
* NEVER rely on the GC strategy
* </b>
*/
// scalastyle:off finalize
class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = true) {
override def finalize(): Unit = {
checkCall(_LIB.mxNDArrayFree(handle))
class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
val writable: Boolean = true) {
private var disposed = false
override protected def finalize(): Unit = {
dispose()
}

/**
* Release the native memory.
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
if (!disposed) {
_LIB.mxNDArrayFree(handle)
disposed = true
}
}

/**
Expand Down
10 changes: 10 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ object Optimizer {
val state = states.getOrElseUpdate(index, optimizer.createState(index, weight))
optimizer.update(index, weight, grad, state)
}
override def dispose(): Unit = {
states.values.foreach {
case array: NDArray => array.dispose()
case sym: Symbol => sym.dispose()
case exec: Executor => exec.dispose()
case kv: KVStore => kv.dispose()
case _ =>
}
}
}
}
}
Expand Down Expand Up @@ -83,4 +92,5 @@ trait MXKVStoreUpdater {
* @param local the value stored on local on this key
*/
def update(key: Int, recv: NDArray, local: NDArray): Unit
def dispose(): Unit
}
Loading

0 comments on commit 2956813

Please sign in to comment.