Skip to content

Commit

Permalink
Fixed #28 issue: Added a support for predictSoftly on the whole dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
zaleslaw committed Jan 14, 2021
1 parent bfc3244 commit 5f4eee8
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ public class Sequential(input: Input, vararg layers: Layer) : TrainableModel() {
return EvaluationResult(avgLossValue, mapOf(Metrics.convertBack(metric) to avgMetricValue))
}

override fun predictAll(dataset: Dataset, batchSize: Int): IntArray {
override fun predict(dataset: Dataset, batchSize: Int): IntArray {
require(dataset.xSize() % batchSize == 0) { "The amount of images must be a multiple of batch size." }
check(isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." }
check(isModelInitialized) { "The model is not initialized yet. Initialize the model weights with init() method or load weights to use this method." }
Expand Down Expand Up @@ -658,6 +658,50 @@ public class Sequential(input: Input, vararg layers: Layer) : TrainableModel() {
return Pair(softPrediction.indexOfFirst { it == softPrediction.maxOrNull()!! }, activations)
}

override fun predictSoftly(dataset: Dataset, batchSize: Int): Array<FloatArray> {
require(dataset.xSize() % batchSize == 0) { "The amount of images must be a multiple of batch size." }
check(isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." }
check(isModelInitialized) { "The model is not initialized yet. Initialize the model weights with init() method or load weights to use this method." }

callback.onPredictBegin()

val imageShape = calculateXShape(batchSize)

val predictions = Array(dataset.xSize()) { FloatArray(amountOfClasses.toInt()) { 0.0f } }

val batchIter: Dataset.BatchIterator = dataset.batchIterator(
batchSize
)

var batchCounter = 0

while (batchIter.hasNext()) {
callback.onPredictBatchBegin(batchCounter, batchSize)

val batch: DataBatch = batchIter.next()

Tensor.create(
imageShape,
batch.x
).use { testImages ->
val predictionsTensor = session.runner()
.fetch(prediction)
.feed(xOp.asOutput(), testImages)
.run()[0]

val dst = Array(imageShape[0].toInt()) { FloatArray(amountOfClasses.toInt()) { 0.0f } }

predictionsTensor.copyTo(dst)

callback.onPredictBatchEnd(batchCounter, batchSize)
batchCounter++
dst.copyInto(predictions, batchSize * (batchCounter - 1))
}
}
callback.onPredictEnd()
return predictions
}

override fun predictSoftly(inputData: FloatArray, predictionTensorName: String): FloatArray {
val (softPrediction, _) = internalPredict(inputData, false, predictionTensorName)
return softPrediction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,17 @@ public abstract class TrainableModel : InferenceModel() {
* @param [batchSize] Number of samples per batch of computation.
* @return Array of labels. The length is equal to the Number of samples on the [dataset].
*/
public abstract fun predictAll(dataset: Dataset, batchSize: Int): IntArray
public abstract fun predict(dataset: Dataset, batchSize: Int): IntArray

/**
* Generates output predictions for the input samples.
* Each prediction is a vector of probabilities instead of specific class in [predict] method.
*
* @param [dataset] Data to predict on.
* @param [batchSize] Number of samples per batch of computation.
* @return Array of labels. Each labels is a vector that represents the probability distributions of a list of potential outcomes. The length is equal to the Number of samples on the [dataset].
*/
public abstract fun predictSoftly(dataset: Dataset, batchSize: Int): Array<FloatArray>

/**
* Generates output prediction for the input sample.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,13 @@ internal class SequentialBasicTest : IntegrationTest() {
val denseActivations = activations[2] as Array<FloatArray>
assertEquals(denseActivations[0][0].toDouble(), 0.0, EPS)

val predictions = it.predictAll(test, TEST_BATCH_SIZE)
val predictions = it.predict(test, TEST_BATCH_SIZE)
assertEquals(test.xSize(), predictions.size)

val softPredictions = it.predictSoftly(test, TEST_BATCH_SIZE)
assertEquals(test.xSize(), softPredictions.size)
assertEquals(AMOUNT_OF_CLASSES, softPredictions[0].size)

var manualAccuracy = 0
predictions.forEachIndexed { index, lb -> if (lb == test.getLabel(index)) manualAccuracy++ }
assertTrue(manualAccuracy > 0.7)
Expand Down Expand Up @@ -349,7 +353,7 @@ internal class SequentialBasicTest : IntegrationTest() {
testModel.use {
val exception =
Assertions.assertThrows(IllegalArgumentException::class.java) {
it.predictAll(test, 256)
it.predict(test, 256)
}
assertEquals(
"The amount of images must be a multiple of batch size.",
Expand All @@ -360,7 +364,7 @@ internal class SequentialBasicTest : IntegrationTest() {
testModel.use {
val exception =
Assertions.assertThrows(IllegalStateException::class.java) {
it.predictAll(test, 100)
it.predict(test, 100)
}
assertEquals(
"The model is not compiled yet. Compile the model to use this method.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ class SequentialInferenceTest {

val predictAllException =
Assertions.assertThrows(IllegalStateException::class.java) {
it.predictAll(test, TEST_BATCH_SIZE)
it.predict(test, TEST_BATCH_SIZE)
}
assertEquals(
"The model is not compiled yet. Compile the model to use this method.",
Expand Down Expand Up @@ -318,7 +318,7 @@ class SequentialInferenceTest {

val predictAllException =
Assertions.assertThrows(IllegalStateException::class.java) {
it.predictAll(test, TEST_BATCH_SIZE)
it.predict(test, TEST_BATCH_SIZE)
}
assertEquals(
"The model is not initialized yet. Initialize the model weights with init() method or load weights to use this method.",
Expand Down Expand Up @@ -660,4 +660,4 @@ class SequentialInferenceTest {
println("Test metrics: $testMetrics")
return testMetrics.toMap()
}
}
}

0 comments on commit 5f4eee8

Please sign in to comment.