Skip to content

Commit

Permalink
Persist Constant, Ones, Zeros, Orthogonal and TruncatedNormal initial…
Browse files Browse the repository at this point in the history
…izers
  • Loading branch information
juliabeliaeva committed Apr 15, 2022
1 parent 2cbd2ee commit 6d5155f
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import org.tensorflow.op.Ops
* @property constantValue Constant value to fill the tensor.
* @constructor Creates a [Constant] initializer with a given [constantValue].
*/
public class Constant(private val constantValue: Float) : Initializer() {
public class Constant(public val constantValue: Float) : Initializer() {
override fun initialize(
fanIn: Int,
fanOut: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import kotlin.math.min
*/

public class Orthogonal(
private val gain: Float = 1.0f,
private val seed: Long = 12L
public val gain: Float = 1.0f,
public val seed: Long = 12L
) : Initializer() {
override fun initialize(
fanIn: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.tensorflow.op.random.TruncatedNormal
* @property seed Seed.
* @constructor Creates [TruncatedNormal] initializer.
*/
public class TruncatedNormal(private val seed: Long = 12L) :
public class TruncatedNormal(public val seed: Long = 12L) :
Initializer() {
override fun initialize(
fanIn: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,21 +257,6 @@ private fun convertToInitializer(initializer: KerasInitializer): Initializer {
INITIALIZER_HE_UNIFORM -> HeUniform(seed = seed)
INITIALIZER_LECUN_NORMAL -> LeCunNormal(seed = seed)
INITIALIZER_LECUN_UNIFORM -> LeCunUniform(seed = seed)
INITIALIZER_ZEROS -> RandomUniform(
seed = seed,
minVal = 0.0f,
maxVal = 0.0f
) // instead of real initializers, because it doesn't influence on nothing
INITIALIZER_CONSTANT -> RandomUniform(
seed = seed,
minVal = 0.0f,
maxVal = 0.0f
) // instead of real initializers, because it doesn't influence on nothing
INITIALIZER_ONES -> RandomUniform(
seed = seed,
minVal = 1.0f,
maxVal = 1.0f
) // instead of real initializers, because it doesn't influence on nothing*/
INITIALIZER_RANDOM_NORMAL -> RandomNormal(
seed = seed,
mean = initializer.config.mean!!.toFloat(),
Expand All @@ -282,10 +267,12 @@ private fun convertToInitializer(initializer: KerasInitializer): Initializer {
minVal = initializer.config.minval!!.toFloat(),
maxVal = initializer.config.maxval!!.toFloat()
)
INITIALIZER_TRUNCATED_NORMAL -> TruncatedNormal(seed = seed)
INITIALIZER_VARIANCE_SCALING -> convertVarianceScalingInitializer(initializer)
INITIALIZER_TRUNCATED_NORMAL -> TruncatedNormal(seed = seed)
INITIALIZER_ORTHOGONAL -> Orthogonal(seed = seed, gain = initializer.config.gain!!.toFloat())
/*INITIALIZER_CONSTANT -> Constant(initializer.config.value!!.toFloat())*/
INITIALIZER_ZEROS -> Zeros()
INITIALIZER_ONES -> Ones()
INITIALIZER_CONSTANT -> Constant(initializer.config.value!!.toFloat())
INITIALIZER_IDENTITY -> Identity(initializer.config.gain?.toFloat() ?: 1f)
else -> throw IllegalStateException("${initializer.class_name} is not supported yet!")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ private fun convertToKerasInitializer(initializer: Initializer, isKerasFullyComp
}
is RandomUniform -> convertToRandomUniformInitializer(initializer)
is RandomNormal -> convertToRandomNormalInitializer(initializer)
is TruncatedNormal -> INITIALIZER_TRUNCATED_NORMAL to KerasInitializerConfig(seed = initializer.seed.toInt())
is Orthogonal -> convertToOrthogonalInitializer(initializer)
is Zeros -> INITIALIZER_ZEROS to KerasInitializerConfig()
is Ones -> INITIALIZER_ONES to KerasInitializerConfig()
is Constant -> INITIALIZER_CONSTANT to KerasInitializerConfig(value = initializer.constantValue.toDouble())
is Identity -> convertToIdentityInitializer(initializer)
else -> throw IllegalStateException("${initializer::class.simpleName} is not supported yet!")
}
Expand Down Expand Up @@ -235,6 +240,15 @@ private fun convertMode(mode: Mode): String {
}
}

private fun convertToOrthogonalInitializer(initializer: Orthogonal): Pair<String, KerasInitializerConfig> {
return Pair(
INITIALIZER_ORTHOGONAL, KerasInitializerConfig(
gain = initializer.gain.toDouble(),
seed = initializer.seed.toInt()
)
)
}

private fun convertToKerasPadding(padding: ConvPadding): KerasPadding {
return when (padding) {
ConvPadding.SAME -> KerasPadding.Same
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.inference.keras

import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.initializer.*
import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.junit.jupiter.api.Test

class InitializerPersistenceTest {
@Test
fun ones() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = Ones())
)
)
}

@Test
fun zeros() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = Zeros())
)
)
}

@Test
fun `random normal`() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = RandomNormal(mean = 10f, stdev = 0.1f, seed = 10))
)
)
}

@Test
fun `random uniform`() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = RandomUniform(maxVal = 10f, minVal = -10f, seed = 10))
)
)
}

@Test
fun `truncated normal`() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = TruncatedNormal(seed = 10))
)
)
}

@Test
fun `glorot normal`() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = GlorotNormal(seed = 10))
)
)
}

@Test
fun `glorot uniform`() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = GlorotUniform(seed = 10))
)
)
}

@Test
fun `he normal`() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = HeNormal(seed = 10))
)
)
}

@Test
fun `he uniform`() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = HeUniform(seed = 10))
)
)
}

@Test
fun `lecun normal`() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = LeCunNormal(seed = 10))
)
)
}

@Test
fun `lecun uniform`() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = LeCunUniform(seed = 10))
)
)
}

@Test
fun identity() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = Identity(gain = 2.0f))
)
)
}

@Test
fun constant() {
LayerPersistenceTest.run(
Sequential.of(
Input(10),
Dense(10, kernelInitializer = Constant(constantValue = 2.0f))
)
)
}

@Test
fun orthogonal() {
LayerPersistenceTest.run(
Sequential.of(
Input(2),
Dense(2, kernelInitializer = Orthogonal(gain = 0.5f, seed = 10))
)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,22 @@ class ModelLoaderTest {
}

with((this[7] as Dense).kernelInitializer) {
// This should be 'Zeros' initializer
assertTrue(this is RandomUniform)
assertEquals(0.0f, (this as RandomUniform).maxVal)
assertEquals(0.0f, (this).minVal)
assertTrue(this is Zeros)
}

with((this[8] as Dense).kernelInitializer) {
// This should be 'Constant' initializer
assertTrue(this is RandomUniform)
assertEquals(0.0f, (this as RandomUniform).maxVal)
assertEquals(0.0f, (this).minVal)
assertTrue(this is Constant)
assertEquals(2f, (this as Constant).constantValue)
}

with((this[9] as Dense).kernelInitializer) {
// This should be 'Ones' initializer
assertTrue(this is RandomUniform)
assertEquals(1.0f, (this as RandomUniform).maxVal)
assertEquals(1.0f, this.minVal)
assertTrue(this is Ones)
}

with((this[10] as Dense).kernelInitializer) {
assertTrue(this is RandomNormal)
// Mean and stdev are not public. Should they be?
// assertEquals(0.0f, (this as RandomNormal).mean)
// assertEquals(1.0f, (this as RandomNormal).stdev)
assertEquals(0.0f, (this as RandomNormal).mean)
assertEquals(1.0f, this.stdev)
}

with((this[11] as Dense).kernelInitializer) {
Expand Down

0 comments on commit 6d5155f

Please sign in to comment.