Skip to content

Commit

Permalink
Support copying optimizer variables in Sequential and Functional models
Browse files Browse the repository at this point in the history
  • Loading branch information
juliabeliaeva committed Dec 19, 2022
1 parent 530b576 commit f4be819
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.freeze
import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.core.util.sortTopologically
import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.tensorflow.Operand
Expand Down Expand Up @@ -108,28 +107,26 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
file.writeBytes(kGraph.tfGraph.toGraphDef())
}

/** Returns a copy of this model. */
// TODO: support saveOptimizerState=true with assignment of intermediate optimizer state
public fun copy(saveOptimizerState: Boolean = false, copyWeights: Boolean = true): Functional {
val serializedModel = serializeModel(true)
val deserializedModel = deserializeFunctionalModel(serializedModel)
if (!copyWeights) {
return deserializedModel
} else {
// TODO: make deep copies, not just links
deserializedModel.compile(
optimizer = this.optimizer,
loss = this.loss,
metrics = this.metrics
)

deserializedModel.layers.forEach {
it.weights = this.getLayer(it.name).weights
}

deserializedModel.isModelInitialized = true
override fun copy(): Functional {
return copy(copiedModelName = null, copyOptimizerState = false, copyWeights = true)
}

return deserializedModel
/**
* Creates a copy of this model.
*
* @param [copiedModelName] a name for the copy
* @param [copyOptimizerState] whether optimizer state needs to be copied
* @param [copyWeights] whether model weights need to be copied
* @return A copied inference model.
*/
public fun copy(copiedModelName: String? = null,
copyOptimizerState: Boolean = false,
copyWeights: Boolean = true
): Functional {
val serializedModel = serializeModel(true)
return deserializeFunctionalModel(serializedModel).also { modelCopy ->
if (copiedModelName != null) modelCopy.name = copiedModelName
if (copyWeights) copyWeightsTo(modelCopy, copyOptimizerState)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,27 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
variable.initializerOperation.run(session)
}

protected fun copyWeightsTo(model: GraphTrainableModel, copyOptimizerState: Boolean) {
// TODO: make deep copies, not just links
model.compile(
optimizer = this.optimizer,
loss = this.loss,
metrics = this.metrics
)

model.layers.forEach {
it.weights = this.getLayer(it.name).weights
}

if (copyOptimizerState) {
val optimizerVariables = kGraph.variableNames().filter(::isOptimizerVariable)
copyVariablesToModel(model, optimizerVariables)
model.isOptimizerVariableInitialized = true
}

model.isModelInitialized = true
}

/**
* Return layer by [layerName].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package org.jetbrains.kotlinx.dl.api.core
import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.tensorflow.Operand
import org.tensorflow.op.core.Placeholder
Expand Down Expand Up @@ -41,27 +40,26 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) {
return input to output
}

/** Returns a copy of this model. */
// TODO: implement the saving of optimizer state
public fun copy(saveOptimizerState: Boolean = false, copyWeights: Boolean = true): Sequential {
val serializedModel = serializeModel(true)
val deserializedModel = deserializeSequentialModel(serializedModel)
if (!copyWeights) {
return deserializedModel
} else {
deserializedModel.compile(
optimizer = this.optimizer,
loss = this.loss,
metrics = this.metrics
)

deserializedModel.layers.forEach {
it.weights = this.getLayer(it.name).weights
}

deserializedModel.isModelInitialized = true
override fun copy(): Sequential {
return copy(copiedModelName = null, copyOptimizerState = false, copyWeights = true)
}

return deserializedModel
/**
* Creates a copy of this model.
*
* @param [copiedModelName] a name for the copy
* @param [copyOptimizerState] whether optimizer state needs to be copied
* @param [copyWeights] whether model weights need to be copied
* @return A copied inference model.
*/
public fun copy(copiedModelName: String? = null,
copyOptimizerState: Boolean = false,
copyWeights: Boolean = true
): Sequential {
val serializedModel = serializeModel(true)
return deserializeSequentialModel(serializedModel).also { modelCopy ->
if (copiedModelName != null) modelCopy.name = copiedModelName
if (copyWeights) copyWeightsTo(modelCopy, copyOptimizerState)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,21 +160,23 @@ public open class TensorFlowInferenceModel : InferenceModel {
model.input = input
model.output = output
if (copiedModelName != null) model.name = name

val variableNames = kGraph.variableNames()
if (variableNames.isNotEmpty()) {
val modelWeightsExtractorRunner = session.runner()
variableNames.forEach(modelWeightsExtractorRunner::fetch)
val modelWeights = variableNames.zip(modelWeightsExtractorRunner.run()).toMap()
model.loadVariables(modelWeights.keys) { variableName, _ ->
modelWeights[variableName]!!.use { it.convertTensorToMultiDimArray() }
}
}

copyVariablesToModel(model, kGraph.variableNames())
model.isModelInitialized = true
return model
}

protected fun copyVariablesToModel(model: TensorFlowInferenceModel, variableNames: List<String>) {
if (variableNames.isEmpty()) return

val modelWeightsExtractorRunner = session.runner()
variableNames.forEach(modelWeightsExtractorRunner::fetch)
val modelWeights = variableNames.zip(modelWeightsExtractorRunner.run()).toMap()

model.loadVariables(modelWeights.keys) { variableName, _ ->
modelWeights[variableName]!!.use { it.convertTensorToMultiDimArray() }
}
}

/**
* Loads variable data for variable names in the provided collection using a provided function.
* @param [variableNames] Variable names to load.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class TransferLearningTest : IntegrationTest() {

it.loadWeights(hdfFile)

val copy = it.copy(saveOptimizerState = false, copyWeights = true)
val copy = it.copy(copyOptimizerState = false, copyWeights = true)
assertTrue(copy.layers.size == 11)
copy.close()

Expand Down

0 comments on commit f4be819

Please sign in to comment.