From 24a2447f5f2ae49189d4d3e1c2876819b6e2b4e2 Mon Sep 17 00:00:00 2001 From: Julia Beliaeva Date: Tue, 25 Jan 2022 10:09:43 +0300 Subject: [PATCH] Allow for multi-dimensional predictions (#328) --- .../kotlinx/dl/api/core/GraphTrainableModel.kt | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt index ba845e68b..c3227423b 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -691,11 +691,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel val tensors = formPredictionAndActivationsTensors(predictionTensorName, testImages, visualizationIsEnabled) - val predictionsTensor = tensors[0] - - val dst = Array(1) { FloatArray(numberOfClasses.toInt()) { 0.0f } } - - predictionsTensor.copyTo(dst) + val prediction = tensors[0].convertTensorToFlattenFloatArray() val activations = mutableListOf() if (visualizationIsEnabled && tensors.size > 1) { @@ -705,7 +701,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel } tensors.forEach { it.close() } - return Pair(dst[0], activations.toList()) + return Pair(prediction, activations.toList()) } }