Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility functions to draw detection results on android Canvas #450

Merged
merged 3 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
kotlin.code.style=official
org.gradle.daemon=false

majorVersion=0.5
minorVersion=SNAPSHOT
minorVersion=SNAPSHOT

android.useAndroidX=true
2 changes: 1 addition & 1 deletion gradle/publish.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
apply plugin: 'maven-publish'
apply from: project.rootProject.file('gradle/maven-metadata.gradle')

def isMultiplatform = project.name in ["dataset", "api", "onnx"]
def isMultiplatform = project.name in ["dataset", "api", "onnx", "visualization"]

task javadocJar(type: Jar) {
classifier 'javadoc'
Expand Down
61 changes: 40 additions & 21 deletions visualization/build.gradle
Original file line number Diff line number Diff line change
@@ -1,29 +1,48 @@
project.setDescription("This module contains the Kotlin API for visualization of the Deep Learning models built with the KotlinDL.")

apply plugin: 'kotlin'
apply plugin: 'kotlin-multiplatform'
apply plugin: 'com.android.library'

dependencies {
api project(":api")
api project(":tensorflow")
}

dependencies {
def excludeSLF4JImpl = { exclude group: 'org.slf4j', module: 'slf4j-simple' }
kotlin {
jvm {
compilations.all {
kotlinOptions.jvmTarget = '1.8'
}
}
android {
publishLibraryVariants("release")
}
sourceSets {
jvmMain {
dependencies {
api project(":api")
api project(":tensorflow")

api 'org.jetbrains.lets-plot:lets-plot-batik:2.3.0', excludeSLF4JImpl
api 'org.jetbrains.lets-plot:lets-plot-common:2.3.0', excludeSLF4JImpl
api 'org.jetbrains.lets-plot:lets-plot-kotlin-api:2.0.1', excludeSLF4JImpl
}

compileKotlin {
kotlinOptions.jvmTarget = "1.8"
}
def excludeSLF4JImpl = { exclude group: 'org.slf4j', module: 'slf4j-simple' }

task sourcesJar(type: Jar) {
classifier 'sources'
from sourceSets.main.allSource
api 'org.jetbrains.lets-plot:lets-plot-batik:2.3.0', excludeSLF4JImpl
api 'org.jetbrains.lets-plot:lets-plot-common:2.3.0', excludeSLF4JImpl
api 'org.jetbrains.lets-plot:lets-plot-kotlin-api:2.0.1', excludeSLF4JImpl
}
}
androidMain {
dependencies {
api project(":api")
api "androidx.camera:camera-view:1.0.0-alpha22"
}
}
}
}

artifacts {
archives sourcesJar
android {
compileSdkVersion 31
namespace = 'org.jetbrains.kotlinx.dl.visualization'
defaultConfig {
minSdkVersion 24
targetSdkVersion 31
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.visualization

import android.content.Context
import android.graphics.Canvas
import android.util.AttributeSet
import android.view.View

/**
* Base class for [View] implementations which visualize detected results on top of the image preview.
* Derived classes should implement [drawDetection] method to perform actual drawing.
*/
abstract class DetectorViewBase<T>(context: Context, attrs: AttributeSet) : View(context, attrs) {
/**
* Detection result to visualize
*/
private var _detection: T? = null

/**
* Draw given detection result on the [Canvas].
*/
abstract fun Canvas.drawDetection(detection: T)

/**
* Called when a new detection result is set.
*/
open fun onDetectionSet(detection: T?) = Unit

/**
* Set current detection result or null if nothing was detected.
*/
fun setDetection(detection: T?) {
synchronized(this) {
_detection = detection

onDetectionSet(detection)
postInvalidate()
}
}

override fun onDraw(canvas: Canvas) {
super.onDraw(canvas)

synchronized(this) {
val detection = _detection
if (detection != null) {
canvas.drawDetection(detection)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.visualization

import android.graphics.Canvas
import android.graphics.Paint
import android.graphics.RectF
import android.text.TextPaint
import org.jetbrains.kotlinx.dl.api.inference.facealignment.Landmark
import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject
import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose
import org.jetbrains.kotlinx.dl.api.inference.posedetection.MultiPoseDetectionResult

/**
* Draw given [detectedObject] on the [Canvas] using [paint] for the bounding box and [labelPaint] for the label.
*
* If the preview image coordinates do not match the [Canvas] coordinates,
* [bounds] of the image preview should be provided.
*
* @see [PreviewImageBounds]
*/
fun Canvas.drawObject(
detectedObject: DetectedObject,
paint: Paint,
labelPaint: TextPaint,
bounds: PreviewImageBounds = bounds()
) {
val rect = RectF(
bounds.toViewX(detectedObject.xMin), bounds.toViewY(detectedObject.yMin),
bounds.toViewX(detectedObject.xMax), bounds.toViewY(detectedObject.yMax)
)
val frameWidth = paint.strokeWidth * detectedObject.probability

drawRect(rect, Paint(paint).apply { strokeWidth = frameWidth })

val label = "${detectedObject.classLabel} : " + "%.2f".format(detectedObject.probability)
drawText(label, rect.left, rect.top - labelPaint.fontMetrics.descent - frameWidth / 2, labelPaint)
}

/**
* Draw given [detectedObjects] on the [Canvas] using [paint] for the bounding box and [labelPaint] for the label.
*
* If the preview image coordinates do not match the [Canvas] coordinates,
* [bounds] of the image preview should be provided.
*
* @see [PreviewImageBounds]
*/
fun Canvas.drawObjects(
detectedObjects: List<DetectedObject>,
paint: Paint,
labelPaint: TextPaint,
bounds: PreviewImageBounds = bounds()
) {
detectedObjects.forEach { drawObject(it, paint, labelPaint, bounds) }
}

/**
* Draw given [detectedPose] on the [Canvas] using [landmarkPaint] and [landmarkRadius] for the pose vertices,
* and [edgePaint] for the pose edges.
*
* If the preview image coordinates do not match the [Canvas] coordinates,
* [bounds] of the image preview should be provided.
*
* @see [PreviewImageBounds]
*/
fun Canvas.drawPose(
detectedPose: DetectedPose,
landmarkPaint: Paint,
edgePaint: Paint, landmarkRadius: Float,
bounds: PreviewImageBounds = bounds()
) {
detectedPose.edges.forEach { edge ->
drawLine(
bounds.toViewX(edge.start.x), bounds.toViewY(edge.start.y),
bounds.toViewX(edge.end.x), bounds.toViewY(edge.end.y),
edgePaint
)
}

detectedPose.poseLandmarks.forEach { landmark ->
drawCircle(bounds.toViewX(landmark.x), bounds.toViewY(landmark.y), landmarkRadius, landmarkPaint)
}
}

/**
* Draw given [detectedPoses] on the [Canvas] using [landmarkPaint] and [landmarkRadius] for the pose vertices,
* [edgePaint] for the poses edges, [objectPaint] for the bounding box and [labelPaint] for the label.
*
* If the preview image coordinates do not match the [Canvas] coordinates,
* [bounds] of the image preview should be provided.
*
* @see [PreviewImageBounds]
*/
fun Canvas.drawMultiplePoses(
detectedPoses: MultiPoseDetectionResult,
landmarkPaint: Paint,
edgePaint: Paint,
objectPaint: Paint,
labelPaint: TextPaint,
landmarkRadius: Float,
bounds: PreviewImageBounds = bounds()
) {
detectedPoses.multiplePoses.forEach { (detectedObject, detectedPose) ->
drawPose(detectedPose, landmarkPaint, edgePaint, landmarkRadius, bounds)
drawObject(detectedObject, objectPaint, labelPaint, bounds)
}
}

/**
* Draw given [landmarks] on the [Canvas] using [paint] and [radius].
*
* If the preview image coordinates do not match the [Canvas] coordinates,
* [bounds] of the image preview should be provided.
*
* @see [PreviewImageBounds]
*/
fun Canvas.drawLandmarks(landmarks: List<Landmark>,
paint: Paint, radius: Float,
bounds: PreviewImageBounds = bounds()
) {
landmarks.forEach { landmark ->
drawCircle(bounds.toViewX(landmark.xRate), bounds.toViewY(landmark.yRate), radius, paint)
}
}

/**
* Create [PreviewImageBounds] originating in the top-left corner of this [Canvas] object and matching its dimensions.
*/
fun Canvas.bounds() = PreviewImageBounds(0f, 0f, width.toFloat(), height.toFloat())
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.visualization

import androidx.camera.view.PreviewView
import kotlin.math.max
import kotlin.math.min

/**
* Defines location and size of the actual preview image relative to the [PreviewView].
* This information can be used to convert from the image coordinate system to the view coordinate system.
*
* @property [x] x-coordinate of the top-left corner of the preview image relative to the [PreviewView] component
* @property [y] y-coordinate of the top-left corner of the preview image relative to the [PreviewView] component
* @property [width] width of the preview image
* @property [height] height of the preview image
*
* @see getPreviewImageBounds
*/
data class PreviewImageBounds(val x: Float, val y: Float, val width: Float, val height: Float) {
fun toViewX(imageX: Float) = imageX * width + x
fun toViewY(imageY: Float) = imageY * height + y
}

/**
* Calculate the location of the preview image top-left corner (relative to the component top-left corner)
* and dimensions, to be used for displaying detected objects, for example with the [DetectorViewBase].
*
* When camera preview resolution differs from the dimensions of the [PreviewView] used to display camera input,
* image is scaled and cropped or padded according to the provided [PreviewView.ScaleType]. Because of this,
* in order to display detected objects on the [PreviewView], their coordinates need to be converted.
* This method returns [PreviewImageBounds] object containing the necessary information to preform the conversion
* from the image coordinate system to the view coordinate system.
*
* @param [sourceImageWidth] width of the image from the camera
* @param [sourceImageHeight] height of the image from the camera
* @param [viewWidth] width of the target [PreviewView]
* @param [viewHeight] height of the target [PreviewView]
* @param [scaleType] scaling option used in the target [PreviewView]
*
* @see <a href="https://developer.android.com/training/camerax/preview#scale-type">Scale type</a>
*/
fun getPreviewImageBounds(
sourceImageWidth: Int,
sourceImageHeight: Int,
viewWidth: Int,
viewHeight: Int,
scaleType: PreviewView.ScaleType
): PreviewImageBounds {
val scale = if (scaleType == PreviewView.ScaleType.FILL_START ||
scaleType == PreviewView.ScaleType.FILL_END ||
scaleType == PreviewView.ScaleType.FILL_CENTER
) {
max(viewWidth.toFloat() / sourceImageWidth, viewHeight.toFloat() / sourceImageHeight)
} else {
min(viewWidth.toFloat() / sourceImageWidth, viewHeight.toFloat() / sourceImageHeight)
}
val previewImageWidth = sourceImageWidth * scale
val previewImageHeight = sourceImageHeight * scale
return when (scaleType) {
PreviewView.ScaleType.FILL_START, PreviewView.ScaleType.FIT_START -> {
PreviewImageBounds(0f, 0f, previewImageWidth, previewImageHeight)
}
PreviewView.ScaleType.FILL_END, PreviewView.ScaleType.FIT_END -> {
PreviewImageBounds(
viewWidth - previewImageWidth, viewHeight - previewImageHeight,
previewImageWidth, previewImageHeight
)
}
else -> {
PreviewImageBounds(
viewWidth / 2 - previewImageWidth / 2, viewHeight / 2 - previewImageHeight / 2,
previewImageWidth, previewImageHeight
)
}
}
}