From 0a8a32ddf9543e03f08df049a2e004957eea34bc Mon Sep 17 00:00:00 2001 From: Julia Beliaeva Date: Mon, 14 Feb 2022 02:07:34 +0300 Subject: [PATCH] Compute transposed convolution output size the same way as keras does --- .../core/layer/convolutional/ConvTranspose.kt | 11 +++-- .../kotlinx/dl/api/core/shape/ConvUtil.kt | 40 +++++++++++++++---- .../dl/api/core/layer/Conv1DTransposeTest.kt | 2 +- .../dl/api/core/layer/Conv2DTransposeTest.kt | 7 ++-- .../dl/api/core/layer/Conv3DTransposeTest.kt | 13 ++++-- 5 files changed, 54 insertions(+), 19 deletions(-) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/ConvTranspose.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/ConvTranspose.kt index d411b47e1..28042992e 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/ConvTranspose.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/ConvTranspose.kt @@ -9,7 +9,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.toLongArray import org.jetbrains.kotlinx.dl.api.core.layer.toLongList import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape import org.jetbrains.kotlinx.dl.api.core.shape.convTransposeOutputLength -import org.jetbrains.kotlinx.dl.api.core.shape.convTransposeSingleSidePadding +import org.jetbrains.kotlinx.dl.api.core.shape.convTransposePadding import org.jetbrains.kotlinx.dl.api.core.shape.shapeFromDims import org.jetbrains.kotlinx.dl.api.core.util.convTransposeBiasVarName import org.jetbrains.kotlinx.dl.api.core.util.convTransposeKernelVarName @@ -94,9 +94,12 @@ public abstract class ConvTranspose( dilations: IntArray ): IntArray { val withStandardPadding = kernelSize.indices.flatMap { dim -> - listOf( - convTransposeSingleSidePadding(padding, this[2 * dim], kernelSize[dim], dilations[dim + 1]), - convTransposeSingleSidePadding(padding, this[2 * dim + 1], kernelSize[dim], dilations[dim + 1]) + convTransposePadding( + padding, + this[2 * dim], + this[2 * dim + 1], + kernelSize[dim], + dilations[dim + 1] ) } return intArrayOf(0, 0, *(withStandardPadding.toIntArray()), 0, 0) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/shape/ConvUtil.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/shape/ConvUtil.kt index 67920e488..57708e325 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/shape/ConvUtil.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/shape/ConvUtil.kt @@ -48,16 +48,40 @@ internal fun convTransposeOutputLength( dilation: Int ): Long { val dilatedFilterSize = dilatedFilterSize(filterSize, dilation) - val totalPadding = (outputPaddingStart ?: 0) + (outputPaddingEnd ?: 0) - padding.value(dilatedFilterSize) + if (outputPaddingEnd == null || outputPaddingStart == null) { + // https://github.com/keras-team/keras/blob/cff8cc93305d1c4a54385fb623fe895dafa0845c/keras/utils/conv_utils.py#L185 + return when (padding) { + ConvPadding.VALID -> inputLength * stride + max(dilatedFilterSize - stride, 0) + ConvPadding.SAME -> inputLength * stride + ConvPadding.FULL -> inputLength * stride - (stride + dilatedFilterSize - 2) + } + } + val totalPadding = convTransposePadding(padding, outputPaddingStart, outputPaddingEnd, filterSize, dilation).sum() return (inputLength - 1) * stride + dilatedFilterSize + totalPadding } -internal fun convTransposeSingleSidePadding(padding: ConvPadding, - outputPadding: Int, - filterSize: Int, - dilation: Int -): Int { +private fun convTransposePadding( + padding: ConvPadding, + outputPaddingStart: Int, + outputPaddingEnd: Int, + dilatedKernelSize: Int +): List { + // https://github.com/keras-team/keras/blob/cff8cc93305d1c4a54385fb623fe895dafa0845c/keras/utils/conv_utils.py#L194 + val automaticPadding = when (padding) { + ConvPadding.VALID -> 0 + ConvPadding.SAME -> dilatedKernelSize / 2 + ConvPadding.FULL -> dilatedKernelSize - 1 + } + return listOf(outputPaddingStart - automaticPadding, outputPaddingEnd - automaticPadding) +} + +internal fun convTransposePadding( + padding: ConvPadding, + outputPaddingStart: Int, + outputPaddingEnd: Int, + filterSize: Int, + dilation: Int +): List { val dilatedKernelSize = dilatedFilterSize(filterSize, dilation) - val automaticPadding = padding.value(dilatedKernelSize) / 2 - return max(0, outputPadding - automaticPadding) + return convTransposePadding(padding, outputPaddingStart, outputPaddingEnd, dilatedKernelSize).map { max(0, it) } } \ No newline at end of file diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv1DTransposeTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv1DTransposeTest.kt index aba268300..e837408b2 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv1DTransposeTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv1DTransposeTest.kt @@ -106,7 +106,7 @@ class Conv1DTransposeTest : ConvLayerTest() { @Test fun samePaddingStrides() { val input = create1DTensor(batchSize = 1, size = 3, channels = 32, initValue = 1f) - val expected = create1DTensor(batchSize = 1, channels = 3, floatArrayOf(32f, 64f, 32f, 64f, 32f)) + val expected = create1DTensor(batchSize = 1, channels = 3, floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f)) assertTensorsEquals( Conv1DTranspose( diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv2DTransposeTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv2DTransposeTest.kt index 1b04c2736..8892e517b 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv2DTransposeTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv2DTransposeTest.kt @@ -130,9 +130,10 @@ class Conv2DTransposeTest : ConvLayerTest() { val input = create2DTensor(batchSize = 1, height = 2, width = 3, channels = 32, initValue = 1f) val expected = create2DTensor( batchSize = 1, channels = 3, - floatArrayOf(32f, 64f, 32f, 64f, 32f), - floatArrayOf(64f, 128f, 64f, 128f, 64f), - floatArrayOf(32f, 64f, 32f, 64f, 32f) + floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f), + floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f), + floatArrayOf(64f, 64f, 128f, 64f, 128f, 64f), + floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f) ) assertTensorsEquals( diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv3DTransposeTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv3DTransposeTest.kt index 62ce6da13..7e2e8337f 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv3DTransposeTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Conv3DTransposeTest.kt @@ -138,9 +138,16 @@ class Conv3DTransposeTest : ConvLayerTest() { val expected = create3DTensor( batchSize = 1, channels = 3, arrayOf( - floatArrayOf(32f, 64f, 32f, 64f, 32f), - floatArrayOf(64f, 128f, 64f, 128f, 64f), - floatArrayOf(32f, 64f, 32f, 64f, 32f) + floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f), + floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f), + floatArrayOf(64f, 64f, 128f, 64f, 128f, 64f), + floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f), + ), + arrayOf( + floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f), + floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f), + floatArrayOf(64f, 64f, 128f, 64f, 128f, 64f), + floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f), ) )