From f4be8198d8f3a3288d6d0cfb38a7f4af130cafdc Mon Sep 17 00:00:00 2001 From: Julia Beliaeva Date: Mon, 19 Dec 2022 17:05:16 +0100 Subject: [PATCH] Support copying optimizer variables in Sequential and Functional models --- .../kotlinx/dl/api/core/Functional.kt | 41 +++++++++---------- .../dl/api/core/GraphTrainableModel.kt | 21 ++++++++++ .../kotlinx/dl/api/core/Sequential.kt | 40 +++++++++--------- .../api/inference/TensorFlowInferenceModel.kt | 24 ++++++----- .../TransferLearningFromKerasTest.kt | 2 +- 5 files changed, 73 insertions(+), 55 deletions(-) diff --git a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt index 6869cf149..18b499710 100644 --- a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt +++ b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt @@ -9,7 +9,6 @@ import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.core.Input import org.jetbrains.kotlinx.dl.api.core.layer.freeze import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape -import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.core.util.sortTopologically import org.jetbrains.kotlinx.dl.api.inference.keras.* import org.tensorflow.Operand @@ -108,28 +107,26 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) { file.writeBytes(kGraph.tfGraph.toGraphDef()) } - /** Returns a copy of this model. */ - // TODO: support saveOptimizerState=true with assignment of intermediate optimizer state - public fun copy(saveOptimizerState: Boolean = false, copyWeights: Boolean = true): Functional { - val serializedModel = serializeModel(true) - val deserializedModel = deserializeFunctionalModel(serializedModel) - if (!copyWeights) { - return deserializedModel - } else { - // TODO: make deep copies, not just links - deserializedModel.compile( - optimizer = this.optimizer, - loss = this.loss, - metrics = this.metrics - ) - - deserializedModel.layers.forEach { - it.weights = this.getLayer(it.name).weights - } - - deserializedModel.isModelInitialized = true + override fun copy(): Functional { + return copy(copiedModelName = null, copyOptimizerState = false, copyWeights = true) + } - return deserializedModel + /** + * Creates a copy of this model. + * + * @param [copiedModelName] a name for the copy + * @param [copyOptimizerState] whether optimizer state needs to be copied + * @param [copyWeights] whether model weights need to be copied + * @return A copied inference model. + */ + public fun copy(copiedModelName: String? = null, + copyOptimizerState: Boolean = false, + copyWeights: Boolean = true + ): Functional { + val serializedModel = serializeModel(true) + return deserializeFunctionalModel(serializedModel).also { modelCopy -> + if (copiedModelName != null) modelCopy.name = copiedModelName + if (copyWeights) copyWeightsTo(modelCopy, copyOptimizerState) } } diff --git a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt index 50757ef76..6446c5233 100644 --- a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt +++ b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/GraphTrainableModel.kt @@ -956,6 +956,27 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel variable.initializerOperation.run(session) } + protected fun copyWeightsTo(model: GraphTrainableModel, copyOptimizerState: Boolean) { + // TODO: make deep copies, not just links + model.compile( + optimizer = this.optimizer, + loss = this.loss, + metrics = this.metrics + ) + + model.layers.forEach { + it.weights = this.getLayer(it.name).weights + } + + if (copyOptimizerState) { + val optimizerVariables = kGraph.variableNames().filter(::isOptimizerVariable) + copyVariablesToModel(model, optimizerVariables) + model.isOptimizerVariableInitialized = true + } + + model.isModelInitialized = true + } + /** * Return layer by [layerName]. * diff --git a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt index 2be54d16e..7721b8e17 100644 --- a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt +++ b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Sequential.kt @@ -8,7 +8,6 @@ package org.jetbrains.kotlinx.dl.api.core import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.core.layer.core.Input import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape -import org.jetbrains.kotlinx.dl.api.core.layer.weights import org.jetbrains.kotlinx.dl.api.inference.keras.* import org.tensorflow.Operand import org.tensorflow.op.core.Placeholder @@ -41,27 +40,26 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) { return input to output } - /** Returns a copy of this model. */ - // TODO: implement the saving of optimizer state - public fun copy(saveOptimizerState: Boolean = false, copyWeights: Boolean = true): Sequential { - val serializedModel = serializeModel(true) - val deserializedModel = deserializeSequentialModel(serializedModel) - if (!copyWeights) { - return deserializedModel - } else { - deserializedModel.compile( - optimizer = this.optimizer, - loss = this.loss, - metrics = this.metrics - ) - - deserializedModel.layers.forEach { - it.weights = this.getLayer(it.name).weights - } - - deserializedModel.isModelInitialized = true + override fun copy(): Sequential { + return copy(copiedModelName = null, copyOptimizerState = false, copyWeights = true) + } - return deserializedModel + /** + * Creates a copy of this model. + * + * @param [copiedModelName] a name for the copy + * @param [copyOptimizerState] whether optimizer state needs to be copied + * @param [copyWeights] whether model weights need to be copied + * @return A copied inference model. + */ + public fun copy(copiedModelName: String? = null, + copyOptimizerState: Boolean = false, + copyWeights: Boolean = true + ): Sequential { + val serializedModel = serializeModel(true) + return deserializeSequentialModel(serializedModel).also { modelCopy -> + if (copiedModelName != null) modelCopy.name = copiedModelName + if (copyWeights) copyWeightsTo(modelCopy, copyOptimizerState) } } diff --git a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt index adae5e497..191286a43 100644 --- a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt +++ b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt @@ -160,21 +160,23 @@ public open class TensorFlowInferenceModel : InferenceModel { model.input = input model.output = output if (copiedModelName != null) model.name = name - - val variableNames = kGraph.variableNames() - if (variableNames.isNotEmpty()) { - val modelWeightsExtractorRunner = session.runner() - variableNames.forEach(modelWeightsExtractorRunner::fetch) - val modelWeights = variableNames.zip(modelWeightsExtractorRunner.run()).toMap() - model.loadVariables(modelWeights.keys) { variableName, _ -> - modelWeights[variableName]!!.use { it.convertTensorToMultiDimArray() } - } - } - + copyVariablesToModel(model, kGraph.variableNames()) model.isModelInitialized = true return model } + protected fun copyVariablesToModel(model: TensorFlowInferenceModel, variableNames: List) { + if (variableNames.isEmpty()) return + + val modelWeightsExtractorRunner = session.runner() + variableNames.forEach(modelWeightsExtractorRunner::fetch) + val modelWeights = variableNames.zip(modelWeightsExtractorRunner.run()).toMap() + + model.loadVariables(modelWeights.keys) { variableName, _ -> + modelWeights[variableName]!!.use { it.convertTensorToMultiDimArray() } + } + } + /** * Loads variable data for variable names in the provided collection using a provided function. * @param [variableNames] Variable names to load. diff --git a/tensorflow/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt b/tensorflow/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt index f6ef17ad9..ff5bef2d6 100644 --- a/tensorflow/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt +++ b/tensorflow/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt @@ -306,7 +306,7 @@ class TransferLearningTest : IntegrationTest() { it.loadWeights(hdfFile) - val copy = it.copy(saveOptimizerState = false, copyWeights = true) + val copy = it.copy(copyOptimizerState = false, copyWeights = true) assertTrue(copy.layers.size == 11) copy.close()