Skip to content

Commit

Permalink
Merge pull request apache#14 from javelinjs/scala-package-cc
Browse files Browse the repository at this point in the history
fix Adam optimizer
  • Loading branch information
terrytangyuan committed Dec 28, 2015
2 parents 7f7c1c2 + 4917155 commit bed8eaf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.slf4j.LoggerFactory
* @author Yuan Tang
*/

abstract class LRScheduler(protected var baseLR: Float = 0.01f) {
abstract class LRScheduler(var baseLR: Float = 0.01f) {
/**
* Base class of a learning rate scheduler
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package ml.dmlc.mxnet.optimizer

import ml.dmlc.mxnet.{NDArray, Optimizer, LRScheduler}
import ml.dmlc.mxnet.NDArrayConversions._

Expand All @@ -8,7 +10,7 @@ import ml.dmlc.mxnet.NDArrayConversions._
* Adam: A Method for Stochastic Optimization,
* http://arxiv.org/abs/1412.6980
*
* @author Yuan Tang
* @author Yuan Tang, Yizhi Liu
*
* @param learningRate Float, Step size.
* @param beta1 Float, Exponential decay rate for the first moment estimates.
Expand All @@ -21,12 +23,17 @@ import ml.dmlc.mxnet.NDArrayConversions._
* @param lrScheduler The learning rate scheduler
*/
class Adam(var learningRate: Float = 0.002f, val beta1: Float = 0.9f, val beta2: Float = 0.999f,
val epsilon: Float = 0.00000001f, val decayFactor: Float = 1-0.00000001f, val wd: Float = 0.0f,
val epsilon: Float = 1e-8f, val decayFactor: Float = 1-1e-8f, val wd: Float = 0.0f,
rescaleGrad: Float = 1f, val clipGradient: Float = 0f,
val lrScheduler: LRScheduler = null) extends Optimizer(rescaleGrad: Float) {
val lrScheduler: LRScheduler = null) extends Optimizer(rescaleGrad: Float) {

protected var time: Int = 0
protected var timeFirstIndex: Int = 0
protected var timeFirstIndex: Option[Int] = None

if (lrScheduler != null) {
lrScheduler.baseLR = learningRate
}

/**
* Update the parameters.
* @param index An unique integer key used to index the parameters
Expand All @@ -45,41 +52,42 @@ class Adam(var learningRate: Float = 0.002f, val beta1: Float = 0.9f, val beta2:
this.learningRate
}) * lrScale.getOrElse(index, 1f)

var (mean, variance) = state
var (mean, variance) = state.asInstanceOf[(NDArray, NDArray)]

if (timeFirstIndex == 0) {
timeFirstIndex = index
// increment time only when the first parameters is called
if (timeFirstIndex == None) {
timeFirstIndex = Option(index)
time = 0
} else if (timeFirstIndex == index) {
} else if (timeFirstIndex.get == index) {
time += 1
}

val t1: Int = time + 1
learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1))/(1.0 - math.pow(beta1, t1))) toFloat
val beta1t = beta1 * math.pow(decayFactor, t1 - 1) toFloat
learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1)) / (1.0 - math.pow(beta1, t1))).toFloat
val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat

var resdGrad = grad * rescaleGrad
if (clipGradient != 0f) {
resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
}

val meanT = beta1t * mean.asInstanceOf[NDArray] + (1.0 - beta1t) * resdGrad toScalar
val varianceT = beta2 * variance.asInstanceOf[NDArray] + (1.0f - beta2) * resdGrad * resdGrad toScalar
val meanT = beta1t * mean + (1.0 - beta1t) * resdGrad
val varianceT = beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad

var step = learningRate * meanT / (math.sqrt(varianceT) + epsilon)
var step = learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon)

if (wd > 0.0f) {
step += (lr * wd * weight).toScalar
step += lr * wd * weight
}

weight += -step.toFloat
weight += -step
mean = meanT
variance = varianceT
}

// Create additional optimizer state: mean, variance
override def createState(index: Int, weight: NDArray): AnyRef = {
timeFirstIndex = 0
override def createState(index: Int, weight: NDArray): (NDArray, NDArray) = {
timeFirstIndex = None // time is incremented only on the first index
(NDArray.zeros(weight.shape, weight.context), // mean
NDArray.zeros(weight.shape, weight.context)) // variance
}
Expand Down

0 comments on commit bed8eaf

Please sign in to comment.