Skip to content

Commit

Permalink
Compute transposed convolution output size the same way as keras does
Browse files Browse the repository at this point in the history
  • Loading branch information
juliabeliaeva committed Feb 17, 2022
1 parent b19a00c commit 0a8a32d
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.toLongArray
import org.jetbrains.kotlinx.dl.api.core.layer.toLongList
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.core.shape.convTransposeOutputLength
import org.jetbrains.kotlinx.dl.api.core.shape.convTransposeSingleSidePadding
import org.jetbrains.kotlinx.dl.api.core.shape.convTransposePadding
import org.jetbrains.kotlinx.dl.api.core.shape.shapeFromDims
import org.jetbrains.kotlinx.dl.api.core.util.convTransposeBiasVarName
import org.jetbrains.kotlinx.dl.api.core.util.convTransposeKernelVarName
Expand Down Expand Up @@ -94,9 +94,12 @@ public abstract class ConvTranspose(
dilations: IntArray
): IntArray {
val withStandardPadding = kernelSize.indices.flatMap { dim ->
listOf(
convTransposeSingleSidePadding(padding, this[2 * dim], kernelSize[dim], dilations[dim + 1]),
convTransposeSingleSidePadding(padding, this[2 * dim + 1], kernelSize[dim], dilations[dim + 1])
convTransposePadding(
padding,
this[2 * dim],
this[2 * dim + 1],
kernelSize[dim],
dilations[dim + 1]
)
}
return intArrayOf(0, 0, *(withStandardPadding.toIntArray()), 0, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,40 @@ internal fun convTransposeOutputLength(
dilation: Int
): Long {
val dilatedFilterSize = dilatedFilterSize(filterSize, dilation)
val totalPadding = (outputPaddingStart ?: 0) + (outputPaddingEnd ?: 0) - padding.value(dilatedFilterSize)
if (outputPaddingEnd == null || outputPaddingStart == null) {
// https://github.com/keras-team/keras/blob/cff8cc93305d1c4a54385fb623fe895dafa0845c/keras/utils/conv_utils.py#L185
return when (padding) {
ConvPadding.VALID -> inputLength * stride + max(dilatedFilterSize - stride, 0)
ConvPadding.SAME -> inputLength * stride
ConvPadding.FULL -> inputLength * stride - (stride + dilatedFilterSize - 2)
}
}
val totalPadding = convTransposePadding(padding, outputPaddingStart, outputPaddingEnd, filterSize, dilation).sum()
return (inputLength - 1) * stride + dilatedFilterSize + totalPadding
}

internal fun convTransposeSingleSidePadding(padding: ConvPadding,
outputPadding: Int,
filterSize: Int,
dilation: Int
): Int {
private fun convTransposePadding(
padding: ConvPadding,
outputPaddingStart: Int,
outputPaddingEnd: Int,
dilatedKernelSize: Int
): List<Int> {
// https://github.com/keras-team/keras/blob/cff8cc93305d1c4a54385fb623fe895dafa0845c/keras/utils/conv_utils.py#L194
val automaticPadding = when (padding) {
ConvPadding.VALID -> 0
ConvPadding.SAME -> dilatedKernelSize / 2
ConvPadding.FULL -> dilatedKernelSize - 1
}
return listOf(outputPaddingStart - automaticPadding, outputPaddingEnd - automaticPadding)
}

internal fun convTransposePadding(
padding: ConvPadding,
outputPaddingStart: Int,
outputPaddingEnd: Int,
filterSize: Int,
dilation: Int
): List<Int> {
val dilatedKernelSize = dilatedFilterSize(filterSize, dilation)
val automaticPadding = padding.value(dilatedKernelSize) / 2
return max(0, outputPadding - automaticPadding)
return convTransposePadding(padding, outputPaddingStart, outputPaddingEnd, dilatedKernelSize).map { max(0, it) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class Conv1DTransposeTest : ConvLayerTest() {
@Test
fun samePaddingStrides() {
val input = create1DTensor(batchSize = 1, size = 3, channels = 32, initValue = 1f)
val expected = create1DTensor(batchSize = 1, channels = 3, floatArrayOf(32f, 64f, 32f, 64f, 32f))
val expected = create1DTensor(batchSize = 1, channels = 3, floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f))

assertTensorsEquals(
Conv1DTranspose(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,10 @@ class Conv2DTransposeTest : ConvLayerTest() {
val input = create2DTensor(batchSize = 1, height = 2, width = 3, channels = 32, initValue = 1f)
val expected = create2DTensor(
batchSize = 1, channels = 3,
floatArrayOf(32f, 64f, 32f, 64f, 32f),
floatArrayOf(64f, 128f, 64f, 128f, 64f),
floatArrayOf(32f, 64f, 32f, 64f, 32f)
floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f),
floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f),
floatArrayOf(64f, 64f, 128f, 64f, 128f, 64f),
floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f)
)

assertTensorsEquals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,16 @@ class Conv3DTransposeTest : ConvLayerTest() {
val expected = create3DTensor(
batchSize = 1, channels = 3,
arrayOf(
floatArrayOf(32f, 64f, 32f, 64f, 32f),
floatArrayOf(64f, 128f, 64f, 128f, 64f),
floatArrayOf(32f, 64f, 32f, 64f, 32f)
floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f),
floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f),
floatArrayOf(64f, 64f, 128f, 64f, 128f, 64f),
floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f),
),
arrayOf(
floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f),
floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f),
floatArrayOf(64f, 64f, 128f, 64f, 128f, 64f),
floatArrayOf(32f, 32f, 64f, 32f, 64f, 32f),
)
)

Expand Down

0 comments on commit 0a8a32d

Please sign in to comment.