diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Reshape.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Reshape.kt index 2fb65597a..496b8fa30 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Reshape.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Reshape.kt @@ -36,47 +36,18 @@ public class Reshape( override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { val tensorShape = TensorShape(inputShape) val amountOfNeuronsInFlattenLayer = (tensorShape.numElements() / abs(tensorShape.size(0))).toInt() - units = when (targetShape.size) { - 1 -> tf.constant(intArrayOf(-1, targetShape[0])) - 2 -> tf.constant(intArrayOf(-1, targetShape[0], targetShape[1])) - 3 -> tf.constant(intArrayOf(-1, targetShape[0], targetShape[1], targetShape[2])) - 4 -> tf.constant(intArrayOf(-1, targetShape[0], targetShape[1], targetShape[2], targetShape[3])) - 5 -> tf.constant( - intArrayOf( - -1, - targetShape[0], - targetShape[1], - targetShape[2], - targetShape[3], - targetShape[4] - ) - ) - else -> throw UnsupportedOperationException("Reshaping for ${targetShape.size} is not supported yet!") - } + + units = tf.constant(IntArray(targetShape.size + 1) { + if (it == 0) -1 else targetShape[it - 1] + }) fanIn = tensorShape.numElements().toInt() fanOut = amountOfNeuronsInFlattenLayer } - override fun computeOutputShape(inputShape: Shape): Shape { - // leaves unknown dimensions unknown - val tensorShape = TensorShape(inputShape) - return when (targetShape.size) { - 3 -> Shape.make( - tensorShape.head(), - targetShape[0].toLong(), - targetShape[1].toLong(), - targetShape[2].toLong() - ) - 2 -> Shape.make( - tensorShape.head(), - targetShape[0].toLong(), - targetShape[1].toLong(), - ) - 1 -> Shape.make(tensorShape.head(), targetShape[0].toLong()) - else -> throw UnsupportedOperationException("Input shape with ${targetShape.size} dimensions is not supported.") - } - } + // leaves unknown dimensions unknown + override fun computeOutputShape(inputShape: Shape): Shape = + Shape.make(inputShape.size(0), *targetShape.map { it.toLong() }.toLongArray()) override fun forward( tf: Ops, diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ReshapeTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ReshapeTest.kt new file mode 100644 index 000000000..2f1b21864 --- /dev/null +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ReshapeTest.kt @@ -0,0 +1,90 @@ +package org.jetbrains.kotlinx.dl.api.core.layer.reshaping + +import org.jetbrains.kotlinx.dl.api.core.layer.LayerTest +import org.jetbrains.kotlinx.dl.api.core.layer.RunMode +import org.junit.jupiter.api.Test + +internal class ReshapeTest : LayerTest() { + @Test + fun default() { + val layer = Reshape(targetShape = listOf(2, 5)) + + val input = arrayOf( + FloatArray(10) { it + 10.0f }, + FloatArray(10) { it + 30.0f } + ) + val expected = arrayOf( + arrayOf( + FloatArray(5) { it + 10.0f }, + FloatArray(5) { it + 15.0f } + ), + arrayOf( + FloatArray(5) { it + 30.0f }, + FloatArray(5) { it + 35.0f } + ) + ) + assertLayerOutputIsCorrect(layer, input, expected) + assertLayerOutputIsCorrect(layer, input, expected, RunMode.GRAPH) + } + + @Test + fun flatten() { + val layer = Reshape(targetShape = listOf(60)) + + var i = 0.0f + val input = Array(10) { Array(2) { Array(3) { Array(2) { FloatArray(5) { i++ } } } } } + + i = 0.0f + val expected = Array(10) { FloatArray(60) { i++ } } + + assertLayerOutputIsCorrect(layer, input, expected) + assertLayerOutputIsCorrect(layer, input, expected, RunMode.GRAPH) + } + + @Test + fun fromFlatToStructured() { + val layer = Reshape(targetShape = listOf(2, 3, 2, 5)) + + var i = 0.0f + val input = Array(10) { FloatArray(60) { i++ } } + + i = 0.0f + val expected = Array(10) { Array(2) { Array(3) { Array(2) { FloatArray(5) { i++ } } } } } + + assertLayerOutputIsCorrect(layer, input, expected) + assertLayerOutputIsCorrect(layer, input, expected, RunMode.GRAPH) + } + + @Test + fun computeOutputShape() { + assertLayerComputedOutputShape( + layer = Reshape(targetShape = listOf(1, 2, 3)), + inputShapeArray = longArrayOf(100, 3, 2, 1), + expectedOutputShape = longArrayOf(100, 1, 2, 3) + ) + + assertLayerComputedOutputShape( + layer = Reshape(targetShape = listOf(1, 2, 3)), + inputShapeArray = longArrayOf(-1, 3, 2, 1), + expectedOutputShape = longArrayOf(-1, 1, 2, 3) + ) + + assertLayerComputedOutputShape( + layer = Reshape(targetShape = listOf(6)), + inputShapeArray = longArrayOf(100, 3, 2, 1), + expectedOutputShape = longArrayOf(100, 6) + ) + + assertLayerComputedOutputShape( + layer = Reshape(targetShape = listOf(4, 5)), + inputShapeArray = longArrayOf(100, 20), + expectedOutputShape = longArrayOf(100, 4, 5) + ) + + assertLayerComputedOutputShape( + layer = Reshape(targetShape = listOf(4, 5, 6, 7, 8, 9, 10)), + inputShapeArray = longArrayOf(100, 4 * 5 * 6, 7 * 10, 8 * 9), + expectedOutputShape = longArrayOf(100, 4, 5, 6, 7, 8, 9, 10) + ) + } +} \ No newline at end of file