diff --git a/impl/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/bitmap/Resize.kt b/impl/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/bitmap/Resize.kt index f1dd89483..6e0515a24 100644 --- a/impl/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/bitmap/Resize.kt +++ b/impl/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/bitmap/Resize.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2023 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -35,8 +35,8 @@ public class Resize( public companion object { internal fun createOutputImageShape(inputShape: TensorShape, outputWidth: Int, outputHeight: Int): TensorShape { return when (inputShape.rank()) { - 2 -> TensorShape(outputWidth.toLong(), outputHeight.toLong()) - 3 -> TensorShape(outputWidth.toLong(), outputHeight.toLong(), inputShape[2]) + 2 -> TensorShape(outputHeight.toLong(), outputWidth.toLong()) + 3 -> TensorShape(outputHeight.toLong(), outputWidth.toLong(), inputShape[2]) else -> throw IllegalArgumentException("Input shape is expected to be 2D or 3D") } } diff --git a/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Cropping.kt b/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Cropping.kt index 8fd2b6ab6..b5acdcf8a 100644 --- a/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Cropping.kt +++ b/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Cropping.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2023 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -29,7 +29,7 @@ public class Cropping( ) : Operation { override fun apply(input: BufferedImage): BufferedImage { val croppedImageShape = getOutputShape(input.getShape()) - val (width, height, _) = croppedImageShape.dims() + val (height, width, _) = croppedImageShape.dims() return input.getSubimage( left, top, @@ -39,12 +39,12 @@ public class Cropping( } override fun getOutputShape(inputShape: TensorShape): TensorShape { - val outputWidth = if (inputShape[0] == -1L) -1 else inputShape[0] - left - right - val outputHeight = if (inputShape[1] == -1L) -1 else inputShape[1] - top - bottom + val outputWidth = if (inputShape[1] == -1L) -1 else inputShape[1] - left - right + val outputHeight = if (inputShape[0] == -1L) -1 else inputShape[0] - top - bottom return when (inputShape.rank()) { - 2 -> TensorShape(outputWidth, outputHeight) - 3 -> TensorShape(outputWidth, outputHeight, inputShape[2]) + 2 -> TensorShape(outputHeight, outputWidth) + 3 -> TensorShape(outputHeight, outputWidth, inputShape[2]) else -> throw IllegalArgumentException("Cropping operation is applicable only to images with rank 2 or 3") } } diff --git a/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/ImageUtil.kt b/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/ImageUtil.kt index 6a82b1c1a..9398548b5 100644 --- a/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/ImageUtil.kt +++ b/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/ImageUtil.kt @@ -25,5 +25,5 @@ internal fun BufferedImage.copy(): BufferedImage { } internal fun BufferedImage.getShape(): TensorShape { - return TensorShape(width.toLong(), height.toLong(), colorModel.numComponents.toLong()) + return TensorShape(height.toLong(), width.toLong(), colorModel.numComponents.toLong()) } diff --git a/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Padding.kt b/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Padding.kt index 82ab6a594..2b6951843 100644 --- a/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Padding.kt +++ b/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Padding.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2023 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -47,12 +47,12 @@ public class Padding( } override fun getOutputShape(inputShape: TensorShape): TensorShape { - val outputWidth = if (inputShape[0] == -1L) -1 else inputShape[0] + left + right - val outputHeight = if (inputShape[1] == -1L) -1 else inputShape[1] + top + bottom + val outputWidth = if (inputShape[1] == -1L) -1 else inputShape[1] + left + right + val outputHeight = if (inputShape[0] == -1L) -1 else inputShape[0] + top + bottom return when (inputShape.rank()) { - 2 -> TensorShape(outputWidth, outputHeight) - 3 -> TensorShape(outputWidth, outputHeight, inputShape[2]) + 2 -> TensorShape(outputHeight, outputWidth) + 3 -> TensorShape(outputHeight, outputWidth, inputShape[2]) else -> throw IllegalArgumentException("Padding operation is supported only for 2D and 3D tensors") } } diff --git a/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Resize.kt b/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Resize.kt index 5f29aae91..d0d6b7722 100644 --- a/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Resize.kt +++ b/impl/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/image/Resize.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2023 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -65,8 +65,8 @@ public class Resize( override fun getOutputShape(inputShape: TensorShape): TensorShape { return when (inputShape.rank()) { - 2 -> TensorShape(outputWidth.toLong(), outputHeight.toLong()) - 3 -> TensorShape(outputWidth.toLong(), outputHeight.toLong(), inputShape[2]) + 2 -> TensorShape(outputHeight.toLong(), outputWidth.toLong()) + 3 -> TensorShape(outputHeight.toLong(), outputWidth.toLong(), inputShape[2]) else -> throw IllegalArgumentException("Resize operation is only supported for 2D and 3D tensors.") } } diff --git a/impl/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/PreprocessingFinalShapeTest.kt b/impl/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/PreprocessingFinalShapeTest.kt index ddc7f642b..0f9d0f14f 100644 --- a/impl/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/PreprocessingFinalShapeTest.kt +++ b/impl/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/PreprocessingFinalShapeTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2023 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -47,7 +47,7 @@ class PreprocessingFinalShapeTest { bottom = 7 } .toFloatArray { } - assertEquals(TensorShape(186, 188, 3), preprocess.getOutputShape(TensorShape(200, 200, 3))) + assertEquals(TensorShape(178, 186, 3), preprocess.getOutputShape(TensorShape(190, 200, 3))) } @Test @@ -66,7 +66,7 @@ class PreprocessingFinalShapeTest { bottom = 3 } .toFloatArray { } - assertEquals(TensorShape(180, 180, 3), preprocess.getOutputShape(TensorShape(200, 200, 3))) + assertEquals(TensorShape(380, 180, 3), preprocess.getOutputShape(TensorShape(400, 200, 3))) } @Test @@ -84,7 +84,7 @@ class PreprocessingFinalShapeTest { bottom = 5 } .toFloatArray { } - assertEquals(TensorShape(140, 90, 3), preprocess.getOutputShape(TensorShape(200, 200, 3))) + assertEquals(TensorShape(90, 140, 3), preprocess.getOutputShape(TensorShape(200, 200, 3))) } @Test @@ -107,7 +107,7 @@ class PreprocessingFinalShapeTest { right = 13 } .toFloatArray { } - assertEquals(TensorShape(324, 212, 1), preprocess.getOutputShape(TensorShape(300, 200, 1))) + assertEquals(TensorShape(312, 224, 1), preprocess.getOutputShape(TensorShape(300, 200, 1))) } @Test @@ -127,7 +127,7 @@ class PreprocessingFinalShapeTest { val image = BufferedImage(10, 20, BufferedImage.TYPE_3BYTE_BGR) val (_, actualShape) = preprocess.apply(image) - assertEquals(actualShape, preprocess.getOutputShape(TensorShape(10, 20, 1))) + assertEquals(actualShape, preprocess.getOutputShape(TensorShape(20, 10, 1))) } @Test @@ -139,6 +139,6 @@ class PreprocessingFinalShapeTest { val image = BufferedImage(10, 20, BufferedImage.TYPE_3BYTE_BGR) val (_, actualShape) = preprocess.apply(image) - assertEquals(actualShape, preprocess.getOutputShape(TensorShape(10, 20, 3))) + assertEquals(actualShape, preprocess.getOutputShape(TensorShape(20, 10, 3))) } } diff --git a/impl/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/PreprocessingImageTest.kt b/impl/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/PreprocessingImageTest.kt index de7de6968..041796720 100644 --- a/impl/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/PreprocessingImageTest.kt +++ b/impl/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/impl/preprocessing/PreprocessingImageTest.kt @@ -105,7 +105,7 @@ class PreprocessingImageTest { inputImage.setRGB(1, 1, Color.RED.rgb) val (imageFloats, tensorShape) = preprocess.apply(inputImage) - Assertions.assertEquals(TensorShape(9, 5, 3), tensorShape) + Assertions.assertEquals(TensorShape(5, 9, 3), tensorShape) val expectedImage = FloatArray(tensorShape.numElements().toInt()) { Color.GRAY.red / 255f } expectedImage.setRGB(3, 1, Color.BLUE, tensorShape, ColorMode.BGR) @@ -198,7 +198,7 @@ class PreprocessingImageTest { } } for (i in colorComponents.indices) { - set3D(y, x, i, tensorShape[0].toInt(), colorMode.channels, colorComponents[i]) + set3D(y, x, i, tensorShape[1].toInt(), colorMode.channels, colorComponents[i]) } } }