Skip to content

Commit

Permalink
#50 Add Identity initializer (#73)
Browse files Browse the repository at this point in the history
* Create Identity initializer

* Added Jupyter Notebook for producing export of all initializers
  • Loading branch information
hbrammer authored Jun 2, 2021
1 parent 54cd45c commit 90085c3
Show file tree
Hide file tree
Showing 10 changed files with 489 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
/examples/out/
/cache/
/examples/cache/
/api/cache/
/pretrainedModels/
/examples/savedmodels/
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package org.jetbrains.kotlinx.dl.api.core.exception

public class IdentityDimensionalityException(dimensions: Long):
Exception("Identity matrix is not defined for order $dimensions tensors.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.jetbrains.kotlinx.dl.api.core.initializer

import org.jetbrains.kotlinx.dl.api.core.exception.IdentityDimensionalityException
import org.jetbrains.kotlinx.dl.api.core.util.getDType
import org.tensorflow.Operand
import org.tensorflow.op.Ops

/**
* Initializer that generates the identity matrix.
* This initializer is only usable for generating 2D matrices.
* Although identity matrices are undefined for non square matrices an "identity" matrix is generated.
* E.g. 2 x 3 "identity" matrix
* ==> [[ 1., 0., 0.],
* [ 0., 1., 0.]]
*
* @property [gain] Identity matrix is multiply by this factor
* @constructor Creates a [Identity] initializer.
*/

public class Identity(
public val gain: Float = 1.0f
): Initializer() {
override fun initialize(fanIn: Int, fanOut: Int, tf: Ops, shape: Operand<Int>, name: String): Operand<Float> {
val dimensions = shape.asOutput().shape().size(0)
if(dimensions != 2L) throw IdentityDimensionalityException(dimensions)

val minSize = tf.reduceMin(shape,tf.constant(0))
val reshapedMinSize = tf.reshape(minSize, tf.constant(intArrayOf(1)))
val diag = tf.tile(tf.constant(floatArrayOf(gain)), reshapedMinSize)

val zeros = tf.withName(name).zeros(shape, getDType())
return tf.matrixSetDiagV2(zeros , diag, tf.constant(0))
}

override fun toString(): String {
return "Identity(scale=$gain)"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@ internal const val INITIALIZER_GLOROT_UNIFORM: String = "GlorotUniform"
internal const val INITIALIZER_GLOROT_NORMAL: String = "GlorotNormal"
internal const val INITIALIZER_HE_UNIFORM: String = "HeUniform"
internal const val INITIALIZER_HE_NORMAL: String = "HeNormal"
internal const val INITIALIZER_LECUN_UNIFORM: String = "LeCunUniform"
internal const val INITIALIZER_LECUN_NORMAL: String = "LeCunNormal"
internal const val INITIALIZER_LECUN_UNIFORM: String = "LecunUniform"
internal const val INITIALIZER_LECUN_NORMAL: String = "LecunNormal"
internal const val INITIALIZER_ZEROS: String = "Zeros"
internal const val INITIALIZER_ONES: String = "Ones"
internal const val INITIALIZER_RANDOM_NORMAL: String = "RandomNormal"
internal const val INITIALIZER_RANDOM_UNIFORM: String = "RandomUniform"
internal const val INITIALIZER_TRUNCATED_NORMAL: String = "TruncatedNormal"
internal const val INITIALIZER_CONSTANT: String = "Constant"
internal const val INITIALIZER_VARIANCE_SCALING: String = "VarianceScaling"
internal const val INITIALIZER_IDENTITY: String = "Identity"

// Keras activations
internal const val ACTIVATION_RELU: String = "relu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ private fun convertToInitializer(initializer: KerasInitializer): Initializer {
INITIALIZER_TRUNCATED_NORMAL -> TruncatedNormal(seed = seed)
INITIALIZER_VARIANCE_SCALING -> convertVarianceScaling(initializer)
/*INITIALIZER_CONSTANT -> Constant(initializer.config.value!!.toFloat())*/
INITIALIZER_IDENTITY -> Identity(initializer.config.gain?.toFloat() ?: 1f)
else -> throw IllegalStateException("${initializer.class_name} is not supported yet!")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ private fun convertToKerasInitializer(initializer: Initializer, isKerasFullyComp
LeCunNormal::class -> convertToVarianceScaling(initializer as VarianceScaling)
LeCunUniform::class -> convertToVarianceScaling(initializer as VarianceScaling)
RandomUniform::class -> convertToRandomUniform(initializer as RandomUniform)
Identity::class -> convertToIdentity(initializer as Identity)
else -> throw IllegalStateException("${initializer::class.simpleName} is not supported yet!")
}

Expand All @@ -242,6 +243,7 @@ private fun convertToKerasInitializer(initializer: Initializer, isKerasFullyComp
HeUniform::class -> INITIALIZER_HE_UNIFORM
LeCunNormal::class -> INITIALIZER_LECUN_NORMAL
LeCunUniform::class -> INITIALIZER_LECUN_UNIFORM
Identity::class -> INITIALIZER_IDENTITY
else -> throw IllegalStateException("${initializer::class.simpleName} is not supported yet!")
}
config = KerasInitializerConfig(seed = 12)
Expand Down Expand Up @@ -271,6 +273,15 @@ private fun convertToVarianceScaling(initializer: VarianceScaling): Pair<String,
)
}

private fun convertToIdentity(initializer: Identity): Pair<String, KerasInitializerConfig>{
return Pair(
INITIALIZER_IDENTITY,
KerasInitializerConfig(
gain = initializer.gain.toDouble()
)
)
}

private fun convertDistribution(distribution: Distribution): String {
return when (distribution) {
Distribution.TRUNCATED_NORMAL -> "truncated_normal"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package org.jetbrains.kotlinx.dl.api.core.initializer

import org.jetbrains.kotlinx.dl.api.core.exception.IdentityDimensionalityException
import org.jetbrains.kotlinx.dl.api.core.shape.shapeOperand
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.assertAll
import org.tensorflow.EagerSession
import org.tensorflow.Shape
import org.tensorflow.op.Ops

private const val EPS = 1e-7f
private const val FAN_IN = 10
private const val FAN_OUT = 20

class IdentityTest {
@Test
fun initialize() {
val actual = Array(2) { FloatArray(2) { 0f } }
val expected = Array(2) { FloatArray(2) { 0f } }
expected[0][0] = 1f
expected[1][1] = 1f

val shape = Shape.make(2, 2)

EagerSession.create().use { session ->
val tf = Ops.create(session)
val instance = Identity()
val operand = instance.initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name")
operand.asOutput().tensor().copyTo(actual)

assertArrayEquals(
expected[0],
actual[0],
EPS
)

assertArrayEquals(
expected[1],
actual[1],
EPS
)

assertEquals(
"Identity(scale=1.0)",
instance.toString()
)
}
}

@Test
fun initializeScaled() {
val actual = Array(2) { FloatArray(2) { 0f } }
val expected = Array(2) { FloatArray(2) { 0f } }
expected[0][0] = 3.4f
expected[1][1] = 3.4f

val shape = Shape.make(2, 2)

EagerSession.create().use { session ->
val tf = Ops.create(session)
val instance = Identity(3.4f)
val operand = instance.initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name")
operand.asOutput().tensor().copyTo(actual)

assertArrayEquals(
expected[0],
actual[0],
EPS
)

assertArrayEquals(
expected[1],
actual[1],
EPS
)

assertEquals(
"Identity(scale=3.4)",
instance.toString()
)
}
}

@Test
fun initializeWithNonSquare2x3() {
val actual = Array(2) { FloatArray(3) { 0f } }
val expected = Array(3) { FloatArray(3) { 0f } }
expected[0][0] = 1f
expected[1][1] = 1f

val shape = Shape.make(2, 3)

EagerSession.create().use { session ->
val tf = Ops.create(session)
val instance = Identity()
val operand = instance.initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name")
operand.asOutput().tensor().copyTo(actual)


assertArrayEquals(
expected[0],
actual[0],
EPS
)

assertArrayEquals(
expected[1],
actual[1],
EPS
)

assertEquals(
"Identity(scale=1.0)",
instance.toString()
)
}
}

@Test
fun initializeWithNonSquare3x2() {
val actual = Array(3) { FloatArray(2) { 0f } }
val expected = Array(3) { FloatArray(2) { 0f } }
expected[0][0] = 1f
expected[1][1] = 1f

val shape = Shape.make(3, 2)

EagerSession.create().use { session ->
val tf = Ops.create(session)
val instance = Identity()
val operand = instance.initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name")
operand.asOutput().tensor().copyTo(actual)


assertArrayEquals(
expected[0],
actual[0],
EPS
)

assertArrayEquals(
expected[1],
actual[1],
EPS
)

assertArrayEquals(
expected[2],
actual[2],
EPS
)

assertEquals(
"Identity(scale=1.0)",
instance.toString()
)
}
}

@Test
fun initializeWith1DShapeFails(){
EagerSession.create().use { session ->
val tf = Ops.create(session)
val shape = Shape.make(5)

val exception = assertThrows(IdentityDimensionalityException::class.java) {
Identity().initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name")
}

assertEquals(
"Identity matrix is not defined for order 1 tensors.",
exception.message
)
}
}

@Test
fun initializeWith3DShapeFails(){
EagerSession.create().use { session ->
val tf = Ops.create(session)
val shape = Shape.make(5, 5, 5)

val exception = assertThrows(IdentityDimensionalityException::class.java) {
Identity().initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name")
}

assertEquals(
"Identity matrix is not defined for order 3 tensors.",
exception.message
)
}
}
}
Loading

0 comments on commit 90085c3

Please sign in to comment.