diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt index 9e9adcd7f..88663a5ab 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt @@ -68,7 +68,7 @@ internal fun GraphTrainableModel.serializeModel(isKerasFullyCompatible: Boolean) (kerasLayers.first().config as LayerConfig).batch_input_shape = listOf(null, inputShape[0], inputShape[1], inputShape[2]) // TODO: refactor with method for Input layer - val config = KerasModelConfig(name = "", layers = kerasLayers) + val config = KerasModelConfig(name = name, layers = kerasLayers) return KerasModel(config = config) } diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt index 92bd29ba0..3e51f646e 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/models/SequentialCompilationTest.kt @@ -82,7 +82,9 @@ internal class SequentialModelTest { biasInitializer = Constant(0.1f), name = "dense_2" ) - ) + ).apply { + name = "sequential_model" + } @Test fun buildModel() { @@ -99,6 +101,8 @@ internal class SequentialModelTest { @Test fun summary() { correctTestModel.use { + assertEquals("sequential_model", it.name) + it.compile(optimizer = Adam(), loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS, metric = Accuracy()) val layerDescriptions = it.summary() assertTrue(layerDescriptions[1].contentEquals("conv2d_1(Conv2D) [None, 28, 28, 32] 832"))