-
Notifications
You must be signed in to change notification settings - Fork 105
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Create Identity initializer * Added Jupyter Notebook for producing export of all initializers
- Loading branch information
Showing
10 changed files
with
489 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,5 +14,6 @@ | |
/examples/out/ | ||
/cache/ | ||
/examples/cache/ | ||
/api/cache/ | ||
/pretrainedModels/ | ||
/examples/savedmodels/ |
4 changes: 4 additions & 0 deletions
4
...ain/kotlin/org/jetbrains/kotlinx/dl/api/core/exception/IdentityDimensionalityException.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
package org.jetbrains.kotlinx.dl.api.core.exception | ||
|
||
public class IdentityDimensionalityException(dimensions: Long): | ||
Exception("Identity matrix is not defined for order $dimensions tensors.") |
38 changes: 38 additions & 0 deletions
38
api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/initializer/Identity.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
package org.jetbrains.kotlinx.dl.api.core.initializer | ||
|
||
import org.jetbrains.kotlinx.dl.api.core.exception.IdentityDimensionalityException | ||
import org.jetbrains.kotlinx.dl.api.core.util.getDType | ||
import org.tensorflow.Operand | ||
import org.tensorflow.op.Ops | ||
|
||
/** | ||
* Initializer that generates the identity matrix. | ||
* This initializer is only usable for generating 2D matrices. | ||
* Although identity matrices are undefined for non square matrices an "identity" matrix is generated. | ||
* E.g. 2 x 3 "identity" matrix | ||
* ==> [[ 1., 0., 0.], | ||
* [ 0., 1., 0.]] | ||
* | ||
* @property [gain] Identity matrix is multiply by this factor | ||
* @constructor Creates a [Identity] initializer. | ||
*/ | ||
|
||
public class Identity( | ||
public val gain: Float = 1.0f | ||
): Initializer() { | ||
override fun initialize(fanIn: Int, fanOut: Int, tf: Ops, shape: Operand<Int>, name: String): Operand<Float> { | ||
val dimensions = shape.asOutput().shape().size(0) | ||
if(dimensions != 2L) throw IdentityDimensionalityException(dimensions) | ||
|
||
val minSize = tf.reduceMin(shape,tf.constant(0)) | ||
val reshapedMinSize = tf.reshape(minSize, tf.constant(intArrayOf(1))) | ||
val diag = tf.tile(tf.constant(floatArrayOf(gain)), reshapedMinSize) | ||
|
||
val zeros = tf.withName(name).zeros(shape, getDType()) | ||
return tf.matrixSetDiagV2(zeros , diag, tf.constant(0)) | ||
} | ||
|
||
override fun toString(): String { | ||
return "Identity(scale=$gain)" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
197 changes: 197 additions & 0 deletions
197
api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/initializer/IdentityTest.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
package org.jetbrains.kotlinx.dl.api.core.initializer | ||
|
||
import org.jetbrains.kotlinx.dl.api.core.exception.IdentityDimensionalityException | ||
import org.jetbrains.kotlinx.dl.api.core.shape.shapeOperand | ||
import org.junit.jupiter.api.Assertions | ||
import org.junit.jupiter.api.Assertions.assertArrayEquals | ||
import org.junit.jupiter.api.Assertions.assertEquals | ||
import org.junit.jupiter.api.Test | ||
import org.junit.jupiter.api.Assertions.assertThrows | ||
import org.junit.jupiter.api.assertAll | ||
import org.tensorflow.EagerSession | ||
import org.tensorflow.Shape | ||
import org.tensorflow.op.Ops | ||
|
||
private const val EPS = 1e-7f | ||
private const val FAN_IN = 10 | ||
private const val FAN_OUT = 20 | ||
|
||
class IdentityTest { | ||
@Test | ||
fun initialize() { | ||
val actual = Array(2) { FloatArray(2) { 0f } } | ||
val expected = Array(2) { FloatArray(2) { 0f } } | ||
expected[0][0] = 1f | ||
expected[1][1] = 1f | ||
|
||
val shape = Shape.make(2, 2) | ||
|
||
EagerSession.create().use { session -> | ||
val tf = Ops.create(session) | ||
val instance = Identity() | ||
val operand = instance.initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name") | ||
operand.asOutput().tensor().copyTo(actual) | ||
|
||
assertArrayEquals( | ||
expected[0], | ||
actual[0], | ||
EPS | ||
) | ||
|
||
assertArrayEquals( | ||
expected[1], | ||
actual[1], | ||
EPS | ||
) | ||
|
||
assertEquals( | ||
"Identity(scale=1.0)", | ||
instance.toString() | ||
) | ||
} | ||
} | ||
|
||
@Test | ||
fun initializeScaled() { | ||
val actual = Array(2) { FloatArray(2) { 0f } } | ||
val expected = Array(2) { FloatArray(2) { 0f } } | ||
expected[0][0] = 3.4f | ||
expected[1][1] = 3.4f | ||
|
||
val shape = Shape.make(2, 2) | ||
|
||
EagerSession.create().use { session -> | ||
val tf = Ops.create(session) | ||
val instance = Identity(3.4f) | ||
val operand = instance.initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name") | ||
operand.asOutput().tensor().copyTo(actual) | ||
|
||
assertArrayEquals( | ||
expected[0], | ||
actual[0], | ||
EPS | ||
) | ||
|
||
assertArrayEquals( | ||
expected[1], | ||
actual[1], | ||
EPS | ||
) | ||
|
||
assertEquals( | ||
"Identity(scale=3.4)", | ||
instance.toString() | ||
) | ||
} | ||
} | ||
|
||
@Test | ||
fun initializeWithNonSquare2x3() { | ||
val actual = Array(2) { FloatArray(3) { 0f } } | ||
val expected = Array(3) { FloatArray(3) { 0f } } | ||
expected[0][0] = 1f | ||
expected[1][1] = 1f | ||
|
||
val shape = Shape.make(2, 3) | ||
|
||
EagerSession.create().use { session -> | ||
val tf = Ops.create(session) | ||
val instance = Identity() | ||
val operand = instance.initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name") | ||
operand.asOutput().tensor().copyTo(actual) | ||
|
||
|
||
assertArrayEquals( | ||
expected[0], | ||
actual[0], | ||
EPS | ||
) | ||
|
||
assertArrayEquals( | ||
expected[1], | ||
actual[1], | ||
EPS | ||
) | ||
|
||
assertEquals( | ||
"Identity(scale=1.0)", | ||
instance.toString() | ||
) | ||
} | ||
} | ||
|
||
@Test | ||
fun initializeWithNonSquare3x2() { | ||
val actual = Array(3) { FloatArray(2) { 0f } } | ||
val expected = Array(3) { FloatArray(2) { 0f } } | ||
expected[0][0] = 1f | ||
expected[1][1] = 1f | ||
|
||
val shape = Shape.make(3, 2) | ||
|
||
EagerSession.create().use { session -> | ||
val tf = Ops.create(session) | ||
val instance = Identity() | ||
val operand = instance.initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name") | ||
operand.asOutput().tensor().copyTo(actual) | ||
|
||
|
||
assertArrayEquals( | ||
expected[0], | ||
actual[0], | ||
EPS | ||
) | ||
|
||
assertArrayEquals( | ||
expected[1], | ||
actual[1], | ||
EPS | ||
) | ||
|
||
assertArrayEquals( | ||
expected[2], | ||
actual[2], | ||
EPS | ||
) | ||
|
||
assertEquals( | ||
"Identity(scale=1.0)", | ||
instance.toString() | ||
) | ||
} | ||
} | ||
|
||
@Test | ||
fun initializeWith1DShapeFails(){ | ||
EagerSession.create().use { session -> | ||
val tf = Ops.create(session) | ||
val shape = Shape.make(5) | ||
|
||
val exception = assertThrows(IdentityDimensionalityException::class.java) { | ||
Identity().initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name") | ||
} | ||
|
||
assertEquals( | ||
"Identity matrix is not defined for order 1 tensors.", | ||
exception.message | ||
) | ||
} | ||
} | ||
|
||
@Test | ||
fun initializeWith3DShapeFails(){ | ||
EagerSession.create().use { session -> | ||
val tf = Ops.create(session) | ||
val shape = Shape.make(5, 5, 5) | ||
|
||
val exception = assertThrows(IdentityDimensionalityException::class.java) { | ||
Identity().initialize(FAN_IN, FAN_OUT, tf, shapeOperand(tf, shape), "default_name") | ||
} | ||
|
||
assertEquals( | ||
"Identity matrix is not defined for order 3 tensors.", | ||
exception.message | ||
) | ||
} | ||
} | ||
} |
Oops, something went wrong.