Skip to content

Commit

Permalink
Move flattenFloats function to a separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
juliabeliaeva committed Aug 1, 2022
1 parent 35ebbab commit 1fe55f3
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,4 @@ internal fun numElementsInShape(shape: LongArray): Long {
prod *= abs(shape[i])
}
return prod
}

/**
* Flattens the given array of float values.
* @return flattened array
*/
public fun Array<*>.flattenFloats(): FloatArray {
val result = mutableListOf<Float>()

fun flatten(array: Any?): Unit = when (array) {
is FloatArray -> array.forEach { result.add(it) }
is Array<*> -> array.forEach { flatten(it) }
else -> throw IllegalArgumentException("Cannot flatten object: '$array'")
}

flatten(this)

return result.toFloatArray()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.util

/**
* Flattens the given array of float values.
* @return flattened array
*/
public fun Array<*>.flattenFloats(): FloatArray {
val result = mutableListOf<Float>()

fun flatten(array: Any?): Unit = when (array) {
is FloatArray -> array.forEach { result.add(it) }
is Array<*> -> array.forEach { flatten(it) }
else -> throw IllegalArgumentException("Cannot flatten object: '$array'")
}

flatten(this)

return result.toFloatArray()
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ package org.jetbrains.kotlinx.dl.api.inference.keras.loaders
import com.beust.klaxon.JsonArray
import com.beust.klaxon.JsonObject
import com.beust.klaxon.Parser
import org.jetbrains.kotlinx.dl.api.core.shape.flattenFloats
import org.jetbrains.kotlinx.dl.api.core.util.flattenFloats
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.activation.EPS
import org.jetbrains.kotlinx.dl.api.core.shape.flattenFloats
import org.jetbrains.kotlinx.dl.api.core.shape.shape
import org.jetbrains.kotlinx.dl.api.core.shape.toLongArray
import org.jetbrains.kotlinx.dl.api.core.util.flattenFloats
import org.jetbrains.kotlinx.dl.api.extension.convertTensorToFlattenFloatArray
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.tensorflow.Graph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.shape.flattenFloats
import org.jetbrains.kotlinx.dl.api.core.shape.shape
import org.jetbrains.kotlinx.dl.api.core.shape.toLongArray
import org.jetbrains.kotlinx.dl.api.core.util.flattenFloats
import org.jetbrains.kotlinx.dl.api.extension.convertTensorToFlattenFloatArray
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.tensorflow.*
Expand Down

0 comments on commit 1fe55f3

Please sign in to comment.