Skip to content

Commit

Permalink
Remove extra parameters from the InferenceModel#copy
Browse files Browse the repository at this point in the history
copyWeight and saveOptimizerState parameters are unused in the OnnxInferenceModel, and InferenceModel is not required to have a name, so copiedModelName does not always make sense as well. It's better to have a more generic copy method in the interface, and move these parameters to the appropriate implementing classes.
  • Loading branch information
juliabeliaeva committed Dec 19, 2022
1 parent b2b7af9 commit 7df2a7c
Show file tree
Hide file tree
Showing 15 changed files with 33 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,9 @@ public interface InferenceModel : AutoCloseable {
public fun reshape(vararg dims: Long)

/**
* Creates a copy.
* Creates a copy of this model.
*
* @param [copiedModelName] Set up this name to make a copy with a new name.
* @return A copied inference model.
*/
public fun copy(
copiedModelName: String? = null,
saveOptimizerState: Boolean = false,
copyWeights: Boolean = true
): InferenceModel
public fun copy(): InferenceModel
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class ModelCopyingTestSuite {
for (modelType in modelTypes) {
modelHub.loadPretrainedModel(modelType, LoadingMode.SKIP_LOADING_IF_EXISTS).use { model ->
@Suppress("UNCHECKED_CAST")
(model.copy("model_copy") as M).use { modelCopy ->
(model.copy() as M).use { modelCopy ->

val detectionResult = detect(model, imageFile)
val detectionResultCopy = detect(modelCopy, imageFile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,9 @@ public class ImageRecognitionModel(
return predictObject(ImageConverter.toBufferedImage(imageFile))
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): ImageRecognitionModel {
override fun copy(): ImageRecognitionModel {
return ImageRecognitionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
internalModel.copy(),
inputColorMode,
channelsFirst,
preprocessor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,8 @@ public class FaceDetectionModel(
.toFloatArray { layout = TensorLayout.NCHW }
.call(ONNXModels.FaceDetection.defaultPreprocessor)

override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel {
return FaceDetectionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
override fun copy(): InferenceModel {
return FaceDetectionModel(internalModel.copy(), modelKindDescription)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,8 @@ public class Fan2D106FaceAlignmentModel(
.rotate { degrees = targetRotation.toFloat() }
.toFloatArray { layout = TensorLayout.NCHW }

override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel {
return Fan2D106FaceAlignmentModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
override fun copy(): InferenceModel {
return Fan2D106FaceAlignmentModel(internalModel.copy(), modelKindDescription)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,8 @@ public open class OnnxInferenceModel private constructor(
}
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): OnnxInferenceModel {
override fun copy(): OnnxInferenceModel {
val model = OnnxInferenceModel(modelSource)
model.name = copiedModelName
if (inputShape != null) {
model.reshape(*inputDimensions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ public class FaceDetectionModel(
.toFloatArray { }
.call(ONNXModels.FaceDetection.defaultPreprocessor)

override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel {
return FaceDetectionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
override fun copy(): InferenceModel {
return FaceDetectionModel(internalModel.copy(), modelKindDescription)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,8 @@ public class Fan2D106FaceAlignmentModel(
return detectLandmarks(ImageConverter.toBufferedImage(imageFile))
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): Fan2D106FaceAlignmentModel {
return Fan2D106FaceAlignmentModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
override fun copy(): Fan2D106FaceAlignmentModel {
return Fan2D106FaceAlignmentModel(internalModel.copy(), modelKindDescription)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,7 @@ public class EfficientDetObjectDetectionModel(
return detectObjects(ImageConverter.toBufferedImage(imageFile), topK)
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): EfficientDetObjectDetectionModel {
return EfficientDetObjectDetectionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
override fun copy(): EfficientDetObjectDetectionModel {
return EfficientDetObjectDetectionModel(internalModel.copy(), modelKindDescription)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,7 @@ public class SSDMobileNetV1ObjectDetectionModel(
return detectObjects(ImageConverter.toBufferedImage(imageFile), topK)
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): SSDMobileNetV1ObjectDetectionModel {
return SSDMobileNetV1ObjectDetectionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
override fun copy(): SSDMobileNetV1ObjectDetectionModel {
return SSDMobileNetV1ObjectDetectionModel(internalModel.copy(), modelKindDescription)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,7 @@ public class SSDObjectDetectionModel(
return foundObjects
}


override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): SSDObjectDetectionModel {
return SSDObjectDetectionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
override fun copy(): SSDObjectDetectionModel {
return SSDObjectDetectionModel(internalModel.copy(), modelKindDescription)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,7 @@ public class MultiPoseDetectionModel(
return detectPoses(ImageConverter.toBufferedImage(imageFile), confidence)
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): MultiPoseDetectionModel {
return MultiPoseDetectionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
override fun copy(): MultiPoseDetectionModel {
return MultiPoseDetectionModel(internalModel.copy(), modelKindDescription)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,7 @@ public class SinglePoseDetectionModel(
return detectPose(ImageConverter.toBufferedImage(imageFile))
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): SinglePoseDetectionModel {
return SinglePoseDetectionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
override fun copy(): SinglePoseDetectionModel {
return SinglePoseDetectionModel(internalModel.copy(), modelKindDescription)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,14 @@ public open class TensorFlowInferenceModel : InferenceModel {
}
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean, // TODO, check this case
copyWeights: Boolean
override fun copy(): TensorFlowInferenceModel {
return copy(copiedModelName = null, saveOptimizerState = false, copyWeights = true)
}

public fun copy(
copiedModelName: String? = null,
saveOptimizerState: Boolean = false, // TODO, check this case
copyWeights: Boolean = true
): TensorFlowInferenceModel {
val model = TensorFlowInferenceModel()
model.kGraph = this.kGraph.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class TransferLearningTest : IntegrationTest() {

it.loadWeights(hdfFile)

val copy = it.copy()
val copy = it.copy(saveOptimizerState = false, copyWeights = true)
assertTrue(copy.layers.size == 11)
copy.close()

Expand Down

0 comments on commit 7df2a7c

Please sign in to comment.