Skip to content

Commit

Permalink
Fix channels ordering for classification models (Kotlin#400) (Kotlin#401
Browse files Browse the repository at this point in the history
)

* Fix channels ordering for classification models (Kotlin#400)

* Add comment for an inputColorMode field (Kotlin#400)
  • Loading branch information
ermolenkodev authored Jul 12, 2022
1 parent 89e4769 commit 7a9d212
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ public class ImageRecognitionModel(

/**
* Predicts [topK] objects for the given [imageFile].
* Default [Preprocessing] is applied to an image.
*
* @param [imageFile] Input image [File].
* @param [topK] Number of top ranked predictions to return
*
* @see preprocessData
*
* @return The list of pairs <label, probability> sorted from the most probable to the lowest probable.
*/
Expand All @@ -67,8 +73,22 @@ public class ImageRecognitionModel(
return predictTopKImageNetLabels(internalModel, inputData, imageNetClassLabels, topK)
}

/**
* Predicts [topK] objects for the given [imageFile] with a custom [Preprocessing] provided.
*
* @param [imageFile] Input image [File].
* @param [preprocessing] custom [Preprocessing] instance
* @param [topK] Number of top ranked predictions to return
*
* @return The list of pairs <label, probability> sorted from the most probable to the lowest probable.
*/
public fun predictTopKObjects(imageFile: File, preprocessing: Preprocessing, topK: Int = 5): List<Pair<String, Float>> {
val (inputData, _) = preprocessing(imageFile)
return predictTopKImageNetLabels(internalModel, inputData, imageNetClassLabels, topK)
}

private fun preprocessData(imageFile: File): FloatArray {
val (weight, height) = if (modelType.channelsFirst)
val (width, height) = if (modelType.channelsFirst)
Pair(internalModel.inputDimensions[1], internalModel.inputDimensions[2])
else
Pair(internalModel.inputDimensions[0], internalModel.inputDimensions[1])
Expand All @@ -77,10 +97,10 @@ public class ImageRecognitionModel(
transformImage {
resize {
outputHeight = height.toInt()
outputWidth = weight.toInt()
outputWidth = width.toInt()
interpolation = InterpolationType.BILINEAR
}
convert { colorMode = ColorMode.BGR }
convert { colorMode = modelType.inputColorMode }
}
}

Expand All @@ -89,11 +109,28 @@ public class ImageRecognitionModel(

/**
* Predicts object for the given [imageFile].
* Default [Preprocessing] is applied to an image.
*
* @param [imageFile] Input image [File].
* @see preprocessData
*
* @return The label of the recognized object with the highest probability.
*/
public fun predictObject(imageFile: File): String {
val inputData = preprocessData(imageFile)
return imageNetClassLabels[internalModel.predict(inputData)]!!
}

/**
* Predicts object for the given [imageFile] with a custom [Preprocessing] provided.
*
* @param [imageFile] Input image [File].
* @param [preprocessing] custom [Preprocessing] instance
*
* @return The label of the recognized object with the highest probability.
*/
public fun predictObject(imageFile: File, preprocessing: Preprocessing): String {
val (inputData, _) = preprocessing(imageFile)
return imageNetClassLabels[internalModel.predict(inputData)]!!
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionModel
import org.jetbrains.kotlinx.dl.api.inference.keras.loadWeights
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessor.Preprocessing
import java.io.File

Expand All @@ -27,10 +28,10 @@ public object TFModels {
public sealed class CV<T : GraphTrainableModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = false,
override val inputColorMode: ColorMode = ColorMode.RGB,
public var inputShape: IntArray? = null,
internal var noTop: Boolean = false
) :
ModelType<T, ImageRecognitionModel> {
) : ModelType<T, ImageRecognitionModel> {

init {
if (inputShape != null) {
Expand Down Expand Up @@ -61,7 +62,12 @@ public object TFModels {
* Official VGG16 model from Keras.applications.</a>
*/
public class VGG16(noTop: Boolean = false, inputShape: IntArray? = null) :
CV<Sequential>("models/tensorflow/cv/vgg16", inputShape = inputShape, noTop = noTop) {
CV<Sequential>(
"models/tensorflow/cv/vgg16",
inputShape = inputShape,
noTop = noTop,
inputColorMode = ColorMode.BGR
) {
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
return preprocessInput(data, tensorShape, inputType = InputType.CAFFE)
}
Expand All @@ -85,7 +91,12 @@ public object TFModels {
* Official VGG19 model from Keras.applications.</a>
*/
public class VGG19(noTop: Boolean = false, inputShape: IntArray? = null) :
CV<Sequential>("models/tensorflow/cv/vgg19", inputShape = inputShape, noTop = noTop) {
CV<Sequential>(
"models/tensorflow/cv/vgg19",
inputShape = inputShape,
noTop = noTop,
inputColorMode = ColorMode.BGR
) {
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
return preprocessInput(data, tensorShape, inputType = InputType.CAFFE)
}
Expand Down Expand Up @@ -155,7 +166,12 @@ public object TFModels {
* Official ResNet50 model from Keras.applications.</a>
*/
public class ResNet50(noTop: Boolean = false, inputShape: IntArray? = null) :
CV<Functional>("models/tensorflow/cv/resnet50", inputShape = inputShape, noTop = noTop) {
CV<Functional>(
"models/tensorflow/cv/resnet50",
inputShape = inputShape,
noTop = noTop,
inputColorMode = ColorMode.BGR
) {
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
return preprocessInput(data, tensorShape, inputType = InputType.CAFFE)
}
Expand All @@ -181,7 +197,12 @@ public object TFModels {
* Official ResNet101 model from Keras.applications.</a>
*/
public class ResNet101(noTop: Boolean = false, inputShape: IntArray? = null) :
CV<Functional>("models/tensorflow/cv/resnet101", inputShape = inputShape, noTop = noTop) {
CV<Functional>(
"models/tensorflow/cv/resnet101",
inputShape = inputShape,
noTop = noTop,
inputColorMode = ColorMode.BGR
) {
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
return preprocessInput(data, tensorShape, inputType = InputType.CAFFE)
}
Expand All @@ -207,7 +228,12 @@ public object TFModels {
* Official ResNet152 model from Keras.applications.</a>
*/
public class ResNet152(noTop: Boolean = false, inputShape: IntArray? = null) :
CV<Functional>("models/tensorflow/cv/resnet152", inputShape = inputShape, noTop = noTop) {
CV<Functional>(
"models/tensorflow/cv/resnet152",
inputShape = inputShape,
noTop = noTop,
inputColorMode = ColorMode.BGR
) {
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
return preprocessInput(data, tensorShape, inputType = InputType.CAFFE)
}
Expand Down Expand Up @@ -548,6 +574,12 @@ public interface ModelType<T : InferenceModel, U : InferenceModel> {
*/
public val channelsFirst: Boolean

/**
* An expected channels order for the input image.
* Note: the wrong choice of this parameter can significantly impact the model's performance.
*/
public val inputColorMode: ColorMode

/**
* Common preprocessing function for the Neural Networks trained on ImageNet and whose weights are available with the keras.application.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDMobileNetV
import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDObjectDetectionModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.MultiPoseDetectionModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.SinglePoseDetectionModel
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessor.ImageShape
import org.jetbrains.kotlinx.dl.dataset.preprocessor.Transpose

Expand All @@ -26,10 +27,10 @@ public object ONNXModels {
public sealed class CV<T : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean,
override val inputColorMode: ColorMode = ColorMode.RGB,
/** If true, model is shipped without last few layers and could be used for transfer learning and fine-tuning with TF Runtime. */
internal var noTop: Boolean = false
) :
ModelType<T, ImageRecognitionModel> {
) : ModelType<T, ImageRecognitionModel> {
override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel {
return ImageRecognitionModel(modelHub.loadModel(this), this)
}
Expand Down Expand Up @@ -323,7 +324,11 @@ public object ONNXModels {
* Official ResNet model from Keras.applications.</a>
*/
public object ResNet50custom :
CV<OnnxInferenceModel>("models/onnx/cv/custom/resnet50", channelsFirst = false) {
CV<OnnxInferenceModel>(
"models/onnx/cv/custom/resnet50",
channelsFirst = false,
inputColorMode = ColorMode.BGR
) {
override fun preprocessInput(data: FloatArray, tensorShape: LongArray): FloatArray {
return preprocessInput(
data,
Expand Down Expand Up @@ -612,7 +617,8 @@ public object ONNXModels {
/** Object detection models and preprocessing. */
public sealed class ObjectDetection<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) :
ModelType<T, U> {
/**
Expand Down Expand Up @@ -964,7 +970,8 @@ public object ONNXModels {
/** Face alignment models and preprocessing. */
public sealed class FaceAlignment<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) :
ModelType<T, U> {
/**
Expand Down Expand Up @@ -996,7 +1003,8 @@ public object ONNXModels {
/** Pose detection models. */
public sealed class PoseDetection<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) :
ModelType<T, U> {
/**
Expand Down

0 comments on commit 7a9d212

Please sign in to comment.