Skip to content

Commit

Permalink
Generalize reshape layer to higher dimensions (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
knok16 authored Oct 12, 2021
1 parent e47c476 commit 374fd5b
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,47 +36,18 @@ public class Reshape(
override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {
val tensorShape = TensorShape(inputShape)
val amountOfNeuronsInFlattenLayer = (tensorShape.numElements() / abs(tensorShape.size(0))).toInt()
units = when (targetShape.size) {
1 -> tf.constant(intArrayOf(-1, targetShape[0]))
2 -> tf.constant(intArrayOf(-1, targetShape[0], targetShape[1]))
3 -> tf.constant(intArrayOf(-1, targetShape[0], targetShape[1], targetShape[2]))
4 -> tf.constant(intArrayOf(-1, targetShape[0], targetShape[1], targetShape[2], targetShape[3]))
5 -> tf.constant(
intArrayOf(
-1,
targetShape[0],
targetShape[1],
targetShape[2],
targetShape[3],
targetShape[4]
)
)
else -> throw UnsupportedOperationException("Reshaping for ${targetShape.size} is not supported yet!")
}

units = tf.constant(IntArray(targetShape.size + 1) {
if (it == 0) -1 else targetShape[it - 1]
})

fanIn = tensorShape.numElements().toInt()
fanOut = amountOfNeuronsInFlattenLayer
}

override fun computeOutputShape(inputShape: Shape): Shape {
// leaves unknown dimensions unknown
val tensorShape = TensorShape(inputShape)
return when (targetShape.size) {
3 -> Shape.make(
tensorShape.head(),
targetShape[0].toLong(),
targetShape[1].toLong(),
targetShape[2].toLong()
)
2 -> Shape.make(
tensorShape.head(),
targetShape[0].toLong(),
targetShape[1].toLong(),
)
1 -> Shape.make(tensorShape.head(), targetShape[0].toLong())
else -> throw UnsupportedOperationException("Input shape with ${targetShape.size} dimensions is not supported.")
}
}
// leaves unknown dimensions unknown
override fun computeOutputShape(inputShape: Shape): Shape =
Shape.make(inputShape.size(0), *targetShape.map { it.toLong() }.toLongArray())

override fun forward(
tf: Ops,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package org.jetbrains.kotlinx.dl.api.core.layer.reshaping

import org.jetbrains.kotlinx.dl.api.core.layer.LayerTest
import org.jetbrains.kotlinx.dl.api.core.layer.RunMode
import org.junit.jupiter.api.Test

internal class ReshapeTest : LayerTest() {
@Test
fun default() {
val layer = Reshape(targetShape = listOf(2, 5))

val input = arrayOf(
FloatArray(10) { it + 10.0f },
FloatArray(10) { it + 30.0f }
)
val expected = arrayOf(
arrayOf(
FloatArray(5) { it + 10.0f },
FloatArray(5) { it + 15.0f }
),
arrayOf(
FloatArray(5) { it + 30.0f },
FloatArray(5) { it + 35.0f }
)
)
assertLayerOutputIsCorrect(layer, input, expected)
assertLayerOutputIsCorrect(layer, input, expected, RunMode.GRAPH)
}

@Test
fun flatten() {
val layer = Reshape(targetShape = listOf(60))

var i = 0.0f
val input = Array(10) { Array(2) { Array(3) { Array(2) { FloatArray(5) { i++ } } } } }

i = 0.0f
val expected = Array(10) { FloatArray(60) { i++ } }

assertLayerOutputIsCorrect(layer, input, expected)
assertLayerOutputIsCorrect(layer, input, expected, RunMode.GRAPH)
}

@Test
fun fromFlatToStructured() {
val layer = Reshape(targetShape = listOf(2, 3, 2, 5))

var i = 0.0f
val input = Array(10) { FloatArray(60) { i++ } }

i = 0.0f
val expected = Array(10) { Array(2) { Array(3) { Array(2) { FloatArray(5) { i++ } } } } }

assertLayerOutputIsCorrect(layer, input, expected)
assertLayerOutputIsCorrect(layer, input, expected, RunMode.GRAPH)
}

@Test
fun computeOutputShape() {
assertLayerComputedOutputShape(
layer = Reshape(targetShape = listOf(1, 2, 3)),
inputShapeArray = longArrayOf(100, 3, 2, 1),
expectedOutputShape = longArrayOf(100, 1, 2, 3)
)

assertLayerComputedOutputShape(
layer = Reshape(targetShape = listOf(1, 2, 3)),
inputShapeArray = longArrayOf(-1, 3, 2, 1),
expectedOutputShape = longArrayOf(-1, 1, 2, 3)
)

assertLayerComputedOutputShape(
layer = Reshape(targetShape = listOf(6)),
inputShapeArray = longArrayOf(100, 3, 2, 1),
expectedOutputShape = longArrayOf(100, 6)
)

assertLayerComputedOutputShape(
layer = Reshape(targetShape = listOf(4, 5)),
inputShapeArray = longArrayOf(100, 20),
expectedOutputShape = longArrayOf(100, 4, 5)
)

assertLayerComputedOutputShape(
layer = Reshape(targetShape = listOf(4, 5, 6, 7, 8, 9, 10)),
inputShapeArray = longArrayOf(100, 4 * 5 * 6, 7 * 10, 8 * 9),
expectedOutputShape = longArrayOf(100, 4, 5, 6, 7, 8, 9, 10)
)
}
}

0 comments on commit 374fd5b

Please sign in to comment.