Skip to content

Commit

Permalink
Remove extra parameters from the TensorFlowInferenceModel#copy
Browse files Browse the repository at this point in the history
TensorFlowInferenceModel does not have any layers, so it only makes sense to copy weights. And since the model class is not trainable, optimizer state does not need to be copied.
  • Loading branch information
juliabeliaeva committed Dec 19, 2022
1 parent 7df2a7c commit 530b576
Showing 1 changed file with 9 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,11 @@ public open class TensorFlowInferenceModel : InferenceModel {
}

override fun copy(): TensorFlowInferenceModel {
return copy(copiedModelName = null, saveOptimizerState = false, copyWeights = true)
return copy(copiedModelName = null)
}

public fun copy(
copiedModelName: String? = null,
saveOptimizerState: Boolean = false, // TODO, check this case
copyWeights: Boolean = true
): TensorFlowInferenceModel {
/** Returns a copy of this model. */
public fun copy(copiedModelName: String? = null): TensorFlowInferenceModel {
val model = TensorFlowInferenceModel()
model.kGraph = this.kGraph.copy()
model.tf = Ops.create(model.kGraph.tfGraph)
Expand All @@ -163,25 +160,17 @@ public open class TensorFlowInferenceModel : InferenceModel {
model.input = input
model.output = output
if (copiedModelName != null) model.name = name
// TODO: check that tensors are closed after usage
if (copyWeights) {
val modelWeightsExtractorRunner = session.runner()
val variableNames = kGraph.variableNames()
check(variableNames.isNotEmpty()) {
"Found 0 variable names in TensorFlow graph $kGraph. " +
"If copied model has no weights, set flag `copyWeights` to `false`."
}

val variableNamesToCopy = variableNames.filter { variableName ->
saveOptimizerState || !isOptimizerVariable(variableName)
}
variableNamesToCopy.forEach(modelWeightsExtractorRunner::fetch)
val modelWeights = variableNamesToCopy.zip(modelWeightsExtractorRunner.run()).toMap()

val variableNames = kGraph.variableNames()
if (variableNames.isNotEmpty()) {
val modelWeightsExtractorRunner = session.runner()
variableNames.forEach(modelWeightsExtractorRunner::fetch)
val modelWeights = variableNames.zip(modelWeightsExtractorRunner.run()).toMap()
model.loadVariables(modelWeights.keys) { variableName, _ ->
modelWeights[variableName]!!.use { it.convertTensorToMultiDimArray() }
}
}

model.isModelInitialized = true
return model
}
Expand Down

0 comments on commit 530b576

Please sign in to comment.