Skip to content

Commit

Permalink
Simplify api for building Layers (#408)
Browse files Browse the repository at this point in the history
* Move assigning Layer#outputShape out of Layer implementations

* Cleanup duplicated input shapes check methods in the AbstractMerge

* Make exception message in the AbstractMerge more informative

* Simplify building a Layer and computing output shape in functional api

Make functional api signatures similar to the sequential api.

* Combine build and computeOutputShape functions in the Layer

* Combine buildLayers and forward functions in the GraphTrainableModel

* Combine build and forward functions in the Layer
  • Loading branch information
juliabeliaeva authored Aug 1, 2022
1 parent 8ab9d85 commit 579f167
Show file tree
Hide file tree
Showing 78 changed files with 281 additions and 961 deletions.
25 changes: 10 additions & 15 deletions api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.core.util.sortTopologically
import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.tensorflow.Operand
import org.tensorflow.op.core.Placeholder
import java.io.File
import java.io.FileNotFoundException

Expand Down Expand Up @@ -97,7 +98,7 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
if (topModel is Sequential && layers.size > 1) {
// establish edges in DAG
topLayers.subList(1, topLayers.size).forEachIndexed { index, layer ->
val topLayersIndex = index - 1 + 1
val topLayersIndex = index - 1 + 1
// shift -1 to take previous, but shift +1 because it's an index in subList, started from 1
layer.inboundLayers.add(topLayers[topLayersIndex])
}
Expand Down Expand Up @@ -254,24 +255,18 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
}
}

override fun buildLayers() {
inputLayer.build(tf)
inputLayer.computeOutputShape()
override fun buildLayers(training: Operand<Boolean>, numberOfLosses: Operand<Float>): Pair<Placeholder<Float>, Operand<Float>> {
val input = inputLayer.build(tf)
inputLayer.setOutputShape(input.asOutput().shape())
val output = mutableMapOf<Layer, Operand<Float>>(inputLayer to input)

layers.filter { it !is Input }.forEach { layer ->
layer.buildFromInboundLayers(tf)
val outputShape = layer.computeOutputShapeFromInboundLayers()
layer.setOutputShape(outputShape)
logger.debug { "${layer.name}; $layer; outputShape: $outputShape" }
val out = layer.build(tf, layer.inboundLayers.map { output[it]!! }, training, numberOfLosses)
output[layer] = out
layer.setOutputShape(out.asOutput().shape())
}
}

override fun forward(input: Operand<Float>, inputLayer: Input): Operand<Float> {
val output = mutableMapOf<Layer, Operand<Float>>(inputLayer to inputLayer.forward(tf, input, training, numberOfLossesOp))
for (layer in layers.filter { it !is Input }) {
output[layer] = layer.forward(tf, layer.inboundLayers.map { output[it]!! }, training, numberOfLossesOp)
}
return output[layers.last()]!!
return input to output[layers.last()]!!
}

override fun save(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,18 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
this.metrics = metrics
this.optimizer = optimizer

buildLayers()
training = tf.withName("training").placeholder(
Boolean::class.javaObjectType,
Placeholder.shape(Shape.scalar())
)
numberOfLossesOp = tf.withName("numberOfLosses").placeholder(
getDType(),
Placeholder.shape(Shape.scalar())
)

val (input, output) = buildLayers(training, numberOfLossesOp)
xOp = input
yPredOp = output

// should be after outputShape calculation
numberOfClasses = when (val lastLayer = layers.last()) {
Expand All @@ -165,19 +176,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
else -> 1
}

xOp = inputLayer.input
yTrueOp = tf.placeholder(getDType()) as Operand<Float>
numberOfLossesOp = tf.withName("numberOfLosses").placeholder(
getDType(),
Placeholder.shape(Shape.scalar())
)

training = tf.withName("training").placeholder(
Boolean::class.javaObjectType,
Placeholder.shape(Shape.scalar())
)

yPredOp = forward(xOp, inputLayer)
lossOp = buildLossFunction(loss)
targets = optimizer.prepareTargets(kGraph, layers.trainableVariables().map { it.variable }, tf, lossOp)

Expand Down Expand Up @@ -212,11 +211,11 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
compile(optimizer, loss, Metrics.convert(metric))
}

/** Common method for building the initial part of the model static graph layer by layer via calling build() method on each layer in correct order. */
protected abstract fun buildLayers()

/** Forms forward path as a part of the model static graph layer by layer via calling forward() method on each layer in correct order. */
protected abstract fun forward(input: Operand<Float>, inputLayer: Input): Operand<Float>
/** Common method for building model static graph layer by layer via calling build() method on each layer in correct order. */
protected abstract fun buildLayers(
training: Operand<Boolean>,
numberOfLosses: Operand<Float>
): Pair<Placeholder<Float>, Operand<Float>>

override fun fit(
trainingDataset: Dataset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.core.Placeholder
import java.io.File
import java.io.FileNotFoundException

Expand Down Expand Up @@ -142,25 +141,17 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) {
}
}

override fun buildLayers() {
inputLayer.build(tf)
var inputShape: Shape = inputLayer.computeOutputShape()
override fun buildLayers(training: Operand<Boolean>, numberOfLosses: Operand<Float>): Pair<Placeholder<Float>, Operand<Float>> {
val input = inputLayer.build(tf)
inputLayer.setOutputShape(input.asOutput().shape())
var output: Operand<Float> = input

layers.filter { it !is Input }.forEach { layer ->
layer.build(tf, inputShape)

inputShape = layer.computeOutputShape(inputShape)
layer.setOutputShape(inputShape)
logger.debug { "${layer.name}; $layer; outputShape: $inputShape" }
output = layer.build(tf, output, training, numberOfLossesOp)
layer.setOutputShape(output.asOutput().shape())
}
}

override fun forward(input: Operand<Float>, inputLayer: Input): Operand<Float> {
var out: Operand<Float> = input
for (layer in layers) {
out = layer.forward(tf, out, training, numberOfLossesOp)
}
return out
return input to output
}

/** Returns a copy of this model. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,63 +30,33 @@ public abstract class Layer(public var name: String) {
public var outboundLayers: MutableList<Layer> = mutableListOf()

/**
* Extend this function to define variables in layer.
* Extend this function to define variables in the layer and compute layer output.
*
* @param [tf] TensorFlow graph API for building operations.
* @param [inputShape] Input shape, result of [computeOutputShape] call from previous layer.
* @param [input] Layer input.
* @param [isTraining] TensorFlow operand for switching between training and inference modes.
* @param [numberOfLosses] TensorFlow operand for batch size data.
*/
public abstract fun build(tf: Ops, inputShape: Shape)

public abstract fun build(tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?): Operand<Float>

/**
* Extend this function to define variables in layer.
* Extend this function to define variables in the layer and compute layer output.
*
* NOTE: This function should be overridden for layers with multiple inputs.
* NOTE: Used in Functional API
*
* @param [tf] TensorFlow graph API for building operations.
* @param [input] Layer input list.
* @param [isTraining] TensorFlow operand for switching between training and inference modes.
* @param [numberOfLosses] TensorFlow operand for batch size data.
*/
public fun buildFromInboundLayers(tf: Ops) {
require(inboundLayers.isNotEmpty()) { "There is no inbound layers to compute output shape" }
build(tf, inboundLayers[0].outputShape.toShape())
}

/**
* Computes output shape, based on [inputShape] and [Layer] type.
*/
public abstract fun computeOutputShape(inputShape: Shape): Shape

/**
* Computes output shape, based on input shapes of inbound layers.
*
* NOTE: This function should be overridden for layers with multiple inputs.
* NOTE: Used in Functional API
*/
public open fun computeOutputShapeFromInboundLayers(): TensorShape {
require(inboundLayers.isNotEmpty()) { "There is no inbound layers to compute output shape" }
return TensorShape(computeOutputShape(inboundLayers[0].outputShape.toShape()))
}

/**
* Builds main layer input transformation with [tf]. Depends on [Layer] type.
*/
public abstract fun forward(
tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float>

/**
* Builds main layer input transformation with [tf]. Depends on [Layer] type.
*/
public open fun forward(
tf: Ops,
input: List<Operand<Float>>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float> {
return forward(tf, input[0], isTraining, numberOfLosses)
public open fun build(tf: Ops,
input: List<Operand<Float>>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?): Operand<Float> {
return build(tf, input.first(), isTraining, numberOfLosses)
}

/** Important part of functional API. It takes [layers] as input and saves them to the [inboundLayers] of the given layer. */
Expand Down Expand Up @@ -129,10 +99,7 @@ internal fun LongArray.toIntArray(): IntArray {
}

internal fun Layer.setOutputShape(shape: Shape) {
setOutputShape(TensorShape(shape))
}

internal fun Layer.setOutputShape(tensorShape: TensorShape) {
val tensorShape = TensorShape(shape)
check(tensorShape.tail().all { elem -> elem > 0 })
{
"The last dimensions (except first = -1) of shape of layer $name contains zero or negative dimension values: ${tensorShape}.\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
package org.jetbrains.kotlinx.dl.api.core.layer.activation

import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops

/**
Expand All @@ -34,19 +32,12 @@ public abstract class AbstractActivationLayer(name: String) : Layer(name) {
input: Operand<Float>
): Operand<Float>

override fun forward(
override fun build(
tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float> = forward(tf, input)

override fun build(tf: Ops, inputShape: Shape): Unit = Unit

override fun computeOutputShape(inputShape: Shape): Shape {
this.outputShape = TensorShape(inputShape)
return inputShape
}

override val hasActivation: Boolean get() = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ public class PReLU(

override var isTrainable: Boolean = true

override fun build(tf: Ops, inputShape: Shape) {
override fun forward(tf: Ops, input: Operand<Float>): Operand<Float> {
val inputShape = input.asOutput().shape()
val alphaShapeArray = inputShape.toLongArray().drop(1).toLongArray()
if (sharedAxes != null) {
for (axis in sharedAxes) {
Expand All @@ -72,9 +73,7 @@ public class PReLU(
alphaInitializer,
alphaRegularizer
)
}

override fun forward(tf: Ops, input: Operand<Float>): Operand<Float> {
// It's equivalent to: `-alpha * relu(-x) + relu(x)`
val positive = tf.nn.relu(input)
val negative = tf.math.mul(tf.math.neg(alpha.variable), tf.nn.relu(tf.math.neg(input)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import org.jetbrains.kotlinx.dl.api.core.activation.Activations
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.layer.*
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.core.shape.shapeFromDims
import org.tensorflow.Operand
import org.tensorflow.Shape
Expand Down Expand Up @@ -65,7 +64,12 @@ public abstract class AbstractConv(
public override val variables: List<KVariable>
get() = listOfNotNull(kernel, bias)

override fun build(tf: Ops, inputShape: Shape) {
override fun build(tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float> {
val inputShape = input.asOutput().shape()
// Amount of channels should be the last value in the inputShape
val numberOfChannels = inputShape.size(inputShape.numDimensions() - 1)

Expand Down Expand Up @@ -96,24 +100,9 @@ public abstract class AbstractConv(
biasRegularizer
)
}
}

override fun computeOutputShape(inputShape: Shape): Shape {
val shape = defineOutputShape(inputShape)
outputShape = TensorShape(shape)
return shape
}

override fun forward(
tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float> {
val convolution = convImplementation(tf, input)

val withBias = bias?.let { tf.nn.biasAdd(convolution, it.variable) } ?: convolution

return Activations.convert(activation).apply(tf, withBias, name)
}

Expand Down Expand Up @@ -149,15 +138,6 @@ public abstract class AbstractConv(

/** The actual layer operation implementation without adding the bias which is added by the abstract class. */
protected abstract fun convImplementation(tf: Ops, input: Operand<Float>): Operand<Float>

/**
* Actual implementation of [computeOutputShape] which only defines the value
* of output shape without the need of saving it to some variable.
*
* @param inputShape which can be used to define the output shape
* @return the defined output shape that is saved in class variable and returned by [computeOutputShape]]
*/
protected abstract fun defineOutputShape(inputShape: Shape): Shape
}

private fun multiply(values: LongArray) = values.fold(1L, Long::times)
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@ import org.jetbrains.kotlinx.dl.api.core.layer.requireArraySize
import org.jetbrains.kotlinx.dl.api.core.layer.toLongList
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.core.shape.convOutputLength
import org.jetbrains.kotlinx.dl.api.core.util.convBiasVarName
import org.jetbrains.kotlinx.dl.api.core.util.convKernelVarName
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Squeeze
import org.tensorflow.op.nn.Conv2d
Expand Down Expand Up @@ -127,21 +125,6 @@ public class Conv1D(
}
}

protected override fun defineOutputShape(inputShape: Shape): Shape {
val batchSize = inputShape.size(0)
val colsCount = inputShape.size(1)

val cols = convOutputLength(
colsCount,
kernelLength,
padding,
strides[1],
dilations[1]
)

return Shape.make(batchSize, cols, filters.toLong())
}

override fun toString(): String {
return "Conv1D(name = $name, isTrainable=$isTrainable, filters=$filters, kernelSize=${kernelSize.contentToString()}, " +
"strides=${strides.contentToString()}, dilations=${dilations.contentToString()}, " +
Expand Down
Loading

0 comments on commit 579f167

Please sign in to comment.