Skip to content

Commit

Permalink
Fix variables initialization (#355)
Browse files Browse the repository at this point in the history
* Introduce InitializerOperation class to represent variable initialization

For some initializers, like Zeros, initializer operations are named differently from what is expected. Because of this, such operations could not be found by name. Keeping initializer in the KVariable allows to access it directly.

* Cleanup saving and loading variable values

* Move getVariablesAndTensors function down to the GraphTrainableModel

* Do not check if variable is related to the frozen layers for pure TensorFlowInferenceModel

It won't work as the model does not have any layers.

* Reuse loadVariablesFromTxt function

* Unify loading variable data from txt files and model copying

* Find a corresponding KVariable object to assign variable value in GraphTrainableModel
  • Loading branch information
juliabeliaeva authored May 9, 2022
1 parent 30c45ae commit a0727b2
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import org.tensorflow.op.Ops
import org.tensorflow.op.core.Placeholder
import org.tensorflow.op.core.Variable
import java.io.File
import java.io.FileNotFoundException
import java.nio.FloatBuffer
import java.nio.file.Files
import java.nio.file.Paths
Expand Down Expand Up @@ -114,8 +113,14 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
session = Session(kGraph.tfGraph)
}

override fun variables(): List<Variable<Float>> = layers.variables().map { it.variable }
override fun frozenVariables(): List<Variable<Float>> = layers.frozenVariables().map { it.variable }
/**
* Returns a list of layer variables in this model.
*/
private fun layerVariables(): List<KVariable> = layers.variables()
/**
* Returns a list of non-trainable, 'frozen' layer variables in this model.
*/
private fun frozenLayerVariables(): List<KVariable> = layers.frozenVariables()

/** Helper method for preprocessing layer names and layer validation. */
internal companion object {
Expand Down Expand Up @@ -894,33 +899,68 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
}
}

/** Returns a list of variables paired with their data. */
private fun getVariablesAndTensors(saveOptimizerState: Boolean): List<Pair<Variable<Float>, Tensor<*>>> {
var variables = layerVariables().map { it.variable }
if (saveOptimizerState) {
variables = variables + kGraph.optimizerVariables()
}

val modelWeightsExtractorRunner = session.runner()
variables.forEach(modelWeightsExtractorRunner::fetch)
return variables.zip(modelWeightsExtractorRunner.run())
}

override fun loadWeights(modelDirectory: File, loadOptimizerState: Boolean) {
check(isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." }
check(!isModelInitialized) { "The model is initialized already." }

Files.createDirectories(modelDirectory.toPath())
// Load variables names
val file = File("${modelDirectory.absolutePath}/variableNames.txt")
loadVariablesFromTxt(modelDirectory.path) { variableName ->
if (!isOptimizerVariable(variableName)) true
else if (loadOptimizerState) !isVariableRelatedToFrozenLayer(variableName)
else false
}

if (!file.exists()) throw FileNotFoundException(
"File 'variableNames.txt' is not found. This file must be in the model directory. " +
"It is generated during Sequential model saving with SavingFormat.TF_GRAPH_CUSTOM_VARIABLES or SavingFormat.JSON_CONFIG_CUSTOM_VARIABLES."
)
isModelInitialized = true
if (loadOptimizerState) isOptimizerVariableInitialized = true
}

val variableNames = file.readLines()
// TODO: common code could be refactored with the link to the function (load variable)
if (variableNames.isNotEmpty()) {
for (variableName in variableNames) {
if (!loadOptimizerState && variableName.startsWith("optimizer")) // skip loading optimizers' variables
continue
else if (loadOptimizerState && isOptimizerNameAndRelatedToFrozenLayer(variableName)) // skip loading optimizers' variables for frozen layers
continue
else loadVariable(variableName, modelDirectory.absolutePath)
/** Check that the variable with the name [variableName] belongs to the frozen layer. */
private fun isVariableRelatedToFrozenLayer(variableName: String): Boolean {
return frozenLayerVariables().map { it.name }.any { variableName.contains(it) }
}

/**
* Loads variable data for variable names in the provided collection using a provided function.
* @param [variableNames] Variable names to load.
* @param [getData] Function that returns variable data by variable name and shape.
*/
protected override fun loadVariables(variableNames: Collection<String>, getData: (String, Shape) -> Any) {
val layerVariablesByName = layerVariables().associateBy { it.name }

for (variableName in variableNames) {
val variableOperation = kGraph.tfGraph.operation(variableName)
check(variableOperation != null) { "Operation $variableName is not found in static graph." }
val variableShape = variableOperation.output<Float>(0).shape()

val data = getData(variableName, variableShape)

val variable = layerVariablesByName[variableName]
if (variable != null) {
fill(variable, data)
} else {
assignVariable(variableName, variableShape, data)
}
}
}

isModelInitialized = true
if (loadOptimizerState) isOptimizerVariableInitialized = true
internal fun fill(variable: KVariable, data: Any) {
variable.initializerOperation.fill(data, session)
}

internal fun init(variable: KVariable) {
variable.initializerOperation.run(session)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.jetbrains.kotlinx.dl.api.core.initializer

import org.jetbrains.kotlinx.dl.api.core.layer.InitializerOperation
import org.jetbrains.kotlinx.dl.api.core.shape.shapeOperand
import org.jetbrains.kotlinx.dl.api.core.util.defaultAssignOpName
import org.jetbrains.kotlinx.dl.api.core.util.defaultInitializerOpName
Expand All @@ -28,21 +29,21 @@ public abstract class Initializer {
* @param [fanOut] The maximum number of inputs that the output of an initializer can feed to other steps.
* @param [tf] Tensorflow Ops Accessor
* @param [input] Variable to initialize
* @return Assign operand created.
* @return initializer operation.
* @see InitializerOperation
*/
public fun apply(
fanIn: Int,
fanOut: Int,
tf: Ops,
input: Operand<Float>,
name: String
): Assign<Float> {
return tf.withName(defaultAssignOpName(name)).assign(
input, initialize(
fanIn, fanOut, tf,
shapeOperand(tf, input.asOutput().shape()), defaultInitializerOpName(name)
)
): InitializerOperation {
val initialize = initialize(
fanIn, fanOut, tf,
shapeOperand(tf, input.asOutput().shape()), defaultInitializerOpName(name)
)
return InitializerOperation(tf.withName(defaultAssignOpName(name)).assign(input, initialize), initialize)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.util.getDType
import org.tensorflow.Operand
import org.tensorflow.Session
import org.tensorflow.Shape
import org.tensorflow.Tensor
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Variable

Expand All @@ -19,17 +21,52 @@ import org.tensorflow.op.core.Variable
* @property [name] name of the variable
* @property [shape] shape of the variable
* @property [variable] corresponding [Variable] object
* @property [initializerOperand] variable initializer
* @property [initializerOperation] variable initializer
* @property [regularizer] variable regularizer
*/
public data class KVariable(
val name: String,
val shape: Shape,
val variable: Variable<Float>,
val initializerOperand: Operand<Float>,
val initializerOperation: InitializerOperation,
val regularizer: Regularizer?
)

/**
* A class that keeps information which allows to initialize [KVariable].
*
* @property [assign] assign operation for the variable
* @property [initialValue] value generated by the default initializer
*/
public data class InitializerOperation(val assign: Operand<Float>, val initialValue: Operand<Float>) {
/**
* Initialize the variable using the default initializer
* @param [session] session to use
*/
public fun run(session: Session) {
session.runner().addTarget(assign).run()
}

/**
* Fill the variable using the given data
* @param [data] data to use, should have correct shape and type
* @param [session] session to use
*/
public fun fill(data: Any, session: Session) {
var tensorData = data
if (data is Array<*> && data.isArrayOf<Float>()) {
tensorData = (data as Array<Float>).toFloatArray()
}

Tensor.create(tensorData).use { tensor ->
session.runner()
.feed(initialValue, tensor)
.addTarget(assign)
.run()
}
}
}

internal fun createVariable(
tf: Ops,
variableName: String,
Expand All @@ -40,12 +77,19 @@ internal fun createVariable(
regularizer: Regularizer?
): KVariable {
val tfVariable = tf.withName(variableName).variable(shape, getDType())
val initOp = initializer.apply(fanIn, fanOut, tf, tfVariable, variableName)
val initializerOperation = initializer.apply(fanIn, fanOut, tf, tfVariable, variableName)
return KVariable(
name = variableName,
shape = shape,
variable = tfVariable,
initializerOperand = initOp,
initializerOperation = initializerOperation,
regularizer = regularizer
)
}

internal fun List<InitializerOperation>.init(session: Session) {
if (isEmpty()) return
val runner = session.runner()
map { it.assign }.forEach(runner::addTarget)
runner.run()
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.TrainableModel
import org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.tensorflow.Operand
import org.tensorflow.Shape
Expand All @@ -21,7 +21,7 @@ public abstract class Layer(public var name: String) {
public lateinit var outputShape: TensorShape

/** Model where this layer is used. */
public var parentModel: TrainableModel? = null
public var parentModel: GraphTrainableModel? = null

/** Returns inbound layers. */
public var inboundLayers: MutableList<Layer> = mutableListOf()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,7 @@ internal fun List<Layer>.frozenVariables(): List<KVariable> {
/**
* Initializes this layers variables using provided initializer operands.
*/
public fun ParametrizedLayer.initialize(session: Session) {
// Only run the session if there is initializer ops.
if (variables.isNotEmpty()) {
val runner = session.runner()
variables.map { it.initializerOperand }.forEach(runner::addTarget)
runner.run()
}
}
public fun ParametrizedLayer.initialize(session: Session): Unit = variables.map { it.initializerOperation }.init(session)

/**
* Initializes variables for [ParametrizedLayer] instances using provided initializer operands.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ public var Layer.weights: Map<String, Array<*>>
val model = parentModel
requireNotNull(model) { "Layer '$name' is not related to any model" }

for ((name, value) in weights) {
model.assignVariable(name, value)
for (variable in variables) {
val value = weights[variable.name]
if (value != null) {
model.fill(variable, value)
}
}
}
Loading

0 comments on commit a0727b2

Please sign in to comment.