diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/WeightLoader.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/WeightLoader.kt index 41901e3e3..e9beb08f5 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/WeightLoader.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/WeightLoader.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020-2021 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ @@ -9,7 +9,6 @@ import io.jhdf.HdfFile import io.jhdf.api.Group import io.jhdf.api.Node import io.jhdf.dataset.DatasetBase -import org.jetbrains.kotlinx.dl.api.core.Functional import org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel import org.jetbrains.kotlinx.dl.api.core.layer.Layer import org.jetbrains.kotlinx.dl.api.inference.keras.WeightMappings.BIAS_DATA_PATH_TEMPLATE @@ -24,152 +23,55 @@ import org.jetbrains.kotlinx.dl.api.inference.keras.WeightMappings.getLayerVaria * * @param [hdfFile] File in hdf5 file format containing weights of the model. */ -public fun GraphTrainableModel.loadWeights( - hdfFile: HdfFile -) { - check(this.isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." } - check(!isModelInitialized) { "Model is initialized already!" } - this.logger.info { "Starting weights loading.." } - - when { - hdfFile.attributes.containsKey("layer_names") -> loadWeightsFromHdf5Group(hdfFile, this, null) - hdfFile.children.containsKey("model_weights") -> { - loadWeightsFromHdf5Group((hdfFile as Group).getChild("model_weights") as Group, this, null) - } - else -> { - this.logger.error { "This is unknown path format. Use special method loadWeightsViaPathTemplates() to specify templates to load weights." } - } - } - - this.logger.info { "Weights are loaded." } - this.isModelInitialized = true -} +public fun GraphTrainableModel.loadWeights(hdfFile: HdfFile): Unit = loadWeights(hdfFile, layers) /** - * Loads weights from hdf5 file created in Keras TensorFlow framework for pre-defined list of layers. + * Loads weights from hdf5 file created in Keras TensorFlow framework for non-trainable (or frozen) layers only. * - * NOTE: Weights for another layers will not be loaded (should be initialized manually). + * NOTE: Weights for trainable layers will not be loaded and will be initialized via default initializers. * * @param [hdfFile] File in hdf5 file format containing weights of Sequential model. - * @param [layerList] List of layers to load weights. Weights for other layers will be initialized by initializer later. */ -public fun GraphTrainableModel.loadWeights( - hdfFile: HdfFile, - layerList: MutableList -) { - check(this.isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." } - check(!isModelInitialized) { "Model is initialized already!" } - this.logger.info { "Starting weights loading.." } - - when { - hdfFile.attributes.containsKey("layer_names") -> loadWeightsFromHdf5Group(hdfFile, this, layerList) - hdfFile.children.containsKey("model_weights") -> { - loadWeightsFromHdf5Group((hdfFile as Group).getChild("model_weights") as Group, this, layerList) - } - else -> { - this.logger.error { "This is unknown path format. Use special method loadWeightsViaPathTemplates() to specify templates to load weights." } - } - } - - this.logger.info { "Weights are loaded." } - this.isModelInitialized = true +public fun GraphTrainableModel.loadWeightsForFrozenLayers(hdfFile: HdfFile) { + loadWeights(hdfFile, layers.filterNot(Layer::isTrainable)) } /** - * Loads weights from hdf5 file created in Keras TensorFlow framework for non-trainable (or frozen) layers only. + * Loads weights from hdf5 file created in Keras TensorFlow framework for pre-defined list of layers. * - * NOTE: Weights for trainable layers will not be loaded and will be initialized via default initializers. + * NOTE: Weights for another layers will not be loaded (should be initialized manually). * * @param [hdfFile] File in hdf5 file format containing weights of Sequential model. + * @param [layerList] List of layers to load weights. Weights for other layers will be initialized by initializer later. */ -public fun GraphTrainableModel.loadWeightsForFrozenLayers( - hdfFile: HdfFile -) { - val frozenLayers = mutableListOf() - this.layers.forEach { - if (!it.isTrainable) frozenLayers.add(it) +public fun GraphTrainableModel.loadWeights(hdfFile: HdfFile, layerList: List) { + val group = when { + hdfFile.attributes.containsKey("layer_names") -> hdfFile + hdfFile.children.containsKey("model_weights") -> (hdfFile as Group).getChild("model_weights") as Group + else -> null } - this.loadWeights(hdfFile, frozenLayers) -} - -private fun loadWeightsFromHdf5Group(group: Group, model: GraphTrainableModel, layerList: MutableList?) { - var originalKerasVersion = 1 - - if (group.attributes.containsKey("keras_version") && ((if (group.attributes["keras_version"] != null) group.attributes["keras_version"]?.data else "1") as String).startsWith( - "2" - ) - ) { - originalKerasVersion = 2 + if (group == null) { + logger.error { + "This is unknown path format. Use special method loadWeightsViaPathTemplates()" + + " to specify templates to load weights." + } + return } - if (originalKerasVersion == 1 - ) { + + if (group.getKerasVersion() == 1) { throw UnsupportedOperationException( "The weights loading from Keras 1.x is not supported by default!" + "\nUse loadWeightsViaPathTemplates() method to make custom loading!" ) } - if (layerList != null) { - model.layers.forEach { - run { - if (layerList.contains(it)) { - fillLayerWeights(it, group, model) - } else { - initLayerWeights(it, model) - } - } - } - } else { - model.layers.forEach { - run { - fillLayerWeights(it, group, model) - } - } - } -} - -private fun fillLayerWeights( - layer: Layer, - group: Group, - model: GraphTrainableModel -) { - val variables = getLayerVariables(layer) - if (variables == null) { - model.logger.warn { "Loading weights for the layer ${layer.name} is skipped as ${layer::class.qualifiedName} layers are not supported." } - return - } - fillLayerVariablesFromKeras(layer.name, variables, model, group) - model.logger.debug { "${layer.paramCount} parameters loaded for the layer ${layer.name}." } + loadWeights(layerList) { layer -> fillLayerWeights(layer, group, this) } } -private fun fillLayerVariablesFromKeras(layerName: String, - variables: Map>, - model: GraphTrainableModel, - group: Group -) { - val layerWeightsNode = group.children[layerName] as? Group - check(layerWeightsNode != null) { - val availableLayerNames = group.children.values.map(Node::getName) - val modelLayerNames = model.layers.map(Layer::name) - "Weights for the loaded layer $layerName are not found in .h5 file! " + - "\nh5 weight file contains weights for the following list of layers: $availableLayerNames" + - "\nDouble-check your loaded configuration which contains layers with the following names: $modelLayerNames." - } - - val nameOfWeightSubGroup = layerWeightsNode.children.keys.first() - val dataNodes = (layerWeightsNode.children[nameOfWeightSubGroup] as Group).children - - dataNodes.values.map { it as DatasetBase }.forEach { - val (name, shape) = variables[it.name] - ?: throw IllegalArgumentException( - "Parsing of h5 file for variable with name ${it.name} in layer $layerName is not supported!" - ) - val dims = it.dimensions - require(shape.map(Long::toInt).toIntArray().contentEquals(dims)) { - "$name variable shape in loaded data is ${dims.contentToString()}. Should be ${shape.contentToString()}" - } - model.fillVariable(name, it.data) - } +private fun Group.getKerasVersion(): Int { + val kerasVersionAttribute = attributes["keras_version"] ?: return 1 + if ((kerasVersionAttribute.data as String).startsWith("2")) return 2 + return 1 } /** @@ -183,108 +85,49 @@ public fun GraphTrainableModel.loadWeightsByPathTemplates( hdfFile: HdfFile, kernelDataPathTemplate: String = KERNEL_DATA_PATH_TEMPLATE, // TODO: doesnt' work for batchnorm/depthwise biasDataPathTemplate: String = BIAS_DATA_PATH_TEMPLATE -) { - check(this.isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." } - check(!isModelInitialized) { "Model is initialized already!" } - this.logger.info { "Starting weights loading.." } - this.layers.forEach { - run { - fillLayerWeights( - it, - hdfFile, - LayerConvOrDensePaths("", kernelDataPathTemplate, biasDataPathTemplate), - this - ) // TODO: doesnt' work for batchnorm/depthwise - } - } - this.logger.info { "Weights are loaded." } - this.isModelInitialized = true -} +): Unit = loadWeightsByPathTemplates(hdfFile, layers, kernelDataPathTemplate, biasDataPathTemplate) /** - * Loads weights from hdf5 file created in Keras TensorFlow framework for pre-defined list of layers. + * Loads weights from hdf5 file created in Keras TensorFlow framework for non-trainable (or frozen) layers only. * - * NOTE: Weights for another layers will not be loaded (should be initialized manually). + * NOTE: Weights for trainable layers will not be loaded and will be initialized via default initializers. * * @param [hdfFile] File in hdf5 file format containing weights of Sequential model. - * @param [layerList] List of layers to load weights. Weights for other layers will be initialized by initializer later. * @param [kernelDataPathTemplate] Template path to kernel weights of the specific layer. * @param [biasDataPathTemplate] Template path to bias weights of the specific layer. */ -public fun GraphTrainableModel.loadWeightsByPathTemplates( +public fun GraphTrainableModel.loadWeightsForFrozenLayersByPathTemplates( hdfFile: HdfFile, - layerList: MutableList, kernelDataPathTemplate: String = KERNEL_DATA_PATH_TEMPLATE, biasDataPathTemplate: String = BIAS_DATA_PATH_TEMPLATE ) { - check(this.isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." } - check(!isModelInitialized) { "Model is initialized already!" } - this.logger.info { "Starting weights loading.." } - this.layers.forEach { - run { - if (layerList.contains(it)) { - fillLayerWeights( - it, - hdfFile, - LayerConvOrDensePaths("", kernelDataPathTemplate, biasDataPathTemplate), - this - ) // TODO: doesnt' work for batchnorm/depthwise - } else { - initLayerWeights(it, this) - } - } - } - this.logger.info { "Weights are loaded." } - this.isModelInitialized = true -} - -private fun fillLayerWeights( - layer: Layer, - hdfFile: HdfFile, - layerPaths: LayerPaths?, - model: GraphTrainableModel -) { - val variables = getLayerVariablePathTemplates(layer, layerPaths) - if (variables == null) { - model.logger.warn { "Loading weights for the layer ${layer.name} is skipped as ${layer::class.qualifiedName} layers are not supported." } - return - } - variables.forEach { (variableName, variableDataPathTemplate) -> - val data = hdfFile.getDatasetByPath(variableDataPathTemplate.format(layer.name, layer.name)).data - model.fillVariable(variableName, data) - } - model.logger.debug { "${layer.paramCount} parameters loaded for the layer ${layer.name}." } -} - -private fun initLayerWeights(layer: Layer, model: GraphTrainableModel) { - val variables = getLayerVariableNames(layer) - if (variables == null) { - model.logger.warn { "Initializing weights for the layer ${layer.name} is skipped as ${layer::class.qualifiedName} layers are not supported." } - return - } - variables.forEach(model::runAssignOpByVarName) - model.logger.debug { "${layer.paramCount} parameters initialized for the layer ${layer.name}." } + loadWeightsByPathTemplates( + hdfFile, layers.filterNot(Layer::isTrainable), + kernelDataPathTemplate, biasDataPathTemplate + ) } /** - * Loads weights from hdf5 file created in Keras TensorFlow framework for non-trainable (or frozen) layers only. + * Loads weights from hdf5 file created in Keras TensorFlow framework for pre-defined list of layers. * - * NOTE: Weights for trainable layers will not be loaded and will be initialized via default initializers. + * NOTE: Weights for another layers will not be loaded (should be initialized manually). * * @param [hdfFile] File in hdf5 file format containing weights of Sequential model. + * @param [layerList] List of layers to load weights. Weights for other layers will be initialized by initializer later. * @param [kernelDataPathTemplate] Template path to kernel weights of the specific layer. * @param [biasDataPathTemplate] Template path to bias weights of the specific layer. */ -public fun Functional.loadWeightsForFrozenLayersByPathTemplates( +public fun GraphTrainableModel.loadWeightsByPathTemplates( hdfFile: HdfFile, + layerList: List, kernelDataPathTemplate: String = KERNEL_DATA_PATH_TEMPLATE, biasDataPathTemplate: String = BIAS_DATA_PATH_TEMPLATE ) { - val frozenLayers = mutableListOf() - this.layers.forEach { - if (!it.isTrainable) frozenLayers.add(it) + // TODO: doesnt' work for batchnorm/depthwise + val layerPaths = LayerConvOrDensePaths("", kernelDataPathTemplate, biasDataPathTemplate) + loadWeights(layerList) { layer -> + fillLayerWeights(layer, hdfFile, layerPaths, this) } - this.loadWeightsByPathTemplates(hdfFile, frozenLayers, kernelDataPathTemplate, biasDataPathTemplate) } /** @@ -301,45 +144,25 @@ public fun GraphTrainableModel.loadWeightsByPaths( missedWeights: MissedWeightsStrategy = MissedWeightsStrategy.INITIALIZE, forFrozenLayersOnly: Boolean = false // TODO: probably it should be a flag in all methods ) { - check(this.isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." } - check(!isModelInitialized) { "Model is initialized already!" } - this.logger.info { "Starting weights loading.." } - - var layersToLoad = this.layers - var layersToInit = this.layers - - if (forFrozenLayersOnly) { - layersToLoad = layersToLoad.filter { !it.isTrainable } - layersToInit = layersToInit.filter { it.isTrainable } - layersToInit.forEach { - initLayerWeights(it, this) - } - } + val layersToLoad = if (forFrozenLayersOnly) layers.filterNot(Layer::isTrainable) else layers - layersToLoad.forEach { - run { - val initializedLayerName = it.name - val layerWeightPaths = weightPaths.find { initializedLayerName == it.layerName } - if (layerWeightPaths != null) { - fillLayerWeights(it, hdfFile, layerWeightPaths, this) - } else { - if (missedWeights == MissedWeightsStrategy.LOAD_CUSTOM_PATH) { - fillLayerWeights( - it, - hdfFile, - null, // TODO: refactor = it doesn't work for batchnorm or depthwise - this - ) - } else { - this.logger.warn { "Layer weight paths for ${it.name} are not found in 'weightPaths' object. It will be initialized by default initializer." } - initLayerWeights(it, this) - } + val layersToWeightPaths = layersToLoad.mapNotNull { layer -> + val layerPaths = weightPaths.find { layer.name == it.layerName } + if (layerPaths == null && missedWeights == MissedWeightsStrategy.INITIALIZE) { + logger.warn { + "Layer weight paths for ${layer.name} are not found in 'weightPaths' object." + + " Initialization is going to be done by default initializer." } + return@mapNotNull null } - } + layer to layerPaths + // TODO: refactor when weight path is not provided and strategy is not INITIALIZE it won't work for batchnorm or depthwise + }.toMap() - this.logger.info { "Weights are loaded." } - this.isModelInitialized = true // TODO: it should depend on what is happened with missed weights + loadWeights(layersToWeightPaths.keys) { layer -> + fillLayerWeights(layer, hdfFile, layersToWeightPaths[layer], this) + } + // TODO: isModelInitialized should depend on what is happened with missed weights } /** This strategy defines the behaviour during weights' loading if the weights are not found in the h5 file by the standard Keras paths. */ @@ -363,27 +186,94 @@ public enum class MissedWeightsStrategy { */ public fun GraphTrainableModel.loadWeightsByPaths( hdfFile: HdfFile, - layerList: MutableList, + layerList: List, kernelDataPathTemplate: String = KERNEL_DATA_PATH_TEMPLATE, biasDataPathTemplate: String = BIAS_DATA_PATH_TEMPLATE ) { - check(this.isModelCompiled) { "The model is not compiled yet. Compile the model to use this method." } - check(!isModelInitialized) { "Model is initialized already!" } - this.logger.info { "Starting weights loading.." } - this.layers.forEach { - run { - if (layerList.contains(it)) { - fillLayerWeights( - it, - hdfFile, - LayerConvOrDensePaths("", kernelDataPathTemplate, biasDataPathTemplate), - this - ) // TODO: does not work for BatchNorm/Depthwise - } else { - initLayerWeights(it, this) - } + // TODO: does not work for BatchNorm/Depthwise + val layerConvOrDensePaths = LayerConvOrDensePaths("", kernelDataPathTemplate, biasDataPathTemplate) + loadWeights(layerList) { layer -> + fillLayerWeights(layer, hdfFile, layerConvOrDensePaths, this) + } +} + +private fun GraphTrainableModel.loadWeights(layersToLoad: Collection, loadWeightsBlock: (Layer) -> Unit) { + check(isModelCompiled) { "Model is not compiled yet. Compile the model before loading weights." } + check(!isModelInitialized) { "Model is already initialized." } + logger.info { "Starting loading weights..." } + + val layerSet = layersToLoad.toSet() + layers.forEach { layer -> + if (layerSet.contains(layer)) { + loadWeightsBlock(layer) + } else { + initLayerWeights(layer, this) } } - this.logger.info { "Weights are loaded." } - this.isModelInitialized = true + + logger.info { "Weights are loaded." } + isModelInitialized = true +} + +private fun fillLayerWeights(layer: Layer, group: Group, model: GraphTrainableModel) { + val variables = getLayerVariables(layer) + if (variables == null) { + model.logger.warn { "Loading weights for the layer ${layer.name} is skipped as ${layer::class.qualifiedName} layers are not supported." } + return + } + fillLayerVariablesFromKeras(layer.name, variables, model, group) + model.logger.debug { "${layer.paramCount} parameters loaded for the layer ${layer.name}." } +} + +private fun fillLayerVariablesFromKeras(layerName: String, + variables: Map>, + model: GraphTrainableModel, + group: Group +) { + val layerWeightsNode = group.children[layerName] as? Group + check(layerWeightsNode != null) { + val availableLayerNames = group.children.values.map(Node::getName) + val modelLayerNames = model.layers.map(Layer::name) + "Weights for the loaded layer $layerName are not found in .h5 file! " + + "\nh5 weight file contains weights for the following list of layers: $availableLayerNames" + + "\nDouble-check your loaded configuration which contains layers with the following names: $modelLayerNames." + } + + val nameOfWeightSubGroup = layerWeightsNode.children.keys.first() + val dataNodes = (layerWeightsNode.children[nameOfWeightSubGroup] as Group).children + + dataNodes.values.map { it as DatasetBase }.forEach { + val (name, shape) = variables[it.name] + ?: throw IllegalArgumentException( + "Parsing of h5 file for variable with name ${it.name} in layer $layerName is not supported!" + ) + val dims = it.dimensions + require(shape.map(Long::toInt).toIntArray().contentEquals(dims)) { + "$name variable shape in loaded data is ${dims.contentToString()}. Should be ${shape.contentToString()}" + } + model.fillVariable(name, it.data) + } +} + +private fun fillLayerWeights(layer: Layer, hdfFile: HdfFile, layerPaths: LayerPaths?, model: GraphTrainableModel) { + val variables = getLayerVariablePathTemplates(layer, layerPaths) + if (variables == null) { + model.logger.warn { "Loading weights for the layer ${layer.name} is skipped as ${layer::class.qualifiedName} layers are not supported." } + return + } + variables.forEach { (variableName, variableDataPathTemplate) -> + val data = hdfFile.getDatasetByPath(variableDataPathTemplate.format(layer.name, layer.name)).data + model.fillVariable(variableName, data) + } + model.logger.debug { "${layer.paramCount} parameters loaded for the layer ${layer.name}." } +} + +private fun initLayerWeights(layer: Layer, model: GraphTrainableModel) { + val variables = getLayerVariableNames(layer) + if (variables == null) { + model.logger.warn { "Initializing weights for the layer ${layer.name} is skipped as ${layer::class.qualifiedName} layers are not supported." } + return + } + variables.forEach(model::runAssignOpByVarName) + model.logger.debug { "${layer.paramCount} parameters initialized for the layer ${layer.name}." } } \ No newline at end of file diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt index 7a11f3cb4..e90aca8eb 100644 --- a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/integration/TransferLearningFromKerasTest.kt @@ -386,7 +386,7 @@ class TransferLearningTest : IntegrationTest() { it.loadWeights(hdfFile) } assertEquals( - "Model is initialized already!", + "Model is already initialized.", exception.message ) }