diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt index 44a1e0d15..4290df894 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt @@ -7,6 +7,8 @@ package org.jetbrains.kotlinx.dl.api.core import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.weights +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.util.sortTopologically import org.jetbrains.kotlinx.dl.api.inference.keras.* import org.tensorflow.Operand @@ -82,7 +84,7 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) { // TODO: make a honest copy of models (pretrained and topModel) or of all layers val pretrainedLayers = pretrainedModel.layers // Freezing - pretrainedLayers.forEach { it.isTrainable = false } + pretrainedLayers.forEach { it.freeze() } val layers = mutableListOf() layers += pretrainedLayers @@ -255,7 +257,7 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) { inputLayer.computeOutputShape() layers.filter { it !is Input }.forEach { - it.buildFromInboundLayers(tf, kGraph) + it.buildFromInboundLayers(tf) val outputShape = it.computeOutputShapeFromInboundLayers() val dims = outputShape.dims() diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt index 81f45c048..e0b33495a 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt @@ -10,8 +10,7 @@ import mu.KotlinLogging import org.jetbrains.kotlinx.dl.api.core.callback.Callback import org.jetbrains.kotlinx.dl.api.core.exception.RepeatableLayerNameException import org.jetbrains.kotlinx.dl.api.core.history.* -import org.jetbrains.kotlinx.dl.api.core.layer.Layer -import org.jetbrains.kotlinx.dl.api.core.layer.NoGradients +import org.jetbrains.kotlinx.dl.api.core.layer.* import org.jetbrains.kotlinx.dl.api.core.layer.core.ActivationLayer import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense import org.jetbrains.kotlinx.dl.api.core.layer.core.Input @@ -36,6 +35,7 @@ import org.jetbrains.kotlinx.dl.dataset.Dataset import org.tensorflow.* import org.tensorflow.op.Ops import org.tensorflow.op.core.Placeholder +import org.tensorflow.op.core.Variable import java.io.File import java.io.FileNotFoundException import java.nio.FloatBuffer @@ -114,6 +114,9 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel session = Session(kGraph.tfGraph) } + override fun variables(): List> = layers.variables().map { it.variable } + override fun frozenVariables(): List> = layers.frozenVariables().map { it.variable } + /** Helper method for preprocessing layer names and layer validation. */ internal companion object { internal fun preProcessLayerNames(layers: Array) { @@ -143,8 +146,6 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel override fun compile(optimizer: Optimizer, loss: LossFunction, metrics: List) { check(!isModelCompiled) { "The model is compiled already. Graph is created. Create new model and compile it." } - validateModelArchitecture() - this.loss = loss this.metrics = metrics this.optimizer = optimizer @@ -172,7 +173,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel yPredOp = forward(xOp, inputLayer) lossOp = buildLossFunction(loss) - targets = optimizer.prepareTargets(kGraph, tf, lossOp) + targets = optimizer.prepareTargets(kGraph, layers.trainableVariables().map { it.variable }, tf, lossOp) predictionOp = when (loss) { is SoftmaxCrossEntropyWithLogits -> tf.withName(OUTPUT_NAME).nn.softmax(yPredOp) @@ -190,9 +191,9 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel val basicLoss = loss.apply(tf, yPredOp, yTrueOp, numberOfLossesOp) var totalLoss = basicLoss // TODO: probably regularization output should be divided on numberOfLossesOp and changed together with loss before averaging - kGraph.variableRegularizers.forEach { (variable, regularizer) -> - run { - totalLoss = tf.math.add(totalLoss, regularizer.apply(tf, variable)) + layers.trainableVariables().forEach { variable -> + variable.regularizer?.let { regularizer -> + totalLoss = tf.math.add(totalLoss, regularizer.apply(tf, variable.variable)) } } return tf.withName(TRAINING_LOSS).identity(totalLoss) @@ -207,23 +208,6 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel compile(optimizer, loss, Metrics.convert(metric)) } - /** Validates architecture. */ - private fun validateModelArchitecture() { - layerValidation(layers.toList()) - - require(layers.none { it is NoGradients && it.isTrainable }) - { - "All layers that implements NoGradient interface should be frozen (status isTrainable==false). " + - "But the following layers violates this rule: ${ - layers.filter { it is NoGradients && it.isTrainable }.map { it.name }.toTypedArray() - .contentDeepToString() - }" - } - // require(layers.last() is Dense) { "DL architectures are not finished with Dense layer are not supported yet!" } - // require(layers.last().hasActivation()) { "Last layer must have an activation function." } - // require((layers.last() as Dense).activation != Activations.Sigmoid) { "The last dense layer should have Linear activation, alternative activations are not supported yet!" } - } - /** Common method for building the initial part of the model static graph layer by layer via calling build() method on each layer in correct order. */ protected abstract fun buildLayers() @@ -277,7 +261,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel check(!isOptimizerVariableInitialized) { "Optimizer variables are initialized already!" } logger.debug { "Initialization of TensorFlow Graph variables." } - kGraph.initializeGraphVariables(session) + layers.initializeVariables(session) isModelInitialized = true } @@ -292,7 +276,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel check(isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." } logger.debug { "Initialization of TensorFlow Graph variables." } - kGraph.initializeGraphVariables(session) + layers.initializeVariables(session) isModelInitialized = true isOptimizerVariableInitialized = false } @@ -308,16 +292,9 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel ): TrainingHistory { check(isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." } - for (layer in layers) { - check(!(layer is NoGradients && layer.isTrainable)) { - "Layer $layer has no gradients implementations in TensorFlow and should be non-trainable only. " + - "Set 'isTrainable' to 'false'." - } - } - if (!isModelInitialized) { logger.debug { "Initialization of TensorFlow Graph variables." } - kGraph.initializeGraphVariables(session) + layers.initializeVariables(session) isModelInitialized = true } @@ -981,3 +958,11 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel ) } } + +/** + * Freezes weights in all layers in this model, so they won't be changed during training. + * @see [Layer.freeze] + */ +public fun GraphTrainableModel.freeze() { + layers.forEach(Layer::freeze) +} diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/KGraph.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/KGraph.kt index bf4ab7985..17fe726f8 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/KGraph.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/KGraph.kt @@ -5,10 +5,8 @@ package org.jetbrains.kotlinx.dl.api.core -import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer import org.tensorflow.Graph import org.tensorflow.GraphOperation -import org.tensorflow.Operand import org.tensorflow.Session import org.tensorflow.op.core.Assign import org.tensorflow.op.core.AssignAdd @@ -39,18 +37,9 @@ public class KGraph(graphDef: ByteArray, prefix: String) : AutoCloseable { /** A list of initializer to initialize the trainableVariables. */ private val optimizerAssignAddInitializers: MutableList> = mutableListOf() - /** A list of variables to train. */ - private val layerVariables: MutableMap, Boolean> = mutableMapOf() - - /** A list of variables to train. */ - internal val variableRegularizers: MutableMap, Regularizer> = mutableMapOf() - /** A list of optimizers' variables. */ private val optimizerVariables: MutableList> = mutableListOf() - /** A list of initializer to initialize the trainableVariables. */ - private val initializers: MutableMap> = mutableMapOf() - init { if (prefix.isEmpty()) { tfGraph.importGraphDef(graphDef) @@ -102,28 +91,6 @@ public class KGraph(graphDef: ByteArray, prefix: String) : AutoCloseable { return KGraph(tfGraph.toGraphDef()) } - /** - * Adds a variable used in layer to the pool of tracked variables. - * - * @param variable Variable to track in KGraph. - * @param isTrainable When passed isTrainable = true, the KGraph automatically adds new variables to the tracked collection of graph variables. - */ - public fun addLayerVariable(variable: Variable, isTrainable: Boolean) { - check(!layerVariables.contains(variable)) { "$variable is added to graph already. Analyze and fix the static graph building process." } - layerVariables[variable] = isTrainable - } - - /** - * Adds a variable related to the given [regularizer] to the pool of tracked variables. - * - * @param [variable] TensorFlow variable to track in KGraph. - * @param [regularizer] The regularizer related to the given [variable]. - */ - public fun addVariableRegularizer(variable: Variable, regularizer: Regularizer) { - check(!variableRegularizers.contains(variable)) { "$variable is added to graph already. Analyze and fix the static graph building process." } - variableRegularizers[variable] = regularizer - } - /** * Adds a variable used in optimizer to the pool of tracked variables. * @@ -134,17 +101,6 @@ public class KGraph(graphDef: ByteArray, prefix: String) : AutoCloseable { optimizerVariables.add(variable) } - /** - * Adds an initializer for layer variable tracked in KGraph. - * - * @param variableName Variable name to bind the variable and initializer. - * @param initializer Assign TensorFlow operand to initialize variable with name: [variableName]. - */ - public fun addInitializer(variableName: String, initializer: Assign) { - check(!initializers.contains(variableName)) { "$variableName has initializer already. Analyze and fix the static graph building process." } - initializers[variableName] = initializer - } - /** * Adds an optimizer initializer for optimizer variable tracked in KGraph. * @@ -163,39 +119,6 @@ public class KGraph(graphDef: ByteArray, prefix: String) : AutoCloseable { optimizerAssignAddInitializers += initializer } - /** - * Returns a list of trainable variables used in layers. - */ - public fun trainableLayerVariables(): List> { - return layerVariables.filter { it.value }.keys.toList() - } - - /** - * Returns the variable's ability to be changed during the training. - * - * @param variableName Variable name. - */ - public fun isVariableTrainable(variableName: String): Boolean { - return layerVariables - .filter { it.value } // only trainable - .map { it.key.ref().op().name() } // extract names - .any { it.contains(variableName) } - } - - /** - * Returns a list of non-trainable, 'frozen' variables used in layers. - */ - public fun frozenLayerVariables(): List> { - return layerVariables.filterNot { it.value }.keys.toList() - } - - /** - * Returns all variables used in all model layers. - */ - public fun layerVariables(): List> { - return layerVariables.keys.toList() - } - /** * Returns all variables used in optimizer and initialized by Assign TensorFlow operand. */ @@ -203,20 +126,6 @@ public class KGraph(graphDef: ByteArray, prefix: String) : AutoCloseable { return optimizerVariables.toList() } - /** - * Initializes TensorFlow graph variables used in model layers. - */ - public fun initializeGraphVariables(session: Session) { - val runner = session.runner() - // Only run the session if there is initializer ops. - if (initializers.isNotEmpty()) { - initializers.forEach { - runner.addTarget(it.value as Operand) - } - runner.run() - } - } - /** * Initializes TensorFlow graph variables used in optimizer. */ diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt index a08b8cdd5..9d023e0a5 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -7,6 +7,7 @@ package org.jetbrains.kotlinx.dl.api.core import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape import org.jetbrains.kotlinx.dl.api.inference.keras.deserializeSequentialModel import org.jetbrains.kotlinx.dl.api.inference.keras.loadSequentialModelLayers @@ -148,7 +149,7 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) { var inputShape: Shape = inputLayer.computeOutputShape() layers.filter { it !is Input }.forEach { - it.build(tf, kGraph, inputShape) + it.build(tf, inputShape) inputShape = it.computeOutputShape(inputShape) val tensorShape = TensorShape(inputShape) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/KVariable.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/KVariable.kt index ea0544883..18422dec5 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/KVariable.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/KVariable.kt @@ -5,7 +5,6 @@ package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer import org.jetbrains.kotlinx.dl.api.core.util.getDType @@ -33,9 +32,7 @@ public data class KVariable( internal fun createVariable( tf: Ops, - kGraph: KGraph, variableName: String, - isTrainable: Boolean, shape: Shape, fanIn: Int, fanOut: Int, @@ -43,12 +40,7 @@ internal fun createVariable( regularizer: Regularizer? ): KVariable { val tfVariable = tf.withName(variableName).variable(shape, getDType()) - val initOp = initializer.apply(fanIn, fanOut, tf, tfVariable, variableName) - kGraph.addLayerVariable(tfVariable, isTrainable) - kGraph.addInitializer(variableName, initOp) - if (regularizer != null) kGraph.addVariableRegularizer(tfVariable, regularizer) - return KVariable( name = variableName, shape = shape, diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Layer.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Layer.kt index 649c277b5..74f65bbe4 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Layer.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Layer.kt @@ -5,10 +5,8 @@ package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.TrainableModel import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape -import org.jetbrains.kotlinx.dl.api.extension.convertTensorToMultiDimArray import org.tensorflow.Operand import org.tensorflow.Shape import org.tensorflow.op.Ops @@ -19,17 +17,6 @@ import org.tensorflow.op.Ops * @param [name] Layer name. Would be changed if empty during model compilation. */ public abstract class Layer(public var name: String) { - /** - * True, if layer's weights could be changed during training. - * If false, layer's weights are frozen and could be changed during the training. - */ - public var isTrainable: Boolean = true - set(value) { - if (value && this is NoGradients) - throw IllegalStateException("${this.javaClass.simpleName} layer can not set `isTrainable` to `true`.") - field = value - } - /** Output data tensor shape. */ public lateinit var outputShape: TensorShape @@ -46,10 +33,9 @@ public abstract class Layer(public var name: String) { * Extend this function to define variables in layer. * * @param [tf] TensorFlow graph API for building operations. - * @param [kGraph] [KGraph] to update it. * @param [inputShape] Input shape, result of [computeOutputShape] call from previous layer. */ - public abstract fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) + public abstract fun build(tf: Ops, inputShape: Shape) /** @@ -59,11 +45,10 @@ public abstract class Layer(public var name: String) { * NOTE: Used in Functional API * * @param [tf] TensorFlow graph API for building operations. - * @param [kGraph] [KGraph] to update it. */ - public fun buildFromInboundLayers(tf: Ops, kGraph: KGraph) { + public fun buildFromInboundLayers(tf: Ops) { require(inboundLayers.isNotEmpty()) { "There is no inbound layers to compute output shape" } - build(tf, kGraph, inboundLayers[0].outputShape.toShape()) + build(tf, inboundLayers[0].outputShape.toShape()) } /** @@ -110,39 +95,8 @@ public abstract class Layer(public var name: String) { return this } - /** Extract weights values for provided variables. */ - protected fun extractWeights(vararg variables: KVariable?): Map> { - require(parentModel != null) { "Layer $name is not related to any model!" } - - val session = parentModel!!.session - val runner = session.runner() - - val variableNames = variables.mapNotNull { it?.name } - for (variableName in variableNames) { - runner.fetch(variableName) - } - - val weights = runner.run().map { it.convertTensorToMultiDimArray() } - return variableNames.zip(weights).toMap() - } - - /** Extract weights values by variable names. */ - protected fun assignWeights(weights: Map>) { - require(parentModel != null) { "Layer $name is not related to any model!" } - - for (variableName in weights.keys) { - parentModel!!.assignVariable(variableName, weights[variableName]!!) - } - } - - /** Layer's weights. */ - public abstract var weights: Map> - /** Returns True, if layer has internal activation function. */ public abstract val hasActivation: Boolean - - /** Returns amount of neurons. */ - public abstract val paramCount: Int } internal fun requireArraySize(array: IntArray, size: Int, name: String) = diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ParametrizedLayer.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ParametrizedLayer.kt new file mode 100644 index 000000000..4ec8719ac --- /dev/null +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ParametrizedLayer.kt @@ -0,0 +1,69 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.core.layer + +import org.jetbrains.kotlinx.dl.api.core.shape.numElements +import org.tensorflow.Session + +/** + * Represents a [Layer] with parameters. + */ +public interface ParametrizedLayer { + /** Variables used in this layer. */ + public val variables: List + + /** Number of parameters in this layer. */ + public val paramCount: Int + get() = variables.sumOf { it.shape.numElements() }.toInt() +} + +/** + * Returns the number of parameters in this layer. If layer is not a [ParametrizedLayer], returns zero. + */ +public val Layer.paramCount: Int + get() = if (this is ParametrizedLayer) paramCount else 0 + +/** + * Returns all variables used in all layers. + */ +internal fun List.variables(): List { + return filterIsInstance().flatMap { it.variables } +} + +/** + * Returns a list of trainable variables used in the layers. + */ +internal fun List.trainableVariables(): List { + return filterIsInstance().filter { it.isTrainable }.flatMap { it.variables } +} + +/** + * Returns a list of non-trainable, 'frozen' variables used in the layers. + */ +internal fun List.frozenVariables(): List { + return filterIsInstance() + .filter { it !is TrainableLayer || !it.isTrainable } + .flatMap { it.variables } +} + +/** + * Initializes this layers variables using provided initializer operands. + */ +public fun ParametrizedLayer.initialize(session: Session) { + // Only run the session if there is initializer ops. + if (variables.isNotEmpty()) { + val runner = session.runner() + variables.map { it.initializerOperand }.forEach(runner::addTarget) + runner.run() + } +} + +/** + * Initializes variables for [ParametrizedLayer] instances using provided initializer operands. + */ +public fun List.initializeVariables(session: Session) { + filterIsInstance().forEach { it.initialize(session) } +} \ No newline at end of file diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/TrainableLayer.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/TrainableLayer.kt new file mode 100644 index 000000000..578e715a2 --- /dev/null +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/TrainableLayer.kt @@ -0,0 +1,41 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.core.layer + +/** + * Represents a [ParametrizedLayer] that can be trained. + */ +public interface TrainableLayer : ParametrizedLayer { + /** + * True, if layer's weights could be changed during training. + * If false, layer's weights are frozen and could not be changed during training. + */ + public var isTrainable: Boolean +} + +/** + * Returns true if this [Layer] is trainable, false otherwise. + * Always returns false for layers not implementing [TrainableLayer]. + */ +public val Layer.isTrainable: Boolean + get() = if (this is TrainableLayer) isTrainable else false + +/** + * Freezes layer weights, so they won't be changed during training. + */ +public fun Layer.freeze() { + if (this is TrainableLayer) isTrainable = false +} + +/** + * Unfreezes layer weights, allowing to change them during training. + */ +public fun Layer.unfreeze() { + require(this is TrainableLayer) { + "Layer ${this.name} does not support training" + } + isTrainable = true +} \ No newline at end of file diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Weights.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Weights.kt new file mode 100644 index 000000000..1b2285e4b --- /dev/null +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/Weights.kt @@ -0,0 +1,35 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.core.layer + +import org.jetbrains.kotlinx.dl.api.extension.convertTensorToMultiDimArray + +/** + * Weights of this layer. Require parent model to be set on the layer. + */ +public var Layer.weights: Map> + get() { + if (this !is ParametrizedLayer) return emptyMap() + + val model = parentModel + requireNotNull(model) { "Layer '$name' is not related to any model" } + + val runner = model.session.runner() + variables.map { it.variable }.forEach(runner::fetch) + val weights = runner.run().map { it.convertTensorToMultiDimArray() } + + return variables.map { it.name }.zip(weights).toMap() + } + set(weights) { + if (this !is ParametrizedLayer) return + + val model = parentModel + requireNotNull(model) { "Layer '$name' is not related to any model" } + + for ((name, value) in weights) { + model.assignVariable(name, value) + } + } \ No newline at end of file diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/AbstractActivationLayer.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/AbstractActivationLayer.kt index 4ad1e44d9..d29a41e5f 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/AbstractActivationLayer.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/AbstractActivationLayer.kt @@ -1,11 +1,10 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.activation -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape import org.tensorflow.Operand @@ -24,10 +23,6 @@ import org.tensorflow.op.Ops * @param [name] Layer name. Would be changed if empty during model compilation. */ public abstract class AbstractActivationLayer(name: String) : Layer(name) { - init { - isTrainable = false - } - /** * Applies the activation functions to the [input] to produce the output. * @@ -46,18 +41,12 @@ public abstract class AbstractActivationLayer(name: String) : Layer(name) { numberOfLosses: Operand? ): Operand = forward(tf, input) - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit + override fun build(tf: Ops, inputShape: Shape): Unit = Unit override fun computeOutputShape(inputShape: Shape): Shape { this.outputShape = TensorShape(inputShape) return inputShape } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = true - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ELU.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ELU.kt index c8765f33e..50aa2ec5f 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ELU.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ELU.kt @@ -1,5 +1,5 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -47,5 +47,5 @@ public class ELU( } override fun toString(): String = - "ELU(name = $name, isTrainable=$isTrainable, alpha=$alpha)" + "ELU(name = $name, alpha=$alpha)" } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/LeakyReLU.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/LeakyReLU.kt index 641ed1b4c..1b44d84e9 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/LeakyReLU.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/LeakyReLU.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -35,5 +35,5 @@ public class LeakyReLU( } override fun toString(): String = - "LeakyReLU(name = $name, isTrainable=$isTrainable, alpha=$alpha)" + "LeakyReLU(name = $name, alpha=$alpha)" } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/PReLU.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/PReLU.kt index f0a6ffd9a..d1d73a829 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/PReLU.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/PReLU.kt @@ -5,13 +5,12 @@ package org.jetbrains.kotlinx.dl.api.core.layer.activation -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros import org.jetbrains.kotlinx.dl.api.core.layer.KVariable +import org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayer import org.jetbrains.kotlinx.dl.api.core.layer.createVariable import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer -import org.jetbrains.kotlinx.dl.api.core.shape.numElements import org.jetbrains.kotlinx.dl.api.core.shape.toLongArray import org.tensorflow.Operand import org.tensorflow.Shape @@ -36,7 +35,7 @@ public class PReLU( public val alphaRegularizer: Regularizer? = null, public val sharedAxes: IntArray? = null, name: String = "" -) : AbstractActivationLayer(name) { +) : AbstractActivationLayer(name), TrainableLayer { /** * TODO: support for constraint (alphaConstraint) should be added */ @@ -45,17 +44,12 @@ public class PReLU( private fun alphaVariableName(): String = if (name.isNotEmpty()) "${name}_alpha" else "alpha" - override var weights: Map> - get() = extractWeights(alpha) - set(value) = assignWeights(value) - override val paramCount: Int - get() = alpha.shape.numElements().toInt() + override val variables: List + get() = listOf(alpha) - init { - isTrainable = true - } + override var isTrainable: Boolean = true - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + override fun build(tf: Ops, inputShape: Shape) { val alphaShapeArray = inputShape.toLongArray().drop(1).toLongArray() if (sharedAxes != null) { for (axis in sharedAxes) { @@ -69,9 +63,7 @@ public class PReLU( val alphaShape = Shape.make(alphaShapeArray[0], *alphaShapeArray.drop(1).toLongArray()) alpha = createVariable( tf, - kGraph, alphaVariableName(), - isTrainable, alphaShape, fanIn, fanOut, diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ReLU.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ReLU.kt index 58bcdb746..b00ba74e1 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ReLU.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ReLU.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -45,5 +45,5 @@ public class ReLU( } override fun toString(): String = - "ReLU(name = $name, isTrainable=$isTrainable, maxValue=$maxValue, negativeSlope=$negativeSlope, threshold=$threshold)" + "ReLU(name = $name, maxValue=$maxValue, negativeSlope=$negativeSlope, threshold=$threshold)" } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/Softmax.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/Softmax.kt index 5f0851976..8e7d9dbf8 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/Softmax.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/Softmax.kt @@ -1,10 +1,14 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + package org.jetbrains.kotlinx.dl.api.core.layer.activation import org.tensorflow.Operand import org.tensorflow.op.Ops -import org.tensorflow.op.core.ReduceSum - import org.tensorflow.op.core.ReduceMax +import org.tensorflow.op.core.ReduceSum /** * Softmax activation layer @@ -44,6 +48,6 @@ public class Softmax( } override fun toString(): String { - return "Softmax(name = $name, isTrainable=$isTrainable, axis=$axis)" + return "Softmax(name = $name, axis=$axis)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ThresholdedReLU.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ThresholdedReLU.kt index 9bcc66ee9..9934a19ab 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ThresholdedReLU.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/activation/ThresholdedReLU.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -35,5 +35,5 @@ public class ThresholdedReLU( } override fun toString(): String = - "ThresholdedReLU(name = $name, isTrainable=$isTrainable, theta=$theta)" + "ThresholdedReLU(name = $name, theta=$theta)" } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/AbstractConv.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/AbstractConv.kt index 45de0c943..4a7d13538 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/AbstractConv.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/AbstractConv.kt @@ -5,16 +5,11 @@ package org.jetbrains.kotlinx.dl.api.core.layer.convolutional -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer -import org.jetbrains.kotlinx.dl.api.core.layer.KVariable -import org.jetbrains.kotlinx.dl.api.core.layer.Layer -import org.jetbrains.kotlinx.dl.api.core.layer.createVariable -import org.jetbrains.kotlinx.dl.api.core.layer.toLongArray +import org.jetbrains.kotlinx.dl.api.core.layer.* import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape -import org.jetbrains.kotlinx.dl.api.core.shape.numElements import org.jetbrains.kotlinx.dl.api.core.shape.shapeFromDims import org.tensorflow.Operand import org.tensorflow.Shape @@ -46,8 +41,8 @@ import kotlin.math.roundToInt */ public abstract class AbstractConv( name: String -) : Layer(name) { - +) : Layer(name), ParametrizedLayer { + protected abstract val filters: Int protected abstract val kernelSize: IntArray protected abstract val strides: IntArray @@ -67,7 +62,10 @@ public abstract class AbstractConv( /** Tensor with bias weights */ internal var bias: KVariable? = null - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + public override val variables: List + get() = listOfNotNull(kernel, bias) + + override fun build(tf: Ops, inputShape: Shape) { // Amount of channels should be the last value in the inputShape val numberOfChannels = inputShape.size(inputShape.numDimensions() - 1) @@ -79,9 +77,7 @@ public abstract class AbstractConv( kernel = createVariable( tf, - kGraph, kernelVarName(name), - isTrainable, computeKernelShape(numberOfChannels), fanIn, fanOut, @@ -92,9 +88,7 @@ public abstract class AbstractConv( if (useBias) { bias = createVariable( tf, - kGraph, biasVarName(name), - isTrainable, computeBiasShape(numberOfChannels), fanIn, fanOut, @@ -123,15 +117,8 @@ public abstract class AbstractConv( return Activations.convert(activation).apply(tf, withBias, name) } - override var weights: Map> - get() = extractWeights(kernel, bias) - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = true - override val paramCount: Int - get() = (kernel.shape.numElements() + (bias?.shape?.numElements() ?: 0)).toInt() - /** Define the number of output channels given the number of input channels. * Defaults to the number of filter in convolutional layer. */ protected open fun getOutputDepth(numberOfChannels: Long): Long = filters.toLong() diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv1D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv1D.kt index 080d4f5f4..a70533852 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv1D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv1D.kt @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal import org.jetbrains.kotlinx.dl.api.core.initializer.HeUniform import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer +import org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayer import org.jetbrains.kotlinx.dl.api.core.layer.requireArraySize import org.jetbrains.kotlinx.dl.api.core.layer.toLongList import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer @@ -69,7 +70,7 @@ public class Conv1D( public override val padding: ConvPadding = ConvPadding.SAME, public override val useBias: Boolean = true, name: String = "", -) : AbstractConv(name = name) { +) : AbstractConv(name = name), TrainableLayer { public constructor( filters: Int = 32, kernelLength: Int = 3, @@ -105,6 +106,8 @@ public class Conv1D( requireArraySize(dilations, 3, "dilations") } + public override var isTrainable: Boolean = true + override val kernelSize: IntArray = intArrayOf(kernelLength) override fun kernelVarName(name: String): String = convKernelVarName(name, dim = 1) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv1DTranspose.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv1DTranspose.kt index 9bf491b9b..3de1e712a 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv1DTranspose.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv1DTranspose.kt @@ -74,7 +74,6 @@ public class Conv1DTranspose( requireArraySize(strides, dimensions + 2, "strides") requireArraySize(dilations, dimensions + 2, "dilations") if (outputPadding != null) requireArraySize(outputPadding, 2 * (dimensions + 2), "outputPadding") - isTrainable = false } override val kernelSize: IntArray = intArrayOf(kernelLength) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv2D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv2D.kt index 7431a5fa1..e1aa28cb4 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv2D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv2D.kt @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal import org.jetbrains.kotlinx.dl.api.core.initializer.HeUniform import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer +import org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayer import org.jetbrains.kotlinx.dl.api.core.layer.requireArraySize import org.jetbrains.kotlinx.dl.api.core.layer.toLongList import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer @@ -65,7 +66,10 @@ public class Conv2D( public override val padding: ConvPadding = ConvPadding.SAME, public override val useBias: Boolean = true, name: String = "" -) : AbstractConv(name = name) { +) : AbstractConv(name = name), TrainableLayer { + + public override var isTrainable: Boolean = true + public constructor( filters: Int = 32, kernelSize: Int = 3, diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv2DTranspose.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv2DTranspose.kt index e476a0c28..f9ace6ed6 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv2DTranspose.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv2DTranspose.kt @@ -102,7 +102,6 @@ public class Conv2DTranspose( requireArraySize(strides, dimensions + 2, "strides") requireArraySize(dilations, dimensions + 2, "dilations") if (outputPadding != null) requireArraySize(outputPadding, 2 * (dimensions + 2), "outputPadding") - isTrainable = false } override fun convImplementation(tf: Ops, input: Operand): Operand { diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv3D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv3D.kt index 1095f521c..083995471 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv3D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv3D.kt @@ -104,7 +104,6 @@ public class Conv3D( requireArraySize(kernelSize, 3, "kernelSize") requireArraySize(strides, 5, "strides") requireArraySize(dilations, 5, "dilations") - isTrainable = false } override fun kernelVarName(name: String): String = convKernelVarName(name, dim = 3) @@ -157,7 +156,7 @@ public class Conv3D( } override fun toString(): String { - return "Conv3D(name = $name, isTrainable=$isTrainable, filters=$filters, kernelSize=${kernelSize.contentToString()}, " + + return "Conv3D(name = $name, filters=$filters, kernelSize=${kernelSize.contentToString()}, " + "strides=${strides.contentToString()}, dilations=${dilations.contentToString()}, " + "activation=$activation, " + "kernelInitializer=$kernelInitializer, biasInitializer=$biasInitializer, " + diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv3DTranspose.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv3DTranspose.kt index 9de55599a..d5a5e4686 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv3DTranspose.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv3DTranspose.kt @@ -99,7 +99,6 @@ public class Conv3DTranspose( requireArraySize(kernelSize, dimensions, "kernelSize") requireArraySize(strides, dimensions + 2, "strides") requireArraySize(dilations, dimensions + 2, "dilations") - isTrainable = false } override val outputPadding: IntArray? get() = null diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/DepthwiseConv2D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/DepthwiseConv2D.kt index 1de1d747a..d861e9d25 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/DepthwiseConv2D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/DepthwiseConv2D.kt @@ -98,7 +98,6 @@ public class DepthwiseConv2D( requireArraySize(kernelSize, 2, "kernelSize") requireArraySize(strides, 4, "strides") requireArraySize(dilations, 4, "dilations") - isTrainable = false } // filters is not used in any place of this implementation of AbstractConv because @@ -154,7 +153,7 @@ public class DepthwiseConv2D( } override fun toString(): String { - return "DepthwiseConv2D(name = $name, isTrainable=$isTrainable, " + + return "DepthwiseConv2D(name = $name, " + "kernelSize=${kernelSize.contentToString()}, " + "strides=${strides.contentToString()}, " + "dilations=${dilations.contentToString()}, " + diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/SeparableConv2D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/SeparableConv2D.kt index b995b21f3..67392de87 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/SeparableConv2D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/convolutional/SeparableConv2D.kt @@ -5,7 +5,6 @@ package org.jetbrains.kotlinx.dl.api.core.layer.convolutional -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal import org.jetbrains.kotlinx.dl.api.core.initializer.HeUniform @@ -14,7 +13,6 @@ import org.jetbrains.kotlinx.dl.api.core.layer.* import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape import org.jetbrains.kotlinx.dl.api.core.shape.convOutputLength -import org.jetbrains.kotlinx.dl.api.core.shape.numElements import org.jetbrains.kotlinx.dl.api.core.shape.shapeFromDims import org.jetbrains.kotlinx.dl.api.core.util.separableConv2dBiasVarName import org.jetbrains.kotlinx.dl.api.core.util.separableConv2dDepthwiseKernelVarName @@ -76,7 +74,8 @@ public class SeparableConv2D( public val padding: ConvPadding = ConvPadding.SAME, public val useBias: Boolean = true, name: String = "" -) : Layer(name), NoGradients { +) : Layer(name), NoGradients, ParametrizedLayer { + public constructor( filters: Int = 32, kernelSize: Int = 3, @@ -118,14 +117,16 @@ public class SeparableConv2D( internal lateinit var pointwiseKernel: KVariable internal var bias: KVariable? = null + override val variables: List + get() = listOfNotNull(depthwiseKernel, pointwiseKernel, bias) + init { requireArraySize(kernelSize, 2, "kernelSize") requireArraySize(strides, 4, "strides") requireArraySize(dilations, 4, "dilations") - isTrainable = false } - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + override fun build(tf: Ops, inputShape: Shape) { // Amount of channels should be the last value in the inputShape (make warning here) val numberOfChannels = inputShape.size(inputShape.numDimensions() - 1) @@ -139,9 +140,7 @@ public class SeparableConv2D( val depthwiseKernelShape = shapeFromDims(*kernelSize.toLongArray(), numberOfChannels, depthMultiplier.toLong()) depthwiseKernel = createVariable( tf, - kGraph, separableConv2dDepthwiseKernelVarName(name), - isTrainable, depthwiseKernelShape, fanIn, fanOut, @@ -151,9 +150,7 @@ public class SeparableConv2D( val pointwiseKernelShape = shapeFromDims(1, 1, numberOfChannels * depthMultiplier, filters.toLong()) pointwiseKernel = createVariable( tf, - kGraph, separableConv2dPointwiseKernelVarName(name), - isTrainable, pointwiseKernelShape, fanIn, fanOut, @@ -164,9 +161,7 @@ public class SeparableConv2D( val biasShape = Shape.make(filters.toLong()) bias = createVariable( tf, - kGraph, separableConv2dBiasVarName(name), - isTrainable, biasShape, fanIn, fanOut, @@ -235,14 +230,5 @@ public class SeparableConv2D( "hasActivation=$hasActivation)" } - override var weights: Map> - get() = extractWeights(depthwiseKernel, pointwiseKernel, bias) - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = true - - override val paramCount: Int - get() = listOfNotNull(depthwiseKernel, pointwiseKernel, bias).sumOf { it.shape.numElements() }.toInt() - - } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/core/Dense.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/core/Dense.kt index 56fe9c375..8e7d17d3f 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/core/Dense.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/core/Dense.kt @@ -5,17 +5,16 @@ package org.jetbrains.kotlinx.dl.api.core.layer.core -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal import org.jetbrains.kotlinx.dl.api.core.initializer.HeUniform import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer import org.jetbrains.kotlinx.dl.api.core.layer.KVariable import org.jetbrains.kotlinx.dl.api.core.layer.Layer +import org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayer import org.jetbrains.kotlinx.dl.api.core.layer.createVariable import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape -import org.jetbrains.kotlinx.dl.api.core.shape.numElements import org.jetbrains.kotlinx.dl.api.core.util.denseBiasVarName import org.jetbrains.kotlinx.dl.api.core.util.denseKernelVarName import org.tensorflow.Operand @@ -54,20 +53,23 @@ public class Dense( public val activityRegularizer: Regularizer? = null, public val useBias: Boolean = true, name: String = "" -) : Layer(name) { +) : Layer(name), TrainableLayer { internal lateinit var kernel: KVariable internal var bias: KVariable? = null - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + override val variables: List + get() = listOfNotNull(kernel, bias) + + override var isTrainable: Boolean = true + + override fun build(tf: Ops, inputShape: Shape) { val fanIn = inputShape.size(inputShape.numDimensions() - 1).toInt() val fanOut = outputSize val kernelShape = Shape.make(inputShape.size(inputShape.numDimensions() - 1), outputSize.toLong()) kernel = createVariable( tf, - kGraph, denseKernelVarName(name), - isTrainable, kernelShape, fanIn, fanOut, @@ -79,9 +81,7 @@ public class Dense( val biasShape = Shape.make(outputSize.toLong()) bias = createVariable( tf, - kGraph, denseBiasVarName(name), - isTrainable, biasShape, fanIn, fanOut, @@ -113,12 +113,5 @@ public class Dense( "useBias=$useBias, hasActivation=$hasActivation, kernelShapeArray=${kernel.shape}, biasShapeArray=${bias?.shape})" } - override var weights: Map> - get() = extractWeights(kernel, bias) - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = true - - override val paramCount: Int - get() = listOfNotNull(kernel, bias).sumOf { it.shape.numElements() }.toInt() } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/core/Input.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/core/Input.kt index a0a9d1f7d..7a8d4ed87 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/core/Input.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/core/Input.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.core -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape import org.jetbrains.kotlinx.dl.api.core.util.DATA_PLACEHOLDER @@ -30,7 +29,7 @@ public class Input(vararg dims: Long, name: String = "") : Layer(name) { /** Input data dimensions. Rank = 3 or 4 for most popular supported cases. */ public var packedDims: LongArray = dims - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} /** * Extend this function to define placeholder in layer. @@ -69,15 +68,9 @@ public class Input(vararg dims: Long, name: String = "") : Layer(name) { return inputShape } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - override val paramCount: Int get() = 0 - override fun toString(): String { - return "Input(name = $name, isTrainable=$isTrainable, shape=${packedDims.contentToString()})" + return "Input(name = $name, shape=${packedDims.contentToString()})" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/AbstractMerge.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/AbstractMerge.kt index 17acb4119..1c7f95142 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/AbstractMerge.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/AbstractMerge.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.merge -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape import org.tensorflow.Operand @@ -20,7 +19,7 @@ import org.tensorflow.op.Ops * @property [layerTypeName] Specified layer name used for tf operation alias building. */ public abstract class AbstractMerge(public val layerTypeName: String, name: String = "") : Layer(name) { - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + override fun build(tf: Ops, inputShape: Shape) { } @@ -86,13 +85,6 @@ public abstract class AbstractMerge(public val layerTypeName: String, name: Stri } } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int - get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Add.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Add.kt index 56e347fad..ffe6b4cf8 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Add.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Add.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -25,6 +25,6 @@ public class Add(name: String = "") : AbstractMerge("AddLayer", name) { } override fun toString(): String { - return "Add(name = $name, isTrainable=$isTrainable)" + return "Add(name = $name)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Average.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Average.kt index c211a2688..d850f1101 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Average.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Average.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -27,6 +27,6 @@ public class Average(name: String = "") : AbstractMerge("AverageLayer", name) { } override fun toString(): String { - return "Average(name = $name, isTrainable=$isTrainable,)" + return "Average(name = $name)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Concatenate.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Concatenate.kt index 4c77b9453..0aac9c3a1 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Concatenate.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Concatenate.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -22,10 +22,6 @@ public class Concatenate( public var axis: Int = 3, name: String = "" ) : AbstractMerge("ConcatenateLayer", name), NoGradients { - init { - isTrainable = false - } - override fun computeOutputShapeFromInboundLayers(): TensorShape { val inputShapes = mutableListOf() inboundLayers.forEach { inboundLayer -> inputShapes.add(inboundLayer.outputShape) } @@ -68,6 +64,6 @@ public class Concatenate( } override fun toString(): String { - return "Concatenate(name = $name, isTrainable=$isTrainable, axis=$axis)" + return "Concatenate(name = $name, axis=$axis)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Dot.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Dot.kt index 0ae513fde..d10eabb14 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Dot.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Dot.kt @@ -1,3 +1,8 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + package org.jetbrains.kotlinx.dl.api.core.layer.merge import org.tensorflow.Operand @@ -45,7 +50,7 @@ public class Dot( } override fun toString(): String { - return "Dot(name = $name, isTrainable=$isTrainable, axis=${axis.contentToString()}, normalize=$normalize)" + return "Dot(name = $name, axis=${axis.contentToString()}, normalize=$normalize)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Maximum.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Maximum.kt index 2108ff80c..2e63363c9 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Maximum.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Maximum.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -25,6 +25,6 @@ public class Maximum(name: String = "") : AbstractMerge("MaximumLayer", name) { } override fun toString(): String { - return "Maximum(name = $name, isTrainable=$isTrainable,)" + return "Maximum(name = $name)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Minimum.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Minimum.kt index 0ba5046c8..6a2156b21 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Minimum.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Minimum.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -25,6 +25,6 @@ public class Minimum(name: String = "") : AbstractMerge("MinimumLayer", name) { } override fun toString(): String { - return "Minimum(name = $name, isTrainable=$isTrainable,)" + return "Minimum(name = $name)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Multiply.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Multiply.kt index c2728ee24..c0f991210 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Multiply.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Multiply.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -25,6 +25,6 @@ public class Multiply(name: String = "") : AbstractMerge("MultiplyLayer", name) } override fun toString(): String { - return "Multiply(name = $name, isTrainable=$isTrainable,)" + return "Multiply(name = $name)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Subtract.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Subtract.kt index 3ad6d879c..cc649e7f1 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Subtract.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/merge/Subtract.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -26,6 +26,6 @@ public class Subtract(name: String = "") : AbstractMerge("SubtractLayer", name) } override fun toString(): String { - return "Subtract(name = $name, isTrainable=$isTrainable,)" + return "Subtract(name = $name)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/normalization/BatchNorm.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/normalization/BatchNorm.kt index d8feacb0b..a67b0d9a2 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/normalization/BatchNorm.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/normalization/BatchNorm.kt @@ -5,16 +5,11 @@ package org.jetbrains.kotlinx.dl.api.core.layer.normalization -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer import org.jetbrains.kotlinx.dl.api.core.initializer.Ones import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros -import org.jetbrains.kotlinx.dl.api.core.layer.KVariable -import org.jetbrains.kotlinx.dl.api.core.layer.Layer -import org.jetbrains.kotlinx.dl.api.core.layer.NoGradients -import org.jetbrains.kotlinx.dl.api.core.layer.createVariable +import org.jetbrains.kotlinx.dl.api.core.layer.* import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer -import org.jetbrains.kotlinx.dl.api.core.shape.numElements import org.jetbrains.kotlinx.dl.api.core.util.batchNormBetaVarName import org.jetbrains.kotlinx.dl.api.core.util.batchNormGammaVarName import org.jetbrains.kotlinx.dl.api.core.util.batchNormMovingMeanVarName @@ -56,31 +51,27 @@ public class BatchNorm( public val movingMeanInitializer: Initializer = Zeros(), public val movingVarianceInitializer: Initializer = Ones(), name: String = "", -) : Layer(name), NoGradients { +) : Layer(name), NoGradients, ParametrizedLayer { internal var gamma: KVariable? = null internal var beta: KVariable? = null internal lateinit var movingMean: KVariable internal lateinit var movingVariance: KVariable - init { - isTrainable = false - } + override val variables: List + get() = listOfNotNull(gamma, beta, movingMean, movingVariance) - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + override fun build(tf: Ops, inputShape: Shape) { // Compute shapes of kernel and bias matrices val weightShape = Shape.make(inputShape.size(axis[0])) if (name.isEmpty()) throw RuntimeException("Cannot build BatchNorm layer, because of empty name") - isTrainable = false // TODO: add isTrainable to addWeight method as a flag val fanIn = Int.MIN_VALUE val fanOut = Int.MIN_VALUE movingMean = createVariable( tf, - kGraph, batchNormMovingMeanVarName(name), - isTrainable, weightShape, fanIn, fanOut, @@ -90,9 +81,7 @@ public class BatchNorm( movingVariance = createVariable( tf, - kGraph, batchNormMovingVarianceVarName(name), - isTrainable, weightShape, fanIn, fanOut, @@ -103,9 +92,7 @@ public class BatchNorm( if (scale) { gamma = createVariable( tf, - kGraph, batchNormGammaVarName(name), - isTrainable, weightShape, fanIn, fanOut, @@ -117,9 +104,7 @@ public class BatchNorm( if (center) { beta = createVariable( tf, - kGraph, batchNormBetaVarName(name), - isTrainable, weightShape, fanIn, fanOut, @@ -191,13 +176,6 @@ public class BatchNorm( "movingMeanShapeArray=${movingMean.shape}, movingVarianceShapeArray=${movingVariance.shape})" } - override var weights: Map> - get() = extractWeights(gamma, beta, movingMean, movingVariance) - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int - get() = listOfNotNull(gamma, beta, movingMean, movingVariance).sumOf { it.shape.numElements() }.toInt() } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool1D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool1D.kt index f04539f89..f52087c8f 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool1D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool1D.kt @@ -1,11 +1,10 @@ /* -* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. -* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. -*/ + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.layer.requireArraySize @@ -46,11 +45,6 @@ public class AvgPool1D( override val hasActivation: Boolean get() = false - override val paramCount: Int - get() = 0 - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) init { requireArraySize(poolSize, 3, "poolSize") @@ -60,7 +54,7 @@ public class AvgPool1D( } } - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit + override fun build(tf: Ops, inputShape: Shape): Unit = Unit override fun computeOutputShape(inputShape: Shape): Shape { var steps = inputShape.size(1) @@ -95,6 +89,6 @@ public class AvgPool1D( } override fun toString(): String { - return "AvgPool1D(name = $name, isTrainable=$isTrainable, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" + return "AvgPool1D(name = $name, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool2D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool2D.kt index 6d5a698d0..b6ae46be9 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool2D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool2D.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.layer.toLongList @@ -44,7 +43,7 @@ public class AvgPool2D( name = name ) - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit + override fun build(tf: Ops, inputShape: Shape): Unit = Unit override fun computeOutputShape(inputShape: Shape): Shape { var rows = inputShape.size(1) @@ -76,14 +75,8 @@ public class AvgPool2D( } override fun toString(): String { - return "AvgPool2D(name = $name, isTrainable=$isTrainable, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" + return "AvgPool2D(name = $name, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool3D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool3D.kt index bc5748501..a44e58314 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool3D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/AvgPool3D.kt @@ -1,11 +1,10 @@ /* -* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. -* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. -*/ + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.layer.requireArraySize @@ -46,11 +45,6 @@ public class AvgPool3D( override val hasActivation: Boolean get() = false - override val paramCount: Int - get() = 0 - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) init { requireArraySize(poolSize, 5, "poolSize") @@ -60,7 +54,7 @@ public class AvgPool3D( } } - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit + override fun build(tf: Ops, inputShape: Shape): Unit = Unit override fun computeOutputShape(inputShape: Shape): Shape { val dim1 = convOutputLength(inputShape.size(1), poolSize[1], padding, strides[1]) @@ -85,6 +79,6 @@ public class AvgPool3D( } override fun toString(): String { - return "AvgPool3D(name = $name, isTrainable=$isTrainable, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" + return "AvgPool3D(name = $name, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool1D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool1D.kt index 0bd8825c5..19a5f1c94 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool1D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool1D.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.util.TF import org.tensorflow.Operand @@ -23,7 +22,7 @@ import org.tensorflow.op.Ops public class GlobalAvgPool1D( name: String = "" ) : Layer(name) { - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun computeOutputShape(inputShape: Shape): Shape { return Shape.make(inputShape.size(0), inputShape.size(2)) @@ -41,14 +40,8 @@ public class GlobalAvgPool1D( } override fun toString(): String { - return "GlobalAvgPool1D(name = $name, isTrainable=$isTrainable, hasActivation=$hasActivation)" + return "GlobalAvgPool1D(name = $name, hasActivation=$hasActivation)" } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool2D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool2D.kt index 527b4e1bf..2ecec0ee0 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool2D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool2D.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.util.TF import org.tensorflow.Operand @@ -29,7 +28,7 @@ import org.tensorflow.op.Ops public class GlobalAvgPool2D( name: String = "" ) : Layer(name) { - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun computeOutputShape(inputShape: Shape): Shape { return Shape.make(inputShape.size(0), inputShape.size(3)) // if (this.dataFormat == 'channelsLast') @@ -45,14 +44,8 @@ public class GlobalAvgPool2D( } override fun toString(): String { - return "GlobalAvgPool2D(name = $name, isTrainable=$isTrainable, hasActivation=$hasActivation)" + return "GlobalAvgPool2D(name = $name, hasActivation=$hasActivation)" } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool3D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool3D.kt index 75d6a195a..8c548b86c 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool3D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalAvgPool3D.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.util.TF import org.tensorflow.Operand @@ -26,7 +25,7 @@ import org.tensorflow.op.Ops public class GlobalAvgPool3D( name: String = "" ) : Layer(name) { - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun computeOutputShape(inputShape: Shape): Shape { return Shape.make(inputShape.size(0), inputShape.size(4)) @@ -42,14 +41,8 @@ public class GlobalAvgPool3D( } override fun toString(): String { - return "GlobalAvgPool3D(name = $name, isTrainable=$isTrainable, hasActivation=$hasActivation)" + return "GlobalAvgPool3D(name = $name, hasActivation=$hasActivation)" } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool1D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool1D.kt index 77c31581a..11a592f35 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool1D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool1D.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.tensorflow.Operand import org.tensorflow.Shape @@ -22,13 +21,8 @@ public class GlobalMaxPool1D( override val hasActivation: Boolean get() = false - override val paramCount: Int - get() = 0 - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun computeOutputShape(inputShape: Shape): Shape { return Shape.make(inputShape.size(0), inputShape.size(2)) @@ -44,6 +38,6 @@ public class GlobalMaxPool1D( } override fun toString(): String { - return "GlobalMaxPool1D(name = $name, isTrainable=$isTrainable, hasActivation=$hasActivation)" + return "GlobalMaxPool1D(name = $name, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool2D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool2D.kt index 5aa148dcd..ca8cfa256 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool2D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool2D.kt @@ -1,11 +1,10 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.tensorflow.Operand import org.tensorflow.Shape @@ -22,13 +21,8 @@ public class GlobalMaxPool2D( override val hasActivation: Boolean get() = false - override val paramCount: Int - get() = 0 - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun computeOutputShape(inputShape: Shape): Shape { return Shape.make(inputShape.size(0), inputShape.size(3)) @@ -44,6 +38,6 @@ public class GlobalMaxPool2D( } override fun toString(): String { - return "GlobalMaxPool2D(name = $name, isTrainable=$isTrainable, hasActivation=$hasActivation)" + return "GlobalMaxPool2D(name = $name, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool3D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool3D.kt index 8fe83d357..62b5fb3b1 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool3D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/GlobalMaxPool3D.kt @@ -1,11 +1,10 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.tensorflow.Operand import org.tensorflow.Shape @@ -22,13 +21,8 @@ public class GlobalMaxPool3D( override val hasActivation: Boolean get() = false - override val paramCount: Int - get() = 0 - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun computeOutputShape(inputShape: Shape): Shape { return Shape.make(inputShape.size(0), inputShape.size(4)) @@ -44,6 +38,6 @@ public class GlobalMaxPool3D( } override fun toString(): String { - return "GlobalMaxPool3D(name = $name, isTrainable=$isTrainable, hasActivation=$hasActivation)" + return "GlobalMaxPool3D(name = $name, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool1D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool1D.kt index 42ecb1fc7..f9927aeef 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool1D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool1D.kt @@ -1,11 +1,10 @@ /* -* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. -* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. -*/ + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.layer.requireArraySize @@ -46,11 +45,6 @@ public class MaxPool1D( override val hasActivation: Boolean get() = false - override val paramCount: Int - get() = 0 - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) init { requireArraySize(poolSize, 3, "poolSize") @@ -60,7 +54,7 @@ public class MaxPool1D( } } - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun computeOutputShape(inputShape: Shape): Shape { val steps = convOutputLength(inputShape.size(1), poolSize[1], padding, strides[1]) @@ -106,6 +100,6 @@ public class MaxPool1D( } override fun toString(): String { - return "MaxPool1D(name = $name, isTrainable=$isTrainable, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" + return "MaxPool1D(name = $name, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool2D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool2D.kt index 8fd83a8c3..de9cbfe23 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool2D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool2D.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.shape.convOutputLength @@ -43,7 +42,7 @@ public class MaxPool2D( name = name ) - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun computeOutputShape(inputShape: Shape): Shape { var rows = inputShape.size(1) @@ -77,14 +76,8 @@ public class MaxPool2D( } override fun toString(): String { - return "MaxPool2D(name = $name, isTrainable=$isTrainable, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" + return "MaxPool2D(name = $name, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool3D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool3D.kt index e6e11c12a..7c97c4930 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool3D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/pooling/MaxPool3D.kt @@ -1,6 +1,10 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + package org.jetbrains.kotlinx.dl.api.core.layer.pooling -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.shape.convOutputLength @@ -37,7 +41,7 @@ public class MaxPool3D( name = name ) - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun computeOutputShape(inputShape: Shape): Shape { var lenDim1: Long = inputShape.size(1) @@ -66,14 +70,8 @@ public class MaxPool3D( } override fun toString(): String { - return "MaxPool3D(name = $name, isTrainable=$isTrainable, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" + return "MaxPool3D(name = $name, poolSize=${poolSize.contentToString()}, strides=${strides.contentToString()}, padding=$padding, hasActivation=$hasActivation)" } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/regularization/Dropout.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/regularization/Dropout.kt index bccdc75e3..fd0d2794d 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/regularization/Dropout.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/regularization/Dropout.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.regularization -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.NoGradients import org.tensorflow.Operand @@ -33,11 +32,8 @@ public class Dropout( public val seed: Long = 12L, name: String = "" ) : Layer(name), NoGradients { - init { - isTrainable = false - } - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + override fun build(tf: Ops, inputShape: Shape) { } override fun computeOutputShape(inputShape: Shape): Shape { @@ -75,14 +71,8 @@ public class Dropout( } override fun toString(): String { - return "Dropout(name = $name, isTrainable=$isTrainable, rate=$rate, seed=$seed, hasActivation=$hasActivation)" + return "Dropout(name = $name, rate=$rate, seed=$seed, hasActivation=$hasActivation)" } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractCropping.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractCropping.kt index 20a32d822..59a090e78 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractCropping.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractCropping.kt @@ -1,11 +1,10 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.reshaping -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.tensorflow.Operand import org.tensorflow.Shape @@ -23,13 +22,8 @@ public abstract class AbstractCropping( ) : Layer(name) { override val hasActivation: Boolean get() = false - override val paramCount: Int - get() = 0 - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun forward( tf: Ops, diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractUpSampling.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractUpSampling.kt index d382115d5..8263b2594 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractUpSampling.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractUpSampling.kt @@ -1,11 +1,10 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.reshaping -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.tensorflow.Operand import org.tensorflow.Shape @@ -27,13 +26,8 @@ public abstract class AbstractUpSampling( override val hasActivation: Boolean get() = false - override val paramCount: Int - get() = 0 - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun forward( tf: Ops, diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractZeroPadding.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractZeroPadding.kt index ed6233eeb..a2547280e 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractZeroPadding.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/AbstractZeroPadding.kt @@ -1,5 +1,5 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -15,14 +15,8 @@ import org.tensorflow.op.Ops public abstract class AbstractZeroPadding( name: String ) : Layer(name) { - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - override val paramCount: Int get() = 0 - override fun forward( tf: Ops, input: Operand, diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping1D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping1D.kt index 3e1b6e3bb..6263561a4 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping1D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping1D.kt @@ -1,5 +1,5 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -59,6 +59,6 @@ public class Cropping1D( } override fun toString(): String { - return "Cropping1D(name = $name, isTrainable=$isTrainable, cropping=${cropping.contentToString()}, hasActivation=$hasActivation)" + return "Cropping1D(name = $name, cropping=${cropping.contentToString()}, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping2D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping2D.kt index 9bde6249b..7c533531f 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping2D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping2D.kt @@ -67,6 +67,6 @@ public class Cropping2D( } override fun toString(): String { - return "Cropping2D(name = $name, isTrainable=$isTrainable, cropping=${cropping.contentDeepToString()}, hasActivation = $hasActivation)" + return "Cropping2D(name = $name, cropping=${cropping.contentDeepToString()}, hasActivation = $hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping3D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping3D.kt index 6fef803ba..b45b3d5a3 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping3D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Cropping3D.kt @@ -70,6 +70,6 @@ public class Cropping3D( } override fun toString(): String { - return "Cropping3D(name = $name, isTrainable=$isTrainable, cropping=${cropping.contentDeepToString()}, hasActivation=$hasActivation)" + return "Cropping3D(name = $name, cropping=${cropping.contentDeepToString()}, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Flatten.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Flatten.kt index 08819130c..cb850c805 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Flatten.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Flatten.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.reshaping -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape import org.tensorflow.Operand @@ -23,7 +22,7 @@ import kotlin.math.abs public class Flatten(name: String = "") : Layer(name) { private lateinit var units: Constant - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + override fun build(tf: Ops, inputShape: Shape) { val tensorShape = TensorShape(inputShape) val amountOfNeuronsInFlattenLayer = (tensorShape.numElements() / abs(tensorShape.size(0))).toInt() units = tf.constant(intArrayOf(-1, amountOfNeuronsInFlattenLayer)) @@ -45,14 +44,8 @@ public class Flatten(name: String = "") : Layer(name) { } override fun toString(): String { - return "Flatten(name = $name, isTrainable=$isTrainable, hasActivation=$hasActivation)" + return "Flatten(name = $name, hasActivation=$hasActivation)" } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Permute.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Permute.kt index e4deff457..b51b1f7eb 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Permute.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Permute.kt @@ -1,10 +1,9 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.reshaping -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.shape.shapeFromDims import org.jetbrains.kotlinx.dl.api.core.shape.toLongArray @@ -31,7 +30,7 @@ public class Permute( } } - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) {} + override fun build(tf: Ops, inputShape: Shape) {} override fun computeOutputShape(inputShape: Shape): Shape { val outputShape = inputShape.toLongArray() @@ -54,14 +53,8 @@ public class Permute( } override fun toString(): String { - return "Permute(name = $name, isTrainable=$isTrainable, dims=${dims.contentToString()}, hasActivation=$hasActivation)" + return "Permute(name = $name, dims=${dims.contentToString()}, hasActivation=$hasActivation)" } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt index c04c9afc2..e7b079386 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/RepeatVector.kt @@ -1,11 +1,10 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.reshaping -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.tensorflow.Operand import org.tensorflow.Shape @@ -33,7 +32,7 @@ public class RepeatVector( require(n >= 1) { "Number of repetitions (n) in RepeatVector should be positive but got $n" } } - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape): Unit = Unit + override fun build(tf: Ops, inputShape: Shape): Unit = Unit override fun computeOutputShape(inputShape: Shape): Shape { require(inputShape.numDimensions() == 2) { @@ -54,14 +53,8 @@ public class RepeatVector( } override fun toString(): String { - return "RepeatVector(name = $name, isTrainable=$isTrainable, n=$n, hasActivation=$hasActivation)" + return "RepeatVector(name = $name, n=$n, hasActivation=$hasActivation)" } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Reshape.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Reshape.kt index 1dc1bfac2..8b872c5cc 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Reshape.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/Reshape.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.reshaping -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.tensorflow.Operand import org.tensorflow.Shape @@ -31,7 +30,7 @@ public class Reshape( ) : Layer(name) { private lateinit var units: Constant - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + override fun build(tf: Ops, inputShape: Shape) { units = tf.constant(IntArray(targetShape.size + 1) { if (it == 0) -1 else targetShape[it - 1] }) @@ -51,14 +50,8 @@ public class Reshape( } override fun toString(): String { - return "Reshape(name = $name, isTrainable=$isTrainable, targetShape=$targetShape, hasActivation=$hasActivation)" + return "Reshape(name = $name, targetShape=$targetShape, hasActivation=$hasActivation)" } - override var weights: Map> - get() = emptyMap() - set(value) = assignWeights(value) - override val hasActivation: Boolean get() = false - - override val paramCount: Int get() = 0 } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling1D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling1D.kt index e4cb195d2..731baf8b3 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling1D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling1D.kt @@ -1,5 +1,5 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -47,6 +47,6 @@ public class UpSampling1D( } override fun toString(): String { - return "UpSampling1D(name = $name, isTrainable=$isTrainable, size=$size, hasActivation=$hasActivation)" + return "UpSampling1D(name = $name, size=$size, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling2D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling2D.kt index 1219ef7c0..eff9c2e27 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling2D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling2D.kt @@ -1,5 +1,5 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -73,6 +73,6 @@ public class UpSampling2D( } override fun toString(): String { - return "UpSampling2D(name = $name, isTrainable=$isTrainable, size=${size.contentToString()}, interpolation=$interpolation, hasActivation=$hasActivation)" + return "UpSampling2D(name = $name, size=${size.contentToString()}, interpolation=$interpolation, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling3D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling3D.kt index 49ec62dd6..a634d6a88 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling3D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/UpSampling3D.kt @@ -1,5 +1,5 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -59,7 +59,7 @@ public class UpSampling3D( } override fun toString(): String { - return "UpSampling3D(name = $name, isTrainable=$isTrainable, size=${size.contentToString()}, hasActivation=$hasActivation)" + return "UpSampling3D(name = $name, size=${size.contentToString()}, hasActivation=$hasActivation)" } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding1D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding1D.kt index fba797b5c..77f179a5e 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding1D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding1D.kt @@ -1,11 +1,10 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.reshaping -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.tensorflow.Shape import org.tensorflow.op.Ops @@ -58,7 +57,7 @@ public class ZeroPadding1D : AbstractZeroPadding { this.padding = padding } - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + override fun build(tf: Ops, inputShape: Shape) { this.inputShape = inputShape } @@ -72,6 +71,6 @@ public class ZeroPadding1D : AbstractZeroPadding { } override fun toString(): String { - return "ZeroPadding1D(name = $name, isTrainable=$isTrainable, padding=${padding.contentToString()}, hasActivation=$hasActivation)" + return "ZeroPadding1D(name = $name, padding=${padding.contentToString()}, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding2D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding2D.kt index 3e7454045..a4ac4bc39 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding2D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding2D.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020-2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.reshaping -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.inference.keras.CHANNELS_FIRST import org.jetbrains.kotlinx.dl.api.inference.keras.CHANNELS_LAST import org.tensorflow.Shape @@ -77,7 +76,7 @@ public class ZeroPadding2D : AbstractZeroPadding { this.dataFormat = dataFormat ?: CHANNELS_LAST } - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + override fun build(tf: Ops, inputShape: Shape) { this.inputShape = inputShape } @@ -147,6 +146,6 @@ public class ZeroPadding2D : AbstractZeroPadding { } override fun toString(): String { - return "ZeroPadding2D(name = $name, isTrainable=$isTrainable, padding=${padding.contentToString()}, hasActivation=$hasActivation)" + return "ZeroPadding2D(name = $name, padding=${padding.contentToString()}, hasActivation=$hasActivation)" } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding3D.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding3D.kt index b27bb0206..32292c2ba 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding3D.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/reshaping/ZeroPadding3D.kt @@ -1,11 +1,10 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer.reshaping -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.tensorflow.Shape import org.tensorflow.op.Ops @@ -80,7 +79,7 @@ public class ZeroPadding3D : AbstractZeroPadding { this.padding = padding } - override fun build(tf: Ops, kGraph: KGraph, inputShape: Shape) { + override fun build(tf: Ops, inputShape: Shape) { this.inputShape = inputShape } @@ -131,7 +130,7 @@ public class ZeroPadding3D : AbstractZeroPadding { } override fun toString(): String { - return "ZeroPadding3D(name = $name, isTrainable=$isTrainable, padding=${padding.contentToString()}, hasActivation=$hasActivation)" + return "ZeroPadding3D(name = $name, padding=${padding.contentToString()}, hasActivation=$hasActivation)" } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer.kt index ea6ac5ef0..23095d9a0 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer.kt @@ -39,11 +39,10 @@ public abstract class Optimizer(public val clipGradient: ClipGradientAction) { */ internal fun prepareTargets( graph: KGraph, + weights: List>, tf: Ops, loss: Operand ): List> { - val weights = graph.trainableLayerVariables() - slots = mutableMapOf() val gradients: Gradients = computeGradients(tf, loss, weights) diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt index e89aa36eb..c4c02a6ae 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt @@ -61,6 +61,15 @@ public open class TensorFlowInferenceModel : InferenceModel() { override val inputDimensions: LongArray get() = TODO("Not yet implemented") + /** + * Returns a list of layer variables in this model. + */ + protected open fun variables(): List> = emptyList() + /** + * Returns a list of non-trainable, 'frozen' layer variables in this model. + */ + protected open fun frozenVariables(): List> = emptyList() + /** * Generates output prediction for the input sample. * @@ -197,14 +206,14 @@ public open class TensorFlowInferenceModel : InferenceModel() { /** Checks that the variable with the name [variableName] is an optimizer variable and belongs to the frozen layer. */ protected fun isOptimizerNameAndRelatedToFrozenLayer(variableName: String): Boolean { - return variableName.startsWith("optimizer") && kGraph.frozenLayerVariables() + return variableName.startsWith("optimizer") && frozenVariables() .map { it.ref().op().name() } // extract names .any { variableName.contains(it) } } /** Returns a list of variables paired with their data. */ protected fun getVariablesAndTensors(saveOptimizerState: Boolean): List, Tensor<*>>> { - var variables = kGraph.layerVariables() + var variables = variables() if (saveOptimizerState) { variables = variables + kGraph.optimizerVariables() } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt index 807f87c8b..8f534b65d 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt @@ -11,7 +11,7 @@ import org.jetbrains.kotlinx.dl.api.core.Sequential import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.* import org.jetbrains.kotlinx.dl.api.core.layer.Layer -import org.jetbrains.kotlinx.dl.api.core.layer.NoGradients +import org.jetbrains.kotlinx.dl.api.core.layer.TrainableLayer import org.jetbrains.kotlinx.dl.api.core.layer.activation.* import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.* import org.jetbrains.kotlinx.dl.api.core.layer.core.ActivationLayer @@ -138,8 +138,9 @@ private fun convertToLayer( LAYER_SOFTMAX -> createSoftmaxLayer(kerasLayer.config!!) else -> throw IllegalStateException("${kerasLayer.class_name} is not supported yet!") }.apply { - isTrainable = if (this !is NoGradients) kerasLayer.config?.trainable?:isTrainable - else false + if (this is TrainableLayer) { + isTrainable = kerasLayer.config?.trainable?:isTrainable + } name = kerasLayer.config?.name?:name } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt index 3f7fd6427..91502fc56 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt @@ -16,6 +16,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.* import org.jetbrains.kotlinx.dl.api.core.layer.core.ActivationLayer import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.isTrainable import org.jetbrains.kotlinx.dl.api.core.layer.merge.* import org.jetbrains.kotlinx.dl.api.core.layer.normalization.BatchNorm import org.jetbrains.kotlinx.dl.api.core.layer.pooling.* diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/WeightLoader.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/WeightLoader.kt index 7e5a1e302..ea58d331b 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/WeightLoader.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/WeightLoader.kt @@ -12,6 +12,8 @@ import io.jhdf.dataset.DatasetBase import org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel import org.jetbrains.kotlinx.dl.api.core.layer.KVariable import org.jetbrains.kotlinx.dl.api.core.layer.Layer +import org.jetbrains.kotlinx.dl.api.core.layer.isTrainable +import org.jetbrains.kotlinx.dl.api.core.layer.paramCount import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray import org.jetbrains.kotlinx.dl.api.inference.keras.WeightMappings.BIAS_DATA_PATH_TEMPLATE import org.jetbrains.kotlinx.dl.api.inference.keras.WeightMappings.KERNEL_DATA_PATH_TEMPLATE diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/loaders/TFModelHub.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/loaders/TFModelHub.kt index aa9cbbc9a..9cdbfb61e 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/loaders/TFModelHub.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/loaders/TFModelHub.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -14,6 +14,7 @@ import mu.KotlinLogging import org.jetbrains.kotlinx.dl.api.core.Functional import org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel import org.jetbrains.kotlinx.dl.api.core.Sequential +import org.jetbrains.kotlinx.dl.api.core.freeze import org.jetbrains.kotlinx.dl.api.inference.InferenceModel import java.io.File import java.net.URL @@ -178,9 +179,7 @@ public class TFModelHub(cacheDirectory: File) : ModelHub(cacheDirectory) { } private fun freezeAllLayers(model: GraphTrainableModel): GraphTrainableModel { - for (layer in model.layers) { - layer.isTrainable = false - } + model.freeze() return model } diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/SequentialBasicTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/SequentialBasicTest.kt index 3ae13aaff..4818c8c2e 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/SequentialBasicTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/SequentialBasicTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -16,6 +16,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.pooling.AvgPool2D import org.jetbrains.kotlinx.dl.api.core.layer.pooling.MaxPool2D import org.jetbrains.kotlinx.dl.api.core.layer.regularization.Dropout import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten +import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.loss.SoftmaxCrossEntropyWithLogits import org.jetbrains.kotlinx.dl.api.core.metric.Accuracy diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/SequentialInferenceTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/SequentialInferenceTest.kt index e9661f52e..cca3b583a 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/SequentialInferenceTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/SequentialInferenceTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -11,10 +11,14 @@ import org.jetbrains.kotlinx.dl.api.core.WritingMode import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal import org.jetbrains.kotlinx.dl.api.core.initializer.HeUniform +import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.freeze +import org.jetbrains.kotlinx.dl.api.core.layer.isTrainable +import org.jetbrains.kotlinx.dl.api.core.layer.paramCount import org.jetbrains.kotlinx.dl.api.core.layer.pooling.MaxPool2D import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten import org.jetbrains.kotlinx.dl.api.core.loss.Losses @@ -145,7 +149,7 @@ class SequentialInferenceTest { assertTrue(model.getLayer("conv2d_3") is Conv2D) assertTrue(model.getLayer("conv2d_1").isTrainable) assertTrue(model.getLayer("conv2d_1").hasActivation) - assertTrue(model.getLayer("flatten_5").isTrainable) + Assertions.assertFalse(model.getLayer("flatten_5").isTrainable) Assertions.assertFalse(model.getLayer("flatten_5").hasActivation) assertTrue(model.getLayer("maxPool_2") is MaxPool2D) assertTrue(model.getLayer("maxPool_4") is MaxPool2D) @@ -370,10 +374,7 @@ class SequentialInferenceTest { val model = Sequential.loadDefaultModelConfiguration(tempDir.toFile()) model.use { - for (layer in it.layers) { - if (layer::class == Conv2D::class) - layer.isTrainable = false - } + it.layers.filterIsInstance().forEach(Layer::freeze) it.compile( optimizer = RMSProp(), @@ -567,10 +568,7 @@ class SequentialInferenceTest { model.use { // Freeze conv2d layers, keep dense layers trainable - for (layer in it.layers) { - if (layer::class == Conv2D::class) - layer.isTrainable = false - } + it.layers.filterIsInstance().forEach(Layer::freeze) it.compile( optimizer = optimizer, diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt index 2dee97de6..66680b811 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020-2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -13,6 +13,9 @@ import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense +import org.jetbrains.kotlinx.dl.api.core.layer.freeze +import org.jetbrains.kotlinx.dl.api.core.layer.isTrainable +import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam @@ -133,7 +136,7 @@ class TransferLearningTest : IntegrationTest() { val dense1LayerName = "dense_1" assertEquals(testModel.layers.size, 9) - assertTrue(testModel.getLayer(flattenLayerName).isTrainable) + assertTrue(!testModel.getLayer(flattenLayerName).isTrainable) assertFalse(testModel.getLayer(flattenLayerName).hasActivation) assertTrue(testModel.getLayer(conv2dLayerName) is Conv2D) assertTrue((testModel.getLayer(conv2dLayerName) as Conv2D).kernelInitializer is GlorotNormal) @@ -485,10 +488,7 @@ class TransferLearningTest : IntegrationTest() { val (train, test) = fashionMnist() testModel.use { - for (layer in it.layers) { - if (layer is Conv2D) - layer.isTrainable = false - } + it.layers.filterIsInstance().forEach(Layer::freeze) it.compile( optimizer = Adam(), @@ -552,14 +552,8 @@ class TransferLearningTest : IntegrationTest() { val (train, test) = fashionMnist() testModel.use { - val layerList = mutableListOf() - - for (layer in it.layers) { - if (layer is Conv2D) { - layer.isTrainable = false - layerList.add(layer) - } - } + val layerList = it.layers.filterIsInstance() + layerList.forEach(Layer::freeze) it.compile( optimizer = Adam(), @@ -623,14 +617,7 @@ class TransferLearningTest : IntegrationTest() { val (train, test) = fashionMnist() testModel.use { - val layerList = mutableListOf() - - for (layer in it.layers) { - if (layer is Conv2D) { - layer.isTrainable = false - layerList.add(layer) - } - } + it.layers.filterIsInstance().forEach(Layer::freeze) it.compile( optimizer = Adam(), diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ActivationLayerTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ActivationLayerTest.kt index 8db32c612..bea2c0bc1 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ActivationLayerTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ActivationLayerTest.kt @@ -1,16 +1,18 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.activation.EPS import org.jetbrains.kotlinx.dl.api.core.shape.shapeFromDims import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertEquals -import org.tensorflow.* +import org.tensorflow.EagerSession +import org.tensorflow.Operand +import org.tensorflow.Shape +import org.tensorflow.Tensor import org.tensorflow.op.Ops internal const val IRRELEVANT_INPUT_SIZE = 8 @@ -27,7 +29,7 @@ open class ActivationLayerTest { val tf = Ops.create(it) val inputOp = tf.constant(input) val inputShape = inputOp.asOutput().shape() - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val isTraining = tf.constant(true) val numberOfLosses = tf.constant(1.0f) @@ -62,7 +64,7 @@ open class ActivationLayerTest { EagerSession.create().use { val tf = Ops.create(it) val inputOp = tf.constant(input) - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val isTraining = tf.constant(true) val numberOfLosses = tf.constant(1.0f) val output = layer.forward(tf, inputOp, isTraining, numberOfLosses).asOutput().tensor() diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/AvgPool1DTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/AvgPool1DTest.kt index 4d89a99b8..b8127cb38 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/AvgPool1DTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/AvgPool1DTest.kt @@ -1,18 +1,16 @@ /* -* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. -* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. -*/ + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.layer.pooling.AvgPool1D import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test import org.tensorflow.EagerSession -import org.tensorflow.Graph import org.tensorflow.Shape import org.tensorflow.op.Ops @@ -50,7 +48,7 @@ internal class AvgPool1DTest { EagerSession.create().use { val tf = Ops.create() val inputShape = Shape.make(input.size.toLong(), input[0].size.toLong(), input[0][0].size.toLong()) - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val inputOp = tf.constant(input) val isTraining = tf.constant(true) @@ -114,7 +112,7 @@ internal class AvgPool1DTest { EagerSession.create().use { val tf = Ops.create() val inputShape = Shape.make(input.size.toLong(), input[0].size.toLong(), input[0][0].size.toLong()) - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val inputOp = tf.constant(input) val isTraining = tf.constant(true) @@ -174,7 +172,7 @@ internal class AvgPool1DTest { EagerSession.create().use { val tf = Ops.create() val inputShape = Shape.make(input.size.toLong(), input[0].size.toLong(), input[0][0].size.toLong()) - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val inputOp = tf.constant(input) val isTraining = tf.constant(true) diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/AvgPool3DTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/AvgPool3DTest.kt index 62926874b..1f3af1f07 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/AvgPool3DTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/AvgPool3DTest.kt @@ -1,11 +1,10 @@ /* -* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. -* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. -*/ + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.layer.pooling.AvgPool3D import org.jetbrains.kotlinx.dl.api.core.shape.shape @@ -13,7 +12,6 @@ import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test import org.tensorflow.EagerSession -import org.tensorflow.Graph import org.tensorflow.Shape import org.tensorflow.op.Ops @@ -72,7 +70,7 @@ internal class AvgPool3DTest { EagerSession.create().use { val tf = Ops.create() - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val inputOp = tf.constant(input) val isTraining = tf.constant(true) @@ -145,7 +143,7 @@ internal class AvgPool3DTest { EagerSession.create().use { val tf = Ops.create() - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val inputOp = tf.constant(input) val isTraining = tf.constant(true) @@ -202,7 +200,7 @@ internal class AvgPool3DTest { EagerSession.create().use { val tf = Ops.create() - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val inputOp = tf.constant(input) val isTraining = tf.constant(true) diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ConvLayerTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ConvLayerTest.kt index d19f4f2c4..200fbeb36 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ConvLayerTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ConvLayerTest.kt @@ -5,7 +5,6 @@ package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.activation.EPS import org.jetbrains.kotlinx.dl.api.core.shape.flattenFloats import org.jetbrains.kotlinx.dl.api.core.shape.shape @@ -26,24 +25,22 @@ open class ConvLayerTest { Graph().use { graph -> Session(graph).use { session -> val tf = Ops.create(graph) - KGraph(graph.toGraphDef()).use { kGraph -> - val inputOp = tf.constant(input.shape.toLongArray(), FloatBuffer.wrap(input.flattenFloats())) - val isTraining = tf.constant(true) - val numberOfLosses = tf.constant(1.0f) + val inputOp = tf.constant(input.shape.toLongArray(), FloatBuffer.wrap(input.flattenFloats())) + val isTraining = tf.constant(true) + val numberOfLosses = tf.constant(1.0f) - layer.build(tf, kGraph, input.shape) - layer.computeOutputShape(input.shape) - val output = layer.forward(tf, inputOp, isTraining, numberOfLosses).asOutput() - kGraph.initializeGraphVariables(session) - session.runner().fetch(output).run().first().use { outputTensor -> - val outputShape = outputTensor.shape() - val expectedShape = expectedOutput.shape.toLongArray() - assertArrayEquals(expectedShape, outputShape) + layer.build(tf, input.shape) + layer.computeOutputShape(input.shape) + val output = layer.forward(tf, inputOp, isTraining, numberOfLosses).asOutput() + (layer as? ParametrizedLayer)?.initialize(session) + session.runner().fetch(output).run().first().use { outputTensor -> + val outputShape = outputTensor.shape() + val expectedShape = expectedOutput.shape.toLongArray() + assertArrayEquals(expectedShape, outputShape) - val result = outputTensor.convertTensorToFlattenFloatArray() - val expected = expectedOutput.flattenFloats() - assertArrayEquals(expected, result, EPS) - } + val result = outputTensor.convertTensorToFlattenFloatArray() + val expected = expectedOutput.flattenFloats() + assertArrayEquals(expected, result, EPS) } } } diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/GlobalMaxPool1DTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/GlobalMaxPool1DTest.kt index 9f04dfcf3..769af368f 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/GlobalMaxPool1DTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/GlobalMaxPool1DTest.kt @@ -1,17 +1,15 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.pooling.GlobalMaxPool1D import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test import org.tensorflow.EagerSession -import org.tensorflow.Graph import org.tensorflow.Shape import org.tensorflow.op.Ops @@ -37,7 +35,7 @@ internal class GlobalMaxPool1DTest { EagerSession.create().use { val tf = Ops.create() val inputShape = Shape.make(input.size.toLong(), input[0].size.toLong(), input[0][0].size.toLong()) - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val inputOp = tf.constant(input) val isTraining = tf.constant(true) diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/LayerTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/LayerTest.kt index bb65023c9..8e5daad28 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/LayerTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/LayerTest.kt @@ -1,6 +1,10 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.shape.flattenFloats import org.jetbrains.kotlinx.dl.api.core.shape.shape import org.jetbrains.kotlinx.dl.api.core.shape.shapeFromDims @@ -24,10 +28,9 @@ open class LayerTest { tf: Ops, layer: Layer, input: Array<*>, - kGraph: KGraph, ): Output<*> { val inputShape = input.shape - layer.build(tf, kGraph, inputShape) + layer.build(tf, inputShape) val inputOp = getInputOp(tf, input) val isTraining = tf.constant(true) val numberOfLosses = tf.constant(1.0f) @@ -41,8 +44,7 @@ open class LayerTest { ): Tensor<*> { EagerSession.create().use { val tf = Ops.create() - val kGraph = KGraph(Graph().toGraphDef()) - val outputOp = getLayerOutputOp(tf, layer, input, kGraph) + val outputOp = getLayerOutputOp(tf, layer, input) return outputOp.tensor() } } @@ -54,12 +56,9 @@ open class LayerTest { Graph().use { graph -> Session(graph).use { session -> val tf = Ops.create(graph) - KGraph(graph.toGraphDef()).use { kGraph -> - val outputOp = getLayerOutputOp(tf, layer, input, kGraph) - kGraph.initializeGraphVariables(session) - val outputTensor = session.runner().fetch(outputOp).run().first() - return outputTensor - } + val outputOp = getLayerOutputOp(tf, layer, input) + (layer as? ParametrizedLayer)?.initialize(session) + return session.runner().fetch(outputOp).run().first() } } } diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/MaxPool1DTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/MaxPool1DTest.kt index 26de94ca7..3fe4f1ba4 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/MaxPool1DTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/MaxPool1DTest.kt @@ -1,18 +1,16 @@ /* -* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. -* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. -*/ + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.layer.pooling.MaxPool1D import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test import org.tensorflow.EagerSession -import org.tensorflow.Graph import org.tensorflow.Shape import org.tensorflow.op.Ops @@ -48,7 +46,7 @@ internal class MaxPool1DTest { EagerSession.create().use { val tf = Ops.create() val inputShape = Shape.make(input.size.toLong(), input[0].size.toLong(), input[0][0].size.toLong()) - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val inputOp = tf.constant(input) val isTraining = tf.constant(true) @@ -111,7 +109,7 @@ internal class MaxPool1DTest { EagerSession.create().use { val tf = Ops.create() val inputShape = Shape.make(input.size.toLong(), input[0].size.toLong(), input[0][0].size.toLong()) - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val inputOp = tf.constant(input) val isTraining = tf.constant(true) @@ -170,7 +168,7 @@ internal class MaxPool1DTest { EagerSession.create().use { val tf = Ops.create() val inputShape = Shape.make(input.size.toLong(), input[0].size.toLong(), input[0][0].size.toLong()) - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val inputOp = tf.constant(input) val isTraining = tf.constant(true) diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/MaxPool3DTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/MaxPool3DTest.kt index a88fc4106..db6861fcd 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/MaxPool3DTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/MaxPool3DTest.kt @@ -1,16 +1,19 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.pooling.MaxPool3D import org.jetbrains.kotlinx.dl.api.core.loss.EPS import org.jetbrains.kotlinx.dl.api.core.shape.shapeFromDims import org.junit.jupiter.api.Assertions.assertArrayEquals +import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test import org.tensorflow.EagerSession -import org.tensorflow.Graph -import org.tensorflow.op.Ops import org.tensorflow.Shape -import org.junit.jupiter.api.Assertions.assertEquals +import org.tensorflow.op.Ops internal class MaxPool3DTest { @@ -73,7 +76,7 @@ internal class MaxPool3DTest { ) EagerSession.create().use { val tf = Ops.create() - layer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + layer.build(tf, inputShape) val inputOp = tf.constant(input) val isTraining = tf.constant(true) val numOfLosses = tf.constant(1.0f) diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt index cf9f40e84..0e2945802 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/RepeatVectorLayerTest.kt @@ -1,16 +1,14 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.RepeatVector import org.jetbrains.kotlinx.dl.api.core.shape.toIntArray import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test -import org.tensorflow.Graph import org.tensorflow.Output import org.tensorflow.Shape import org.tensorflow.op.Ops @@ -48,7 +46,7 @@ internal class RepeatVectorLayerTest { // TODO: generalise this for Layer, see https://github.com/JetBrains/KotlinDL/issues/145 private operator fun RepeatVector.invoke(input: Array): Output = Ops.create().let { tf -> - build(tf, KGraph(Graph().toGraphDef()), Shape.make(10, 10)) + build(tf, Shape.make(10, 10)) val inputOp = tf.constant(input) val isTraining = tf.constant(true) val numberOfLosses = tf.constant(1.0f) diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ZeroPadding2DTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ZeroPadding2DTest.kt index d7f077d82..fb40b585f 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ZeroPadding2DTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/layer/ZeroPadding2DTest.kt @@ -1,11 +1,10 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package org.jetbrains.kotlinx.dl.api.core.layer -import org.jetbrains.kotlinx.dl.api.core.KGraph import org.jetbrains.kotlinx.dl.api.core.initializer.Ones import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.ZeroPadding2D import org.jetbrains.kotlinx.dl.api.core.shape.shapeFromDims @@ -13,7 +12,6 @@ import org.jetbrains.kotlinx.dl.api.inference.keras.CHANNELS_LAST import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test import org.tensorflow.EagerSession -import org.tensorflow.Graph import org.tensorflow.Shape import org.tensorflow.op.Ops @@ -35,7 +33,7 @@ internal class ZeroPadding2DTest { val paddingLayer = ZeroPadding2D(padding, dataFormat = CHANNELS_LAST) val inputDimensions = tf.constant(inputDimensionsArray) val input = Ones().initialize(1, 1, tf, inputDimensions, "test_input") - paddingLayer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + paddingLayer.build(tf, inputShape) val isTraining = tf.constant(true) val numberOfLosses = tf.constant(1.0f) val output = paddingLayer.forward(tf, input, isTraining, numberOfLosses).asOutput().tensor() @@ -92,7 +90,7 @@ internal class ZeroPadding2DTest { val paddingLayer = ZeroPadding2D(paddingArray, dataFormat = CHANNELS_LAST) val inputDimensions = tf.constant(inputDimensionsArray) val input = Ones().initialize(1, 1, tf, inputDimensions, "test_input") - paddingLayer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + paddingLayer.build(tf, inputShape) val isTraining = tf.constant(true) val numberOfLosses = tf.constant(1.0f) val output = paddingLayer.forward(tf, input, isTraining, numberOfLosses).asOutput().tensor() @@ -151,7 +149,7 @@ internal class ZeroPadding2DTest { val paddingLayer = ZeroPadding2D(paddingArray, dataFormat = CHANNELS_LAST) val inputDimensions = tf.constant(inputDimensionsArray) val input = Ones().initialize(1, 1, tf, inputDimensions, "test_input") - paddingLayer.build(tf, KGraph(Graph().toGraphDef()), inputShape) + paddingLayer.build(tf, inputShape) val isTraining = tf.constant(true) val numberOfLosses = tf.constant(1.0f) val output = paddingLayer.forward(tf, input, isTraining, numberOfLosses).asOutput().tensor() diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/ml/RegressionTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/ml/RegressionTest.kt index 3b866a3d7..d699fc232 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/ml/RegressionTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/ml/RegressionTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.SGD diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/FunctionalModelTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/FunctionalModelTest.kt index 1020bbf71..805f05638 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/FunctionalModelTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/FunctionalModelTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -13,6 +13,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.layer.merge.Add import org.jetbrains.kotlinx.dl.api.core.layer.pooling.GlobalAvgPool2D import org.jetbrains.kotlinx.dl.api.core.layer.pooling.MaxPool2D @@ -71,7 +72,7 @@ internal class FunctionalModelTest { @Test fun summary() { - dense_2.isTrainable = false + dense_2.freeze() checkFunctionalModelAfterCompilation(correctTestModel) checkFunctionalModelAfterCompilation(correctTestModelWithUnsortedLayers) diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt index 5b7874805..6e6cdee4a 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt @@ -15,6 +15,8 @@ import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.ConvPadding import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.isTrainable +import org.jetbrains.kotlinx.dl.api.core.layer.paramCount import org.jetbrains.kotlinx.dl.api.core.layer.pooling.MaxPool2D import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten import org.jetbrains.kotlinx.dl.api.core.loss.Losses @@ -96,7 +98,7 @@ internal class SequentialModelTest { assertTrue(correctTestModel.getLayer("conv2d_2") is Conv2D) assertTrue(correctTestModel.getLayer("conv2d_1").isTrainable) assertTrue(correctTestModel.getLayer("conv2d_1").hasActivation) - assertTrue(correctTestModel.getLayer("flatten_1").isTrainable) + assertFalse(correctTestModel.getLayer("flatten_1").isTrainable) assertFalse(correctTestModel.getLayer("flatten_1").hasActivation) assertArrayEquals(correctTestModel.inputLayer.packedDims, longArrayOf(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)) } diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/MergeLayersImportExportTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/MergeLayersImportExportTest.kt index a4ca9f445..1c1b4ec3e 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/MergeLayersImportExportTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/MergeLayersImportExportTest.kt @@ -8,6 +8,7 @@ package org.jetbrains.kotlinx.dl.api.inference.keras import org.jetbrains.kotlinx.dl.api.core.dsl.functional import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.layer.merge.* import org.junit.jupiter.api.Test @@ -102,8 +103,8 @@ class MergeLayersImportExportTest { functional { layers { val input = +Input(10) - val dense1 = +Dense(name = "Dense1", outputSize = 5)(input).apply { isTrainable = false } - val dense2 = +Dense(name = "Dense2", outputSize = 5)(input).apply { isTrainable = false } + val dense1 = +Dense(name = "Dense1", outputSize = 5)(input).apply { freeze() } + val dense2 = +Dense(name = "Dense2", outputSize = 5)(input).apply { freeze() } +Concatenate(name = "test_concatenate", axis = 1)(dense1, dense2) } } diff --git a/docs/transfer_learning.md b/docs/transfer_learning.md index 0c2df1dd3..72cee4a86 100644 --- a/docs/transfer_learning.md +++ b/docs/transfer_learning.md @@ -79,13 +79,8 @@ So this is what we will do: - The last layer of the original model classifies 1000 classes, but we only have two, so we'll dispose of it, and add another final prediction layer (and one intermediate dense layer to achieve better accuracy). ```kotlin -val layers = mutableListOf() - -for (layer in model.layers.dropLast(1)) { - layer.isTrainable = false - layers.add(layer) -} -layers.forEach { it.isTrainable = false } +val layers = model.layers.dropLast(1).toMutableList() +layers.forEach(Layer::freeze) layers.add( Dense( diff --git a/examples/src/main/kotlin/examples/experimental/batchnorm/TransferLearningWithBatchNorm.kt b/examples/src/main/kotlin/examples/experimental/batchnorm/TransferLearningWithBatchNorm.kt index cadbbd2a3..1c2795351 100644 --- a/examples/src/main/kotlin/examples/experimental/batchnorm/TransferLearningWithBatchNorm.kt +++ b/examples/src/main/kotlin/examples/experimental/batchnorm/TransferLearningWithBatchNorm.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -7,6 +7,8 @@ package examples.experimental.batchnorm import io.jhdf.HdfFile import org.jetbrains.kotlinx.dl.api.core.Sequential +import org.jetbrains.kotlinx.dl.api.core.freeze +import org.jetbrains.kotlinx.dl.api.core.layer.unfreeze import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam @@ -36,10 +38,8 @@ fun main() { model.use { - for (layer in it.layers) { - layer.isTrainable = false - } - it.layers.last().isTrainable = true + it.freeze() + it.layers.last().unfreeze() it.compile( optimizer = Adam(), diff --git a/examples/src/main/kotlin/examples/inference/mnist/LeNetMnistExportImportWithJSON.kt b/examples/src/main/kotlin/examples/inference/mnist/LeNetMnistExportImportWithJSON.kt index 8d5864c0d..ee2cb9eae 100644 --- a/examples/src/main/kotlin/examples/inference/mnist/LeNetMnistExportImportWithJSON.kt +++ b/examples/src/main/kotlin/examples/inference/mnist/LeNetMnistExportImportWithJSON.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -9,7 +9,9 @@ import examples.inference.lenet5 import org.jetbrains.kotlinx.dl.api.core.SavingFormat import org.jetbrains.kotlinx.dl.api.core.Sequential import org.jetbrains.kotlinx.dl.api.core.WritingMode +import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.RMSProp @@ -67,10 +69,7 @@ fun lenetOnMnistExportImportToJson() { model.use { // Freeze conv2d layers, keep dense layers trainable - for (layer in it.layers) { - if (layer::class == Conv2D::class) - layer.isTrainable = false - } + it.layers.filterIsInstance().forEach(Layer::freeze) it.compile( optimizer = RMSProp(), diff --git a/examples/src/main/kotlin/examples/inference/optimizers/LeNetMnistExportImportToJSONWithAdamOptimizerState.kt b/examples/src/main/kotlin/examples/inference/optimizers/LeNetMnistExportImportToJSONWithAdamOptimizerState.kt index 5cd01efaa..21ef4634c 100644 --- a/examples/src/main/kotlin/examples/inference/optimizers/LeNetMnistExportImportToJSONWithAdamOptimizerState.kt +++ b/examples/src/main/kotlin/examples/inference/optimizers/LeNetMnistExportImportToJSONWithAdamOptimizerState.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -9,7 +9,9 @@ import examples.inference.lenet5 import org.jetbrains.kotlinx.dl.api.core.SavingFormat import org.jetbrains.kotlinx.dl.api.core.Sequential import org.jetbrains.kotlinx.dl.api.core.WritingMode +import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam @@ -70,10 +72,7 @@ fun lenetOnMnistExportImportToJSONWithAdamOptimizerState() { model.use { // Freeze conv2d layers, keep dense layers trainable - for (layer in it.layers) { - if (layer::class == Conv2D::class) - layer.isTrainable = false - } + it.layers.filterIsInstance().forEach(Layer::freeze) it.compile( optimizer = optimizer, diff --git a/examples/src/main/kotlin/examples/ml/LinearRegression.kt b/examples/src/main/kotlin/examples/ml/LinearRegression.kt index 6553d4812..b2bdc6d69 100644 --- a/examples/src/main/kotlin/examples/ml/LinearRegression.kt +++ b/examples/src/main/kotlin/examples/ml/LinearRegression.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dl.api.core.initializer.GlorotNormal import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam diff --git a/examples/src/main/kotlin/examples/ml/LinearRegressionWithTwoMetrics.kt b/examples/src/main/kotlin/examples/ml/LinearRegressionWithTwoMetrics.kt index 01755b2e5..11d6a668c 100644 --- a/examples/src/main/kotlin/examples/ml/LinearRegressionWithTwoMetrics.kt +++ b/examples/src/main/kotlin/examples/ml/LinearRegressionWithTwoMetrics.kt @@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dl.api.core.initializer.GlorotNormal import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.loss.LossFunction import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.MAE diff --git a/examples/src/main/kotlin/examples/ml/SineRegression.kt b/examples/src/main/kotlin/examples/ml/SineRegression.kt index cde072b2f..47c3a8ea4 100644 --- a/examples/src/main/kotlin/examples/ml/SineRegression.kt +++ b/examples/src/main/kotlin/examples/ml/SineRegression.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal import org.jetbrains.kotlinx.dl.api.core.initializer.HeUniform import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense import org.jetbrains.kotlinx.dl.api.core.layer.core.Input +import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam diff --git a/examples/src/main/kotlin/examples/transferlearning/lenet/Example_4_Additional_training_and_freezing.kt b/examples/src/main/kotlin/examples/transferlearning/lenet/Example_4_Additional_training_and_freezing.kt index 005bf7cc6..1b2f6651c 100644 --- a/examples/src/main/kotlin/examples/transferlearning/lenet/Example_4_Additional_training_and_freezing.kt +++ b/examples/src/main/kotlin/examples/transferlearning/lenet/Example_4_Additional_training_and_freezing.kt @@ -1,12 +1,14 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ package examples.transferlearning.lenet import org.jetbrains.kotlinx.dl.api.core.Sequential +import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam @@ -31,10 +33,7 @@ fun additionalTrainingAndFreezing() { model.use { // Freeze conv2d layers, keep dense layers trainable - for (layer in it.layers) { - if (layer is Conv2D) - layer.isTrainable = false - } + it.layers.filterIsInstance().forEach(Layer::freeze) it.compile( optimizer = Adam(), diff --git a/examples/src/main/kotlin/examples/transferlearning/lenet/Example_5_Additional_training_and_freezing_and_init.kt b/examples/src/main/kotlin/examples/transferlearning/lenet/Example_5_Additional_training_and_freezing_and_init.kt index 807e35bcd..b133a6a33 100644 --- a/examples/src/main/kotlin/examples/transferlearning/lenet/Example_5_Additional_training_and_freezing_and_init.kt +++ b/examples/src/main/kotlin/examples/transferlearning/lenet/Example_5_Additional_training_and_freezing_and_init.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -8,6 +8,7 @@ package examples.transferlearning.lenet import org.jetbrains.kotlinx.dl.api.core.Sequential import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam @@ -30,14 +31,9 @@ fun additionalTrainingAndPartialFreezingAndPartialInitialization() { val model = Sequential.loadModelConfiguration(jsonConfigFile) model.use { - val layerList = mutableListOf() // Freeze conv2d layers, keep dense layers trainable - for (layer in it.layers) { - if (layer::class == Conv2D::class) { - layer.isTrainable = false - layerList.add(layer) - } - } + val layerList = it.layers.filterIsInstance() + layerList.forEach(Layer::freeze) it.compile( optimizer = Adam(), diff --git a/examples/src/main/kotlin/examples/transferlearning/lenet/Example_6_Additional_training_and_new_dense_layers.kt b/examples/src/main/kotlin/examples/transferlearning/lenet/Example_6_Additional_training_and_new_dense_layers.kt index 48d9481c4..8f983b3e1 100644 --- a/examples/src/main/kotlin/examples/transferlearning/lenet/Example_6_Additional_training_and_new_dense_layers.kt +++ b/examples/src/main/kotlin/examples/transferlearning/lenet/Example_6_Additional_training_and_new_dense_layers.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -12,6 +12,7 @@ import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.layer.pooling.MaxPool2D import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten import org.jetbrains.kotlinx.dl.api.core.loss.Losses @@ -39,7 +40,7 @@ fun additionalTrainingAndNewTopDenseLayers() { layers.add(input) for (layer in otherLayers) { if (layer is Conv2D || layer is MaxPool2D) { - layer.isTrainable = false + layer.freeze() layers.add(layer) } } diff --git a/examples/src/main/kotlin/examples/transferlearning/modelhub/mobilenet/additionalTrainingForMobileNet.kt b/examples/src/main/kotlin/examples/transferlearning/modelhub/mobilenet/additionalTrainingForMobileNet.kt index 08f955fb8..4be44d616 100644 --- a/examples/src/main/kotlin/examples/transferlearning/modelhub/mobilenet/additionalTrainingForMobileNet.kt +++ b/examples/src/main/kotlin/examples/transferlearning/modelhub/mobilenet/additionalTrainingForMobileNet.kt @@ -10,6 +10,7 @@ import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.GlorotUniform import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam @@ -78,8 +79,6 @@ fun mobilenetWithAdditionalTraining() { val hdfFile = modelHub.loadWeights(modelType) model.use { - it.layers.last().isTrainable = true - it.compile( optimizer = Adam(), loss = Losses.MAE, @@ -89,12 +88,8 @@ fun mobilenetWithAdditionalTraining() { it.logSummary() } - val layers = mutableListOf() - - for (layer in model.layers) { - layer.isTrainable = false - layers.add(layer) - } + val layers = model.layers.toMutableList() + layers.forEach(Layer::freeze) val lastLayer = layers.last() for (outboundLayer in lastLayer.inboundLayers) diff --git a/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_2_ResNet50_prediction_additional_training.kt b/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_2_ResNet50_prediction_additional_training.kt index 58cf85d1b..ee0b6188f 100644 --- a/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_2_ResNet50_prediction_additional_training.kt +++ b/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_2_ResNet50_prediction_additional_training.kt @@ -10,6 +10,7 @@ import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.GlorotUniform import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam @@ -74,12 +75,8 @@ fun resnet50additionalTraining() { val (train, test) = dataset.split(TRAIN_TEST_SPLIT_RATIO) val hdfFile = modelHub.loadWeights(modelType) - val layers = mutableListOf() - - for (layer in model.layers) { - layer.isTrainable = false - layers.add(layer) - } + val layers = model.layers.toMutableList() + layers.forEach(Layer::freeze) val lastLayer = layers.last() for (outboundLayer in lastLayer.inboundLayers) diff --git a/examples/src/main/kotlin/examples/transferlearning/modelhub/vgg19/Example_3_VGG19_additional_training.kt b/examples/src/main/kotlin/examples/transferlearning/modelhub/vgg19/Example_3_VGG19_additional_training.kt index 4229457c8..a9ed5ab8a 100644 --- a/examples/src/main/kotlin/examples/transferlearning/modelhub/vgg19/Example_3_VGG19_additional_training.kt +++ b/examples/src/main/kotlin/examples/transferlearning/modelhub/vgg19/Example_3_VGG19_additional_training.kt @@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam @@ -78,13 +79,8 @@ fun vgg19additionalTraining() { ).shuffle() val (train, test) = dataset.split(TRAIN_TEST_SPLIT_RATIO) - val layers = mutableListOf() - - for (layer in model.layers.dropLast(1)) { - layer.isTrainable = false - layers.add(layer) - } - layers.forEach { it.isTrainable = false } + val layers = model.layers.dropLast(1).toMutableList() + layers.forEach(Layer::freeze) layers.add( Dense( diff --git a/examples/src/main/kotlin/examples/transferlearning/toyresnet/Example_2_ToyResNet_additional_training.kt b/examples/src/main/kotlin/examples/transferlearning/toyresnet/Example_2_ToyResNet_additional_training.kt index 3ec731aa6..8d133b8f3 100644 --- a/examples/src/main/kotlin/examples/transferlearning/toyresnet/Example_2_ToyResNet_additional_training.kt +++ b/examples/src/main/kotlin/examples/transferlearning/toyresnet/Example_2_ToyResNet_additional_training.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -7,6 +7,8 @@ package examples.transferlearning.toyresnet import org.jetbrains.kotlinx.dl.api.core.Functional +import org.jetbrains.kotlinx.dl.api.core.freeze +import org.jetbrains.kotlinx.dl.api.core.layer.unfreeze import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam @@ -44,10 +46,8 @@ fun main() { println("Accuracy before: $accuracy") - for (layer in it.layers) { - layer.isTrainable = false - } - it.layers.last().isTrainable = true + it.freeze() + it.layers.last().unfreeze() it.fit(dataset = train, epochs = 1, batchSize = 1000) diff --git a/examples/src/main/kotlin/examples/transferlearning/transferLearningRunner.kt b/examples/src/main/kotlin/examples/transferlearning/transferLearningRunner.kt index a22810e1e..aaaa44225 100644 --- a/examples/src/main/kotlin/examples/transferlearning/transferLearningRunner.kt +++ b/examples/src/main/kotlin/examples/transferlearning/transferLearningRunner.kt @@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dl.api.core.activation.Activations import org.jetbrains.kotlinx.dl.api.core.initializer.GlorotUniform import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense +import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.layer.pooling.GlobalAvgPool2D import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics @@ -165,12 +166,8 @@ fun runImageRecognitionTransferLearningOnTopModel( val (train, test) = dataset.split(TRAIN_TEST_SPLIT_RATIO) val hdfFile = modelHub.loadWeights(modelType) - val layers = mutableListOf() - - for (layer in model.layers) { - layer.isTrainable = false - layers.add(layer) - } + val layers = model.layers.toMutableList() + layers.forEach(Layer::freeze) val lastLayer = layers.last() for (outboundLayer in lastLayer.inboundLayers) diff --git a/examples/src/main/kotlin/examples/visualization/LeNetFashionMnistVisualization.kt b/examples/src/main/kotlin/examples/visualization/LeNetFashionMnistVisualization.kt index b6215cfbd..b1a79ac0a 100644 --- a/examples/src/main/kotlin/examples/visualization/LeNetFashionMnistVisualization.kt +++ b/examples/src/main/kotlin/examples/visualization/LeNetFashionMnistVisualization.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -7,12 +7,14 @@ package examples.visualization import examples.inference.lenet5 import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D +import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam import org.jetbrains.kotlinx.dl.dataset.fashionMnist import org.jetbrains.kotlinx.dl.visualization.letsplot.* -import org.jetbrains.kotlinx.dl.visualization.swing.* +import org.jetbrains.kotlinx.dl.visualization.swing.drawActivations +import org.jetbrains.kotlinx.dl.visualization.swing.drawFilters private const val EPOCHS = 1 private const val TRAINING_BATCH_SIZE = 500 diff --git a/examples/src/main/kotlin/examples/visualization/LeNetMnistVisualization.kt b/examples/src/main/kotlin/examples/visualization/LeNetMnistVisualization.kt index e6401e581..14c52e6c1 100644 --- a/examples/src/main/kotlin/examples/visualization/LeNetMnistVisualization.kt +++ b/examples/src/main/kotlin/examples/visualization/LeNetMnistVisualization.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -7,13 +7,18 @@ package examples.visualization import examples.inference.lenet5 import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D +import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.loss.Losses import org.jetbrains.kotlinx.dl.api.core.metric.Metrics import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam import org.jetbrains.kotlinx.dl.api.core.summary.logSummary import org.jetbrains.kotlinx.dl.dataset.mnist -import org.jetbrains.kotlinx.dl.visualization.letsplot.* -import org.jetbrains.kotlinx.dl.visualization.swing.* +import org.jetbrains.kotlinx.dl.visualization.letsplot.columnPlot +import org.jetbrains.kotlinx.dl.visualization.letsplot.filtersPlot +import org.jetbrains.kotlinx.dl.visualization.letsplot.flattenImagePlot +import org.jetbrains.kotlinx.dl.visualization.letsplot.modelActivationOnLayersPlot +import org.jetbrains.kotlinx.dl.visualization.swing.drawActivations +import org.jetbrains.kotlinx.dl.visualization.swing.drawFilters private const val EPOCHS = 1 private const val TRAINING_BATCH_SIZE = 500 diff --git a/visualization/src/main/kotlin/org/jetbrains/kotlinx/dl/visualization/letsplot/PlotConv2D.kt b/visualization/src/main/kotlin/org/jetbrains/kotlinx/dl/visualization/letsplot/PlotConv2D.kt index d5f70685a..7bcb0d1f5 100644 --- a/visualization/src/main/kotlin/org/jetbrains/kotlinx/dl/visualization/letsplot/PlotConv2D.kt +++ b/visualization/src/main/kotlin/org/jetbrains/kotlinx/dl/visualization/letsplot/PlotConv2D.kt @@ -1,5 +1,5 @@ /* - * Copyright 2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2021-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -8,6 +8,7 @@ package org.jetbrains.kotlinx.dl.visualization.letsplot import jetbrains.letsPlot.Figure import org.jetbrains.kotlinx.dl.api.core.TrainableModel import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D +import org.jetbrains.kotlinx.dl.api.core.layer.weights internal val FILTER_LAYERS_PERMUTATION = intArrayOf(1, 0, 2, 3)