Skip to content

Commit

Permalink
Export and import ParametrizedTruncatedNormal initializer when full c…
Browse files Browse the repository at this point in the history
…ompatibility is not required
  • Loading branch information
juliabeliaeva committed Apr 29, 2022
1 parent 3e0676d commit 5116e4c
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

Expand All @@ -20,11 +20,11 @@ import org.tensorflow.op.random.ParameterizedTruncatedNormal
* @constructor Creates a [ParametrizedTruncatedNormal] initializer.
*/
public class ParametrizedTruncatedNormal(
private val mean: Float = 0.0f,
private val stdev: Float = 1.0f,
private val p1: Float = -10.0f, // low level edge
private val p2: Float = 10.0f, // high level edge
private val seed: Long
internal val mean: Float = 0.0f,
internal val stdev: Float = 1.0f,
internal val p1: Float = -10.0f, // low level edge
internal val p2: Float = 10.0f, // high level edge
internal val seed: Long
) :
Initializer() {
override fun initialize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ internal const val INITIALIZER_ONES: String = "Ones"
internal const val INITIALIZER_RANDOM_NORMAL: String = "RandomNormal"
internal const val INITIALIZER_RANDOM_UNIFORM: String = "RandomUniform"
internal const val INITIALIZER_TRUNCATED_NORMAL: String = "TruncatedNormal"
internal const val INITIALIZER_PARAMETRIZED_TRUNCATED_NORMAL: String = "ParametrizedTruncatedNormal"
internal const val INITIALIZER_CONSTANT: String = "Constant"
internal const val INITIALIZER_VARIANCE_SCALING: String = "VarianceScaling"
internal const val INITIALIZER_IDENTITY: String = "Identity"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ private fun convertToInitializer(initializer: KerasInitializer): Initializer {
)
INITIALIZER_VARIANCE_SCALING -> convertVarianceScalingInitializer(initializer)
INITIALIZER_TRUNCATED_NORMAL -> TruncatedNormal(seed = seed)
INITIALIZER_PARAMETRIZED_TRUNCATED_NORMAL -> ParametrizedTruncatedNormal(
mean = config.mean!!.toFloat(),
stdev = config.stddev!!.toFloat(),
p1 = config.p1!!.toFloat(),
p2 = config.p2!!.toFloat(),
seed = seed
)
INITIALIZER_ORTHOGONAL -> Orthogonal(seed = seed, gain = config.gain!!.toFloat())
INITIALIZER_ZEROS -> Zeros()
INITIALIZER_ONES -> Ones()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ private fun convertToKerasInitializer(initializer: Initializer, isKerasFullyComp
is RandomUniform -> convertToRandomUniformInitializer(initializer)
is RandomNormal -> convertToRandomNormalInitializer(initializer)
is TruncatedNormal -> INITIALIZER_TRUNCATED_NORMAL to KerasInitializerConfig(seed = initializer.seed.toInt())
is ParametrizedTruncatedNormal -> {
if (isKerasFullyCompatible) {
throw throw IllegalStateException("Exporting ${initializer::class.simpleName} is not supported in the fully compatible mode.")
} else convertToParametrizedTruncatedNormalInitializer(initializer)
}
is Orthogonal -> convertToOrthogonalInitializer(initializer)
is Zeros -> INITIALIZER_ZEROS to KerasInitializerConfig()
is Ones -> INITIALIZER_ONES to KerasInitializerConfig()
Expand Down Expand Up @@ -226,6 +231,18 @@ private fun convertToIdentityInitializer(initializer: Identity): Pair<String, Ke
)
}

private fun convertToParametrizedTruncatedNormalInitializer(initializer: ParametrizedTruncatedNormal): Pair<String, KerasInitializerConfig> {
return Pair(
INITIALIZER_PARAMETRIZED_TRUNCATED_NORMAL, KerasInitializerConfig(
mean = initializer.mean.toDouble(),
stddev = initializer.stdev.toDouble(),
p1 = initializer.p1.toDouble(),
p2 = initializer.p2.toDouble(),
seed = initializer.seed.toInt()
)
)
}

private fun convertDistribution(distribution: Distribution): String {
return when (distribution) {
Distribution.TRUNCATED_NORMAL -> "truncated_normal"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

Expand Down Expand Up @@ -27,5 +27,9 @@ internal data class KerasInitializerConfig(
@Json(serializeNull = false)
val value: Double? = null,
@Json(serializeNull = false)
val gain: Double? = null
val gain: Double? = null,
@Json(serializeNull = false)
val p1: Double? = null,
@Json(serializeNull = false)
val p2: Double? = null
)
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,24 @@ class InitializerImportExportTest {
)
}

@Test
fun parametrizedTruncatedNormal() {
LayerImportExportTest.run(
Sequential.of(
Input(10),
Dense(
10, kernelInitializer = ParametrizedTruncatedNormal(
mean = 0.1f,
stdev = 2f,
p1 = -5f,
p2 = 5.1f,
seed = 10
)
)
)
)
}

@Test
fun glorotNormal() {
LayerImportExportTest.run(
Expand Down

0 comments on commit 5116e4c

Please sign in to comment.