Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor output processing in OnnxInferenceModel prediction methods #465

Merged
merged 6 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/src/main/kotlin/examples/onnx/faces/Fan2D106.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.facealignment.Landmark
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray
import org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment.Fan2D106FaceAlignmentModel
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter
Expand Down Expand Up @@ -53,11 +54,10 @@ fun main() {
val inputImage = ImageConverter.toBufferedImage(inputFile)
val inputData = preprocessor.apply(inputImage).first

val yhat = it.predictRaw(inputData)
println(yhat.values.toTypedArray().contentDeepToString())
val floats = it.predictRaw(inputData) { output -> output.getFloatArray("fc1") }
println(floats.contentToString())

val landMarks = mutableListOf<Landmark>()
val floats = (yhat["fc1"] as Array<*>)[0] as FloatArray
for (j in floats.indices step 2) {
landMarks.add(Landmark((1 + floats[j]) / 2, (1 + floats[j + 1]) / 2))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package examples.onnx.objectdetection.ssd
import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessing.call
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
Expand Down Expand Up @@ -45,10 +46,10 @@ fun ssd() {
val inputData = preprocessing.load(getFileFromResource("datasets/detection/image$i.jpg")).first

val start = System.currentTimeMillis()
val yhat = it.predictRaw(inputData)
val yhat = it.predictRaw(inputData) { output -> output.getFloatArray(0)}
val end = System.currentTimeMillis()
println("Prediction took ${end - start} ms")
println(yhat.values.toTypedArray().contentDeepToString())
println(yhat.contentToString())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray
import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose
import org.jetbrains.kotlinx.dl.api.inference.posedetection.MultiPoseDetectionResult
import org.jetbrains.kotlinx.dl.api.inference.posedetection.PoseLandmark
Expand Down Expand Up @@ -51,10 +52,11 @@ fun multiPoseDetectionMoveNet() {
.call(modelType.preprocessor)

val inputData = preprocessor.apply(inputImage).first
val yhat = it.predictRaw(inputData)
println(yhat.values.toTypedArray().contentDeepToString())
val rawPoseLandmarks = it.predictRaw(inputData) { result ->
result.get2DFloatArray("output_0")
}
println(rawPoseLandmarks.contentDeepToString())

val rawPoseLandmarks = (yhat["output_0"] as Array<Array<FloatArray>>)[0]
val poses = rawPoseLandmarks.mapNotNull { floats ->
val probability = floats[55]
if (probability < 0.05) return@mapNotNull null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package examples.onnx.posedetection.singlepose
import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray
import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose
import org.jetbrains.kotlinx.dl.api.inference.posedetection.PoseLandmark
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
Expand Down Expand Up @@ -49,10 +50,10 @@ fun poseDetectionMoveNet() {

val inputData = preprocessing.apply(image).first

val yhat = it.predictRaw(inputData)
println(yhat.values.toTypedArray().contentDeepToString())

val rawPoseLandMarks = (yhat["output_0"] as Array<Array<Array<FloatArray>>>)[0][0]
val rawPoseLandMarks = it.predictRaw(inputData) { result ->
result.get2DFloatArray("output_0")
}
println(rawPoseLandMarks.contentDeepToString())

// Dictionary that maps from joint names to keypoint indices.
val keypoints = mapOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package examples.onnx.faces
import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessing.call
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
Expand Down Expand Up @@ -56,8 +57,8 @@ class FacesTestSuite {
val imageFile = getFileFromResource("datasets/faces/image$i.jpg")
val inputData = fileDataLoader.load(imageFile).first

val yhat = it.predictRaw(inputData)
assertEquals(212, (yhat.values.toTypedArray()[0] as Array<FloatArray>)[0].size)
val yhat = it.predictRaw(inputData) { output -> output.getFloatArray(0) }
assertEquals(212, yhat.size)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package examples.onnx.posedetection
import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessing.call
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
Expand Down Expand Up @@ -84,9 +85,9 @@ class PoseDetectionTestSuite {

val inputData = fileDataLoader.load(imageFile).first

val yhat = it.predictRaw(inputData)

val rawPoseLandMarks = (yhat["output_0"] as Array<Array<Array<FloatArray>>>)[0][0]
val rawPoseLandMarks = it.predictRaw(inputData) { result ->
result.get2DFloatArray("output_0")
}

assertEquals(17, rawPoseLandMarks.size)
}
Expand All @@ -113,9 +114,9 @@ class PoseDetectionTestSuite {

val inputData = preprocessing.load(imageFile).first

val yhat = it.predictRaw(inputData)

val rawPoseLandMarks = (yhat["output_0"] as Array<Array<Array<FloatArray>>>)[0][0]
val rawPoseLandMarks = it.predictRaw(inputData) { result ->
result.get2DFloatArray("output_0")
}

assertEquals(17, rawPoseLandMarks.size)
}
Expand All @@ -142,10 +143,10 @@ class PoseDetectionTestSuite {


val inputData = dataLoader.load(imageFile).first
val yhat = inferenceModel.predictRaw(inputData)
println(yhat.values.toTypedArray().contentDeepToString())

val rawPosesLandMarks = (yhat["output_0"] as Array<Array<FloatArray>>)[0]
val rawPosesLandMarks = inferenceModel.predictRaw(inputData) { result ->
result.get2DFloatArray("output_0")
}
println(rawPosesLandMarks.contentDeepToString())

assertEquals(6, rawPosesLandMarks.size)
rawPosesLandMarks.forEach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.jetbrains.kotlinx.dl.api.inference.onnx

import ai.onnxruntime.OrtSession
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
Expand All @@ -29,15 +30,14 @@ public interface OnnxHighLevelModel<I, R> : ExecutionProviderCompatible {
/**
* Converts raw model output to the result.
*/
public fun convert(output: Map<String, Any>): R
public fun convert(output: OrtSession.Result): R

/**
* Makes prediction on the given [input].
*/
public fun predict(input: I): R {
val preprocessedInput = preprocessing.apply(input)
val output = internalModel.predictRaw(preprocessedInput.first)
return convert(output)
return internalModel.predictRaw(preprocessedInput.first) { convert(it) }
}

override fun initializeWith(vararg executionProviders: ExecutionProvider) {
Expand Down
Loading