Skip to content

Commit

Permalink
Add utility functions to draw detection results on android Canvas (#450)
Browse files Browse the repository at this point in the history
* Migrate "visualization" module to multiplatform

* Add utility functions and classes to draw detection results on android

* Refactor function parameters and add kdoc information
  • Loading branch information
juliabeliaeva authored Sep 23, 2022
1 parent 794259d commit dc97543
Show file tree
Hide file tree
Showing 14 changed files with 312 additions and 23 deletions.
5 changes: 4 additions & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
kotlin.code.style=official
org.gradle.daemon=false

majorVersion=0.5
minorVersion=SNAPSHOT
minorVersion=SNAPSHOT

android.useAndroidX=true
2 changes: 1 addition & 1 deletion gradle/publish.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
apply plugin: 'maven-publish'
apply from: project.rootProject.file('gradle/maven-metadata.gradle')

def isMultiplatform = project.name in ["dataset", "api", "onnx"]
def isMultiplatform = project.name in ["dataset", "api", "onnx", "visualization"]

task javadocJar(type: Jar) {
classifier 'javadoc'
Expand Down
61 changes: 40 additions & 21 deletions visualization/build.gradle
Original file line number Diff line number Diff line change
@@ -1,29 +1,48 @@
project.setDescription("This module contains the Kotlin API for visualization of the Deep Learning models built with the KotlinDL.")

apply plugin: 'kotlin'
apply plugin: 'kotlin-multiplatform'
apply plugin: 'com.android.library'

dependencies {
api project(":api")
api project(":tensorflow")
}

dependencies {
def excludeSLF4JImpl = { exclude group: 'org.slf4j', module: 'slf4j-simple' }
kotlin {
jvm {
compilations.all {
kotlinOptions.jvmTarget = '1.8'
}
}
android {
publishLibraryVariants("release")
}
sourceSets {
jvmMain {
dependencies {
api project(":api")
api project(":tensorflow")

api 'org.jetbrains.lets-plot:lets-plot-batik:2.3.0', excludeSLF4JImpl
api 'org.jetbrains.lets-plot:lets-plot-common:2.3.0', excludeSLF4JImpl
api 'org.jetbrains.lets-plot:lets-plot-kotlin-api:2.0.1', excludeSLF4JImpl
}

compileKotlin {
kotlinOptions.jvmTarget = "1.8"
}
def excludeSLF4JImpl = { exclude group: 'org.slf4j', module: 'slf4j-simple' }

task sourcesJar(type: Jar) {
classifier 'sources'
from sourceSets.main.allSource
api 'org.jetbrains.lets-plot:lets-plot-batik:2.3.0', excludeSLF4JImpl
api 'org.jetbrains.lets-plot:lets-plot-common:2.3.0', excludeSLF4JImpl
api 'org.jetbrains.lets-plot:lets-plot-kotlin-api:2.0.1', excludeSLF4JImpl
}
}
androidMain {
dependencies {
api project(":api")
api "androidx.camera:camera-view:1.0.0-alpha22"
}
}
}
}

artifacts {
archives sourcesJar
android {
compileSdkVersion 31
namespace = 'org.jetbrains.kotlinx.dl.visualization'
defaultConfig {
minSdkVersion 24
targetSdkVersion 31
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.visualization

import android.content.Context
import android.graphics.Canvas
import android.util.AttributeSet
import android.view.View

/**
* Base class for [View] implementations which visualize detected results on top of the image preview.
* Derived classes should implement [drawDetection] method to perform actual drawing.
*/
abstract class DetectorViewBase<T>(context: Context, attrs: AttributeSet) : View(context, attrs) {
/**
* Detection result to visualize
*/
private var _detection: T? = null

/**
* Draw given detection result on the [Canvas].
*/
abstract fun Canvas.drawDetection(detection: T)

/**
* Called when a new detection result is set.
*/
open fun onDetectionSet(detection: T?) = Unit

/**
* Set current detection result or null if nothing was detected.
*/
fun setDetection(detection: T?) {
synchronized(this) {
_detection = detection

onDetectionSet(detection)
postInvalidate()
}
}

override fun onDraw(canvas: Canvas) {
super.onDraw(canvas)

synchronized(this) {
val detection = _detection
if (detection != null) {
canvas.drawDetection(detection)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.visualization

import android.graphics.Canvas
import android.graphics.Paint
import android.graphics.RectF
import android.text.TextPaint
import org.jetbrains.kotlinx.dl.api.inference.facealignment.Landmark
import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject
import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose
import org.jetbrains.kotlinx.dl.api.inference.posedetection.MultiPoseDetectionResult

/**
* Draw given [detectedObject] on the [Canvas] using [paint] for the bounding box and [labelPaint] for the label.
*
* If the preview image coordinates do not match the [Canvas] coordinates,
* [bounds] of the image preview should be provided.
*
* @see [PreviewImageBounds]
*/
fun Canvas.drawObject(
detectedObject: DetectedObject,
paint: Paint,
labelPaint: TextPaint,
bounds: PreviewImageBounds = bounds()
) {
val rect = RectF(
bounds.toViewX(detectedObject.xMin), bounds.toViewY(detectedObject.yMin),
bounds.toViewX(detectedObject.xMax), bounds.toViewY(detectedObject.yMax)
)
val frameWidth = paint.strokeWidth * detectedObject.probability

drawRect(rect, Paint(paint).apply { strokeWidth = frameWidth })

val label = "${detectedObject.classLabel} : " + "%.2f".format(detectedObject.probability)
drawText(label, rect.left, rect.top - labelPaint.fontMetrics.descent - frameWidth / 2, labelPaint)
}

/**
* Draw given [detectedObjects] on the [Canvas] using [paint] for the bounding box and [labelPaint] for the label.
*
* If the preview image coordinates do not match the [Canvas] coordinates,
* [bounds] of the image preview should be provided.
*
* @see [PreviewImageBounds]
*/
fun Canvas.drawObjects(
detectedObjects: List<DetectedObject>,
paint: Paint,
labelPaint: TextPaint,
bounds: PreviewImageBounds = bounds()
) {
detectedObjects.forEach { drawObject(it, paint, labelPaint, bounds) }
}

/**
* Draw given [detectedPose] on the [Canvas] using [landmarkPaint] and [landmarkRadius] for the pose vertices,
* and [edgePaint] for the pose edges.
*
* If the preview image coordinates do not match the [Canvas] coordinates,
* [bounds] of the image preview should be provided.
*
* @see [PreviewImageBounds]
*/
fun Canvas.drawPose(
detectedPose: DetectedPose,
landmarkPaint: Paint,
edgePaint: Paint, landmarkRadius: Float,
bounds: PreviewImageBounds = bounds()
) {
detectedPose.edges.forEach { edge ->
drawLine(
bounds.toViewX(edge.start.x), bounds.toViewY(edge.start.y),
bounds.toViewX(edge.end.x), bounds.toViewY(edge.end.y),
edgePaint
)
}

detectedPose.poseLandmarks.forEach { landmark ->
drawCircle(bounds.toViewX(landmark.x), bounds.toViewY(landmark.y), landmarkRadius, landmarkPaint)
}
}

/**
* Draw given [detectedPoses] on the [Canvas] using [landmarkPaint] and [landmarkRadius] for the pose vertices,
* [edgePaint] for the poses edges, [objectPaint] for the bounding box and [labelPaint] for the label.
*
* If the preview image coordinates do not match the [Canvas] coordinates,
* [bounds] of the image preview should be provided.
*
* @see [PreviewImageBounds]
*/
fun Canvas.drawMultiplePoses(
detectedPoses: MultiPoseDetectionResult,
landmarkPaint: Paint,
edgePaint: Paint,
objectPaint: Paint,
labelPaint: TextPaint,
landmarkRadius: Float,
bounds: PreviewImageBounds = bounds()
) {
detectedPoses.multiplePoses.forEach { (detectedObject, detectedPose) ->
drawPose(detectedPose, landmarkPaint, edgePaint, landmarkRadius, bounds)
drawObject(detectedObject, objectPaint, labelPaint, bounds)
}
}

/**
* Draw given [landmarks] on the [Canvas] using [paint] and [radius].
*
* If the preview image coordinates do not match the [Canvas] coordinates,
* [bounds] of the image preview should be provided.
*
* @see [PreviewImageBounds]
*/
fun Canvas.drawLandmarks(landmarks: List<Landmark>,
paint: Paint, radius: Float,
bounds: PreviewImageBounds = bounds()
) {
landmarks.forEach { landmark ->
drawCircle(bounds.toViewX(landmark.xRate), bounds.toViewY(landmark.yRate), radius, paint)
}
}

/**
* Create [PreviewImageBounds] originating in the top-left corner of this [Canvas] object and matching its dimensions.
*/
fun Canvas.bounds() = PreviewImageBounds(0f, 0f, width.toFloat(), height.toFloat())
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.visualization

import androidx.camera.view.PreviewView
import kotlin.math.max
import kotlin.math.min

/**
* Defines location and size of the actual preview image relative to the [PreviewView].
* This information can be used to convert from the image coordinate system to the view coordinate system.
*
* @property [x] x-coordinate of the top-left corner of the preview image relative to the [PreviewView] component
* @property [y] y-coordinate of the top-left corner of the preview image relative to the [PreviewView] component
* @property [width] width of the preview image
* @property [height] height of the preview image
*
* @see getPreviewImageBounds
*/
data class PreviewImageBounds(val x: Float, val y: Float, val width: Float, val height: Float) {
fun toViewX(imageX: Float) = imageX * width + x
fun toViewY(imageY: Float) = imageY * height + y
}

/**
* Calculate the location of the preview image top-left corner (relative to the component top-left corner)
* and dimensions, to be used for displaying detected objects, for example with the [DetectorViewBase].
*
* When camera preview resolution differs from the dimensions of the [PreviewView] used to display camera input,
* image is scaled and cropped or padded according to the provided [PreviewView.ScaleType]. Because of this,
* in order to display detected objects on the [PreviewView], their coordinates need to be converted.
* This method returns [PreviewImageBounds] object containing the necessary information to preform the conversion
* from the image coordinate system to the view coordinate system.
*
* @param [sourceImageWidth] width of the image from the camera
* @param [sourceImageHeight] height of the image from the camera
* @param [viewWidth] width of the target [PreviewView]
* @param [viewHeight] height of the target [PreviewView]
* @param [scaleType] scaling option used in the target [PreviewView]
*
* @see <a href="https://developer.android.com/training/camerax/preview#scale-type">Scale type</a>
*/
fun getPreviewImageBounds(
sourceImageWidth: Int,
sourceImageHeight: Int,
viewWidth: Int,
viewHeight: Int,
scaleType: PreviewView.ScaleType
): PreviewImageBounds {
val scale = if (scaleType == PreviewView.ScaleType.FILL_START ||
scaleType == PreviewView.ScaleType.FILL_END ||
scaleType == PreviewView.ScaleType.FILL_CENTER
) {
max(viewWidth.toFloat() / sourceImageWidth, viewHeight.toFloat() / sourceImageHeight)
} else {
min(viewWidth.toFloat() / sourceImageWidth, viewHeight.toFloat() / sourceImageHeight)
}
val previewImageWidth = sourceImageWidth * scale
val previewImageHeight = sourceImageHeight * scale
return when (scaleType) {
PreviewView.ScaleType.FILL_START, PreviewView.ScaleType.FIT_START -> {
PreviewImageBounds(0f, 0f, previewImageWidth, previewImageHeight)
}
PreviewView.ScaleType.FILL_END, PreviewView.ScaleType.FIT_END -> {
PreviewImageBounds(
viewWidth - previewImageWidth, viewHeight - previewImageHeight,
previewImageWidth, previewImageHeight
)
}
else -> {
PreviewImageBounds(
viewWidth / 2 - previewImageWidth / 2, viewHeight / 2 - previewImageHeight / 2,
previewImageWidth, previewImageHeight
)
}
}
}

0 comments on commit dc97543

Please sign in to comment.