Skip to content

Commit

Permalink
Override copy function in Functional and Sequential models
Browse files Browse the repository at this point in the history
  • Loading branch information
juliabeliaeva committed Dec 19, 2022
1 parent 90b046d commit bbf1be7
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
file.writeBytes(kGraph.tfGraph.toGraphDef())
}

/** Returns a copy of this model. */
public fun copy(saveOptimizerState: Boolean = false, copyWeights: Boolean = true): Functional {
override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): Functional {
val serializedModel = serializeModel(true)
val deserializedModel = deserializeFunctionalModel(serializedModel)
if (copiedModelName != null) deserializedModel.name = copiedModelName
if (!copyWeights) return deserializedModel

// TODO: make deep copies, not just links
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) {
return input to output
}

/** Returns a copy of this model. */
public fun copy(saveOptimizerState: Boolean = false, copyWeights: Boolean = true): Sequential {
override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): Sequential {
val serializedModel = serializeModel(true)
val deserializedModel = deserializeSequentialModel(serializedModel)
if (copiedModelName != null) deserializedModel.name = copiedModelName
if (!copyWeights) return deserializedModel

deserializedModel.compile(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ public open class TensorFlowInferenceModel : InferenceModel {
return copy(copiedModelName = null, saveOptimizerState = false, copyWeights = true)
}

public fun copy(
/** Returns a copy of this model. */
public open fun copy(
copiedModelName: String? = null,
saveOptimizerState: Boolean = false, // TODO, check this case
copyWeights: Boolean = true
Expand Down

0 comments on commit bbf1be7

Please sign in to comment.