From 49171551d28e6485581771655151bebcf53958f6 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 28 Dec 2015 13:56:24 +0800 Subject: [PATCH] fix Adam optimizer --- .../scala/ml/dmlc/mxnet/LRScheduler.scala | 2 +- .../scala/ml/dmlc/mxnet/optimizer/Adam.scala | 42 +++++++++++-------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala index 756194b94566..8cfee16aae80 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala @@ -7,7 +7,7 @@ import org.slf4j.LoggerFactory * @author Yuan Tang */ -abstract class LRScheduler(protected var baseLR: Float = 0.01f) { +abstract class LRScheduler(var baseLR: Float = 0.01f) { /** * Base class of a learning rate scheduler * diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/Adam.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/Adam.scala index 5616506e78e6..7907c2059ae2 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/Adam.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/Adam.scala @@ -1,3 +1,5 @@ +package ml.dmlc.mxnet.optimizer + import ml.dmlc.mxnet.{NDArray, Optimizer, LRScheduler} import ml.dmlc.mxnet.NDArrayConversions._ @@ -8,7 +10,7 @@ import ml.dmlc.mxnet.NDArrayConversions._ * Adam: A Method for Stochastic Optimization, * http://arxiv.org/abs/1412.6980 * - * @author Yuan Tang + * @author Yuan Tang, Yizhi Liu * * @param learningRate Float, Step size. * @param beta1 Float, Exponential decay rate for the first moment estimates. @@ -21,12 +23,17 @@ import ml.dmlc.mxnet.NDArrayConversions._ * @param lrScheduler The learning rate scheduler */ class Adam(var learningRate: Float = 0.002f, val beta1: Float = 0.9f, val beta2: Float = 0.999f, - val epsilon: Float = 0.00000001f, val decayFactor: Float = 1-0.00000001f, val wd: Float = 0.0f, + val epsilon: Float = 1e-8f, val decayFactor: Float = 1-1e-8f, val wd: Float = 0.0f, rescaleGrad: Float = 1f, val clipGradient: Float = 0f, - val lrScheduler: LRScheduler = null) extends Optimizer(rescaleGrad: Float) { + val lrScheduler: LRScheduler = null) extends Optimizer(rescaleGrad: Float) { protected var time: Int = 0 - protected var timeFirstIndex: Int = 0 + protected var timeFirstIndex: Option[Int] = None + + if (lrScheduler != null) { + lrScheduler.baseLR = learningRate + } + /** * Update the parameters. * @param index An unique integer key used to index the parameters @@ -45,41 +52,42 @@ class Adam(var learningRate: Float = 0.002f, val beta1: Float = 0.9f, val beta2: this.learningRate }) * lrScale.getOrElse(index, 1f) - var (mean, variance) = state + var (mean, variance) = state.asInstanceOf[(NDArray, NDArray)] - if (timeFirstIndex == 0) { - timeFirstIndex = index + // increment time only when the first parameters is called + if (timeFirstIndex == None) { + timeFirstIndex = Option(index) time = 0 - } else if (timeFirstIndex == index) { + } else if (timeFirstIndex.get == index) { time += 1 } val t1: Int = time + 1 - learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1))/(1.0 - math.pow(beta1, t1))) toFloat - val beta1t = beta1 * math.pow(decayFactor, t1 - 1) toFloat + learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1)) / (1.0 - math.pow(beta1, t1))).toFloat + val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat var resdGrad = grad * rescaleGrad if (clipGradient != 0f) { resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) } - val meanT = beta1t * mean.asInstanceOf[NDArray] + (1.0 - beta1t) * resdGrad toScalar - val varianceT = beta2 * variance.asInstanceOf[NDArray] + (1.0f - beta2) * resdGrad * resdGrad toScalar + val meanT = beta1t * mean + (1.0 - beta1t) * resdGrad + val varianceT = beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad - var step = learningRate * meanT / (math.sqrt(varianceT) + epsilon) + var step = learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon) if (wd > 0.0f) { - step += (lr * wd * weight).toScalar + step += lr * wd * weight } - weight += -step.toFloat + weight += -step mean = meanT variance = varianceT } // Create additional optimizer state: mean, variance - override def createState(index: Int, weight: NDArray): AnyRef = { - timeFirstIndex = 0 + override def createState(index: Int, weight: NDArray): (NDArray, NDArray) = { + timeFirstIndex = None // time is incremented only on the first index (NDArray.zeros(weight.shape, weight.context), // mean NDArray.zeros(weight.shape, weight.context)) // variance }