Skip to content

Commit

Permalink
Introduce ParametrizedLayer and TrainableLayer interfaces (#345)
Browse files Browse the repository at this point in the history
* Add ParametrizedLayer and TrainableLayer interfaces

* Simplify freezing model layers

* Do not try to unfreeze ActivationLayer since it is not trainable

* Keep variable information in the layers

* Use only trainable variable regularizers for computing loss

* Remove KGraph argument from Layer::build

* Move Layer#weights to extension field

Co-authored-by: Veniamin Viflyantsev <knok16@gmail.com>
  • Loading branch information
juliabeliaeva and knok16 authored Apr 26, 2022
1 parent d21ea93 commit 30c45ae
Show file tree
Hide file tree
Showing 107 changed files with 524 additions and 775 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package org.jetbrains.kotlinx.dl.api.core

import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.core.layer.freeze
import org.jetbrains.kotlinx.dl.api.core.util.sortTopologically
import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.tensorflow.Operand
Expand Down Expand Up @@ -82,7 +84,7 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
// TODO: make a honest copy of models (pretrained and topModel) or of all layers
val pretrainedLayers = pretrainedModel.layers
// Freezing
pretrainedLayers.forEach { it.isTrainable = false }
pretrainedLayers.forEach { it.freeze() }

val layers = mutableListOf<Layer>()
layers += pretrainedLayers
Expand Down Expand Up @@ -255,7 +257,7 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
inputLayer.computeOutputShape()

layers.filter { it !is Input }.forEach {
it.buildFromInboundLayers(tf, kGraph)
it.buildFromInboundLayers(tf)

val outputShape = it.computeOutputShapeFromInboundLayers()
val dims = outputShape.dims()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import mu.KotlinLogging
import org.jetbrains.kotlinx.dl.api.core.callback.Callback
import org.jetbrains.kotlinx.dl.api.core.exception.RepeatableLayerNameException
import org.jetbrains.kotlinx.dl.api.core.history.*
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.NoGradients
import org.jetbrains.kotlinx.dl.api.core.layer.*
import org.jetbrains.kotlinx.dl.api.core.layer.core.ActivationLayer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
Expand All @@ -36,6 +35,7 @@ import org.jetbrains.kotlinx.dl.dataset.Dataset
import org.tensorflow.*
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Placeholder
import org.tensorflow.op.core.Variable
import java.io.File
import java.io.FileNotFoundException
import java.nio.FloatBuffer
Expand Down Expand Up @@ -114,6 +114,9 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
session = Session(kGraph.tfGraph)
}

override fun variables(): List<Variable<Float>> = layers.variables().map { it.variable }
override fun frozenVariables(): List<Variable<Float>> = layers.frozenVariables().map { it.variable }

/** Helper method for preprocessing layer names and layer validation. */
internal companion object {
internal fun preProcessLayerNames(layers: Array<out Layer>) {
Expand Down Expand Up @@ -143,8 +146,6 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
override fun compile(optimizer: Optimizer, loss: LossFunction, metrics: List<Metric>) {
check(!isModelCompiled) { "The model is compiled already. Graph is created. Create new model and compile it." }

validateModelArchitecture()

this.loss = loss
this.metrics = metrics
this.optimizer = optimizer
Expand Down Expand Up @@ -172,7 +173,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel

yPredOp = forward(xOp, inputLayer)
lossOp = buildLossFunction(loss)
targets = optimizer.prepareTargets(kGraph, tf, lossOp)
targets = optimizer.prepareTargets(kGraph, layers.trainableVariables().map { it.variable }, tf, lossOp)

predictionOp = when (loss) {
is SoftmaxCrossEntropyWithLogits -> tf.withName(OUTPUT_NAME).nn.softmax(yPredOp)
Expand All @@ -190,9 +191,9 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
val basicLoss = loss.apply(tf, yPredOp, yTrueOp, numberOfLossesOp)
var totalLoss = basicLoss
// TODO: probably regularization output should be divided on numberOfLossesOp and changed together with loss before averaging
kGraph.variableRegularizers.forEach { (variable, regularizer) ->
run {
totalLoss = tf.math.add(totalLoss, regularizer.apply(tf, variable))
layers.trainableVariables().forEach { variable ->
variable.regularizer?.let { regularizer ->
totalLoss = tf.math.add(totalLoss, regularizer.apply(tf, variable.variable))
}
}
return tf.withName(TRAINING_LOSS).identity(totalLoss)
Expand All @@ -207,23 +208,6 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
compile(optimizer, loss, Metrics.convert(metric))
}

/** Validates architecture. */
private fun validateModelArchitecture() {
layerValidation(layers.toList())

require(layers.none { it is NoGradients && it.isTrainable })
{
"All layers that implements NoGradient interface should be frozen (status isTrainable==false). " +
"But the following layers violates this rule: ${
layers.filter { it is NoGradients && it.isTrainable }.map { it.name }.toTypedArray()
.contentDeepToString()
}"
}
// require(layers.last() is Dense) { "DL architectures are not finished with Dense layer are not supported yet!" }
// require(layers.last().hasActivation()) { "Last layer must have an activation function." }
// require((layers.last() as Dense).activation != Activations.Sigmoid) { "The last dense layer should have Linear activation, alternative activations are not supported yet!" }
}

/** Common method for building the initial part of the model static graph layer by layer via calling build() method on each layer in correct order. */
protected abstract fun buildLayers()

Expand Down Expand Up @@ -277,7 +261,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
check(!isOptimizerVariableInitialized) { "Optimizer variables are initialized already!" }

logger.debug { "Initialization of TensorFlow Graph variables." }
kGraph.initializeGraphVariables(session)
layers.initializeVariables(session)
isModelInitialized = true
}

Expand All @@ -292,7 +276,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
check(isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." }

logger.debug { "Initialization of TensorFlow Graph variables." }
kGraph.initializeGraphVariables(session)
layers.initializeVariables(session)
isModelInitialized = true
isOptimizerVariableInitialized = false
}
Expand All @@ -308,16 +292,9 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
): TrainingHistory {
check(isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." }

for (layer in layers) {
check(!(layer is NoGradients && layer.isTrainable)) {
"Layer $layer has no gradients implementations in TensorFlow and should be non-trainable only. " +
"Set 'isTrainable' to 'false'."
}
}

if (!isModelInitialized) {
logger.debug { "Initialization of TensorFlow Graph variables." }
kGraph.initializeGraphVariables(session)
layers.initializeVariables(session)
isModelInitialized = true
}

Expand Down Expand Up @@ -981,3 +958,11 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
)
}
}

/**
* Freezes weights in all layers in this model, so they won't be changed during training.
* @see [Layer.freeze]
*/
public fun GraphTrainableModel.freeze() {
layers.forEach(Layer::freeze)
}
91 changes: 0 additions & 91 deletions api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/KGraph.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

package org.jetbrains.kotlinx.dl.api.core

import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.tensorflow.Graph
import org.tensorflow.GraphOperation
import org.tensorflow.Operand
import org.tensorflow.Session
import org.tensorflow.op.core.Assign
import org.tensorflow.op.core.AssignAdd
Expand Down Expand Up @@ -39,18 +37,9 @@ public class KGraph(graphDef: ByteArray, prefix: String) : AutoCloseable {
/** A list of initializer to initialize the trainableVariables. */
private val optimizerAssignAddInitializers: MutableList<AssignAdd<*>> = mutableListOf()

/** A list of variables to train. */
private val layerVariables: MutableMap<Variable<Float>, Boolean> = mutableMapOf()

/** A list of variables to train. */
internal val variableRegularizers: MutableMap<Variable<Float>, Regularizer> = mutableMapOf()

/** A list of optimizers' variables. */
private val optimizerVariables: MutableList<Variable<Float>> = mutableListOf()

/** A list of initializer to initialize the trainableVariables. */
private val initializers: MutableMap<String, Assign<Float>> = mutableMapOf()

init {
if (prefix.isEmpty()) {
tfGraph.importGraphDef(graphDef)
Expand Down Expand Up @@ -102,28 +91,6 @@ public class KGraph(graphDef: ByteArray, prefix: String) : AutoCloseable {
return KGraph(tfGraph.toGraphDef())
}

/**
* Adds a variable used in layer to the pool of tracked variables.
*
* @param variable Variable to track in KGraph.
* @param isTrainable When passed isTrainable = true, the KGraph automatically adds new variables to the tracked collection of graph variables.
*/
public fun addLayerVariable(variable: Variable<Float>, isTrainable: Boolean) {
check(!layerVariables.contains(variable)) { "$variable is added to graph already. Analyze and fix the static graph building process." }
layerVariables[variable] = isTrainable
}

/**
* Adds a variable related to the given [regularizer] to the pool of tracked variables.
*
* @param [variable] TensorFlow variable to track in KGraph.
* @param [regularizer] The regularizer related to the given [variable].
*/
public fun addVariableRegularizer(variable: Variable<Float>, regularizer: Regularizer) {
check(!variableRegularizers.contains(variable)) { "$variable is added to graph already. Analyze and fix the static graph building process." }
variableRegularizers[variable] = regularizer
}

/**
* Adds a variable used in optimizer to the pool of tracked variables.
*
Expand All @@ -134,17 +101,6 @@ public class KGraph(graphDef: ByteArray, prefix: String) : AutoCloseable {
optimizerVariables.add(variable)
}

/**
* Adds an initializer for layer variable tracked in KGraph.
*
* @param variableName Variable name to bind the variable and initializer.
* @param initializer Assign TensorFlow operand to initialize variable with name: [variableName].
*/
public fun addInitializer(variableName: String, initializer: Assign<Float>) {
check(!initializers.contains(variableName)) { "$variableName has initializer already. Analyze and fix the static graph building process." }
initializers[variableName] = initializer
}

/**
* Adds an optimizer initializer for optimizer variable tracked in KGraph.
*
Expand All @@ -163,60 +119,13 @@ public class KGraph(graphDef: ByteArray, prefix: String) : AutoCloseable {
optimizerAssignAddInitializers += initializer
}

/**
* Returns a list of trainable variables used in layers.
*/
public fun trainableLayerVariables(): List<Variable<Float>> {
return layerVariables.filter { it.value }.keys.toList()
}

/**
* Returns the variable's ability to be changed during the training.
*
* @param variableName Variable name.
*/
public fun isVariableTrainable(variableName: String): Boolean {
return layerVariables
.filter { it.value } // only trainable
.map { it.key.ref().op().name() } // extract names
.any { it.contains(variableName) }
}

/**
* Returns a list of non-trainable, 'frozen' variables used in layers.
*/
public fun frozenLayerVariables(): List<Variable<Float>> {
return layerVariables.filterNot { it.value }.keys.toList()
}

/**
* Returns all variables used in all model layers.
*/
public fun layerVariables(): List<Variable<Float>> {
return layerVariables.keys.toList()
}

/**
* Returns all variables used in optimizer and initialized by Assign TensorFlow operand.
*/
public fun optimizerVariables(): List<Variable<Float>> {
return optimizerVariables.toList()
}

/**
* Initializes TensorFlow graph variables used in model layers.
*/
public fun initializeGraphVariables(session: Session) {
val runner = session.runner()
// Only run the session if there is initializer ops.
if (initializers.isNotEmpty()) {
initializers.forEach {
runner.addTarget(it.value as Operand<Float>)
}
runner.run()
}
}

/**
* Initializes TensorFlow graph variables used in optimizer.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core

import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.inference.keras.deserializeSequentialModel
import org.jetbrains.kotlinx.dl.api.inference.keras.loadSequentialModelLayers
Expand Down Expand Up @@ -148,7 +149,7 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) {
var inputShape: Shape = inputLayer.computeOutputShape()

layers.filter { it !is Input }.forEach {
it.build(tf, kGraph, inputShape)
it.build(tf, inputShape)

inputShape = it.computeOutputShape(inputShape)
val tensorShape = TensorShape(inputShape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.util.getDType
Expand Down Expand Up @@ -33,22 +32,15 @@ public data class KVariable(

internal fun createVariable(
tf: Ops,
kGraph: KGraph,
variableName: String,
isTrainable: Boolean,
shape: Shape,
fanIn: Int,
fanOut: Int,
initializer: Initializer,
regularizer: Regularizer?
): KVariable {
val tfVariable = tf.withName(variableName).variable(shape, getDType())

val initOp = initializer.apply(fanIn, fanOut, tf, tfVariable, variableName)
kGraph.addLayerVariable(tfVariable, isTrainable)
kGraph.addInitializer(variableName, initOp)
if (regularizer != null) kGraph.addVariableRegularizer(tfVariable, regularizer)

return KVariable(
name = variableName,
shape = shape,
Expand Down
Loading

0 comments on commit 30c45ae

Please sign in to comment.