Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify api for building Layers #408

Merged
merged 7 commits into from
Aug 1, 2022
25 changes: 10 additions & 15 deletions api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/Functional.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.core.util.sortTopologically
import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.tensorflow.Operand
import org.tensorflow.op.core.Placeholder
import java.io.File
import java.io.FileNotFoundException

Expand Down Expand Up @@ -97,7 +98,7 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
if (topModel is Sequential && layers.size > 1) {
// establish edges in DAG
topLayers.subList(1, topLayers.size).forEachIndexed { index, layer ->
val topLayersIndex = index - 1 + 1
val topLayersIndex = index - 1 + 1
// shift -1 to take previous, but shift +1 because it's an index in subList, started from 1
layer.inboundLayers.add(topLayers[topLayersIndex])
}
Expand Down Expand Up @@ -254,24 +255,18 @@ public class Functional(vararg layers: Layer) : GraphTrainableModel(*layers) {
}
}

override fun buildLayers() {
inputLayer.build(tf)
inputLayer.computeOutputShape()
override fun buildLayers(training: Operand<Boolean>, numberOfLosses: Operand<Float>): Pair<Placeholder<Float>, Operand<Float>> {
val input = inputLayer.build(tf)
inputLayer.setOutputShape(input.asOutput().shape())
val output = mutableMapOf<Layer, Operand<Float>>(inputLayer to input)

layers.filter { it !is Input }.forEach { layer ->
layer.buildFromInboundLayers(tf)
val outputShape = layer.computeOutputShapeFromInboundLayers()
layer.setOutputShape(outputShape)
logger.debug { "${layer.name}; $layer; outputShape: $outputShape" }
val out = layer.build(tf, layer.inboundLayers.map { output[it]!! }, training, numberOfLosses)
output[layer] = out
layer.setOutputShape(out.asOutput().shape())
}
}

override fun forward(input: Operand<Float>, inputLayer: Input): Operand<Float> {
val output = mutableMapOf<Layer, Operand<Float>>(inputLayer to inputLayer.forward(tf, input, training, numberOfLossesOp))
for (layer in layers.filter { it !is Input }) {
output[layer] = layer.forward(tf, layer.inboundLayers.map { output[it]!! }, training, numberOfLossesOp)
}
return output[layers.last()]!!
return input to output[layers.last()]!!
}

override fun save(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,18 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
this.metrics = metrics
this.optimizer = optimizer

buildLayers()
training = tf.withName("training").placeholder(
Boolean::class.javaObjectType,
Placeholder.shape(Shape.scalar())
)
numberOfLossesOp = tf.withName("numberOfLosses").placeholder(
getDType(),
Placeholder.shape(Shape.scalar())
)

val (input, output) = buildLayers(training, numberOfLossesOp)
xOp = input
yPredOp = output

// should be after outputShape calculation
numberOfClasses = when (val lastLayer = layers.last()) {
Expand All @@ -165,19 +176,7 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
else -> 1
}

xOp = inputLayer.input
yTrueOp = tf.placeholder(getDType()) as Operand<Float>
numberOfLossesOp = tf.withName("numberOfLosses").placeholder(
getDType(),
Placeholder.shape(Shape.scalar())
)

training = tf.withName("training").placeholder(
Boolean::class.javaObjectType,
Placeholder.shape(Shape.scalar())
)

yPredOp = forward(xOp, inputLayer)
lossOp = buildLossFunction(loss)
targets = optimizer.prepareTargets(kGraph, layers.trainableVariables().map { it.variable }, tf, lossOp)

Expand Down Expand Up @@ -212,11 +211,11 @@ public abstract class GraphTrainableModel(vararg layers: Layer) : TrainableModel
compile(optimizer, loss, Metrics.convert(metric))
}

/** Common method for building the initial part of the model static graph layer by layer via calling build() method on each layer in correct order. */
protected abstract fun buildLayers()

/** Forms forward path as a part of the model static graph layer by layer via calling forward() method on each layer in correct order. */
protected abstract fun forward(input: Operand<Float>, inputLayer: Input): Operand<Float>
/** Common method for building model static graph layer by layer via calling build() method on each layer in correct order. */
protected abstract fun buildLayers(
training: Operand<Boolean>,
numberOfLosses: Operand<Float>
): Pair<Placeholder<Float>, Operand<Float>>

override fun fit(
trainingDataset: Dataset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.setOutputShape
import org.jetbrains.kotlinx.dl.api.core.layer.weights
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.core.Placeholder
import java.io.File
import java.io.FileNotFoundException

Expand Down Expand Up @@ -142,25 +141,17 @@ public class Sequential(vararg layers: Layer) : GraphTrainableModel(*layers) {
}
}

override fun buildLayers() {
inputLayer.build(tf)
var inputShape: Shape = inputLayer.computeOutputShape()
override fun buildLayers(training: Operand<Boolean>, numberOfLosses: Operand<Float>): Pair<Placeholder<Float>, Operand<Float>> {
val input = inputLayer.build(tf)
inputLayer.setOutputShape(input.asOutput().shape())
var output: Operand<Float> = input

layers.filter { it !is Input }.forEach { layer ->
layer.build(tf, inputShape)

inputShape = layer.computeOutputShape(inputShape)
layer.setOutputShape(inputShape)
logger.debug { "${layer.name}; $layer; outputShape: $inputShape" }
output = layer.build(tf, output, training, numberOfLossesOp)
layer.setOutputShape(output.asOutput().shape())
}
}

override fun forward(input: Operand<Float>, inputLayer: Input): Operand<Float> {
var out: Operand<Float> = input
for (layer in layers) {
out = layer.forward(tf, out, training, numberOfLossesOp)
}
return out
return input to output
}

/** Returns a copy of this model. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,63 +30,33 @@ public abstract class Layer(public var name: String) {
public var outboundLayers: MutableList<Layer> = mutableListOf()

/**
* Extend this function to define variables in layer.
* Extend this function to define variables in the layer and compute layer output.
*
* @param [tf] TensorFlow graph API for building operations.
* @param [inputShape] Input shape, result of [computeOutputShape] call from previous layer.
* @param [input] Layer input.
* @param [isTraining] TensorFlow operand for switching between training and inference modes.
* @param [numberOfLosses] TensorFlow operand for batch size data.
*/
public abstract fun build(tf: Ops, inputShape: Shape)

public abstract fun build(tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?): Operand<Float>

/**
* Extend this function to define variables in layer.
* Extend this function to define variables in the layer and compute layer output.
*
* NOTE: This function should be overridden for layers with multiple inputs.
* NOTE: Used in Functional API
*
* @param [tf] TensorFlow graph API for building operations.
* @param [input] Layer input list.
* @param [isTraining] TensorFlow operand for switching between training and inference modes.
* @param [numberOfLosses] TensorFlow operand for batch size data.
*/
public fun buildFromInboundLayers(tf: Ops) {
require(inboundLayers.isNotEmpty()) { "There is no inbound layers to compute output shape" }
build(tf, inboundLayers[0].outputShape.toShape())
}

/**
* Computes output shape, based on [inputShape] and [Layer] type.
*/
public abstract fun computeOutputShape(inputShape: Shape): Shape

/**
* Computes output shape, based on input shapes of inbound layers.
*
* NOTE: This function should be overridden for layers with multiple inputs.
* NOTE: Used in Functional API
*/
public open fun computeOutputShapeFromInboundLayers(): TensorShape {
require(inboundLayers.isNotEmpty()) { "There is no inbound layers to compute output shape" }
return TensorShape(computeOutputShape(inboundLayers[0].outputShape.toShape()))
}

/**
* Builds main layer input transformation with [tf]. Depends on [Layer] type.
*/
public abstract fun forward(
tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float>

/**
* Builds main layer input transformation with [tf]. Depends on [Layer] type.
*/
public open fun forward(
tf: Ops,
input: List<Operand<Float>>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float> {
return forward(tf, input[0], isTraining, numberOfLosses)
public open fun build(tf: Ops,
input: List<Operand<Float>>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?): Operand<Float> {
return build(tf, input.first(), isTraining, numberOfLosses)
}

/** Important part of functional API. It takes [layers] as input and saves them to the [inboundLayers] of the given layer. */
Expand Down Expand Up @@ -129,10 +99,7 @@ internal fun LongArray.toIntArray(): IntArray {
}

internal fun Layer.setOutputShape(shape: Shape) {
setOutputShape(TensorShape(shape))
}

internal fun Layer.setOutputShape(tensorShape: TensorShape) {
val tensorShape = TensorShape(shape)
check(tensorShape.tail().all { elem -> elem > 0 })
{
"The last dimensions (except first = -1) of shape of layer $name contains zero or negative dimension values: ${tensorShape}.\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
package org.jetbrains.kotlinx.dl.api.core.layer.activation

import org.jetbrains.kotlinx.dl.api.core.layer.Layer
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops

/**
Expand All @@ -34,19 +32,12 @@ public abstract class AbstractActivationLayer(name: String) : Layer(name) {
input: Operand<Float>
): Operand<Float>

override fun forward(
override fun build(
tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float> = forward(tf, input)

override fun build(tf: Ops, inputShape: Shape): Unit = Unit

override fun computeOutputShape(inputShape: Shape): Shape {
this.outputShape = TensorShape(inputShape)
return inputShape
}

override val hasActivation: Boolean get() = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ public class PReLU(

override var isTrainable: Boolean = true

override fun build(tf: Ops, inputShape: Shape) {
override fun forward(tf: Ops, input: Operand<Float>): Operand<Float> {
val inputShape = input.asOutput().shape()
val alphaShapeArray = inputShape.toLongArray().drop(1).toLongArray()
if (sharedAxes != null) {
for (axis in sharedAxes) {
Expand All @@ -72,9 +73,7 @@ public class PReLU(
alphaInitializer,
alphaRegularizer
)
}

override fun forward(tf: Ops, input: Operand<Float>): Operand<Float> {
// It's equivalent to: `-alpha * relu(-x) + relu(x)`
val positive = tf.nn.relu(input)
val negative = tf.math.mul(tf.math.neg(alpha.variable), tf.nn.relu(tf.math.neg(input)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import org.jetbrains.kotlinx.dl.api.core.activation.Activations
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.layer.*
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.core.shape.shapeFromDims
import org.tensorflow.Operand
import org.tensorflow.Shape
Expand Down Expand Up @@ -65,7 +64,12 @@ public abstract class AbstractConv(
public override val variables: List<KVariable>
get() = listOfNotNull(kernel, bias)

override fun build(tf: Ops, inputShape: Shape) {
override fun build(tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float> {
val inputShape = input.asOutput().shape()
// Amount of channels should be the last value in the inputShape
val numberOfChannels = inputShape.size(inputShape.numDimensions() - 1)

Expand Down Expand Up @@ -96,24 +100,9 @@ public abstract class AbstractConv(
biasRegularizer
)
}
}

override fun computeOutputShape(inputShape: Shape): Shape {
val shape = defineOutputShape(inputShape)
outputShape = TensorShape(shape)
return shape
}

override fun forward(
tf: Ops,
input: Operand<Float>,
isTraining: Operand<Boolean>,
numberOfLosses: Operand<Float>?
): Operand<Float> {
val convolution = convImplementation(tf, input)

val withBias = bias?.let { tf.nn.biasAdd(convolution, it.variable) } ?: convolution

return Activations.convert(activation).apply(tf, withBias, name)
}

Expand Down Expand Up @@ -149,15 +138,6 @@ public abstract class AbstractConv(

/** The actual layer operation implementation without adding the bias which is added by the abstract class. */
protected abstract fun convImplementation(tf: Ops, input: Operand<Float>): Operand<Float>

/**
* Actual implementation of [computeOutputShape] which only defines the value
* of output shape without the need of saving it to some variable.
*
* @param inputShape which can be used to define the output shape
* @return the defined output shape that is saved in class variable and returned by [computeOutputShape]]
*/
protected abstract fun defineOutputShape(inputShape: Shape): Shape
}

private fun multiply(values: LongArray) = values.fold(1L, Long::times)
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@ import org.jetbrains.kotlinx.dl.api.core.layer.requireArraySize
import org.jetbrains.kotlinx.dl.api.core.layer.toLongList
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.core.shape.convOutputLength
import org.jetbrains.kotlinx.dl.api.core.util.convBiasVarName
import org.jetbrains.kotlinx.dl.api.core.util.convKernelVarName
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Squeeze
import org.tensorflow.op.nn.Conv2d
Expand Down Expand Up @@ -127,21 +125,6 @@ public class Conv1D(
}
}

protected override fun defineOutputShape(inputShape: Shape): Shape {
val batchSize = inputShape.size(0)
val colsCount = inputShape.size(1)

val cols = convOutputLength(
colsCount,
kernelLength,
padding,
strides[1],
dilations[1]
)

return Shape.make(batchSize, cols, filters.toLong())
}

override fun toString(): String {
return "Conv1D(name = $name, isTrainable=$isTrainable, filters=$filters, kernelSize=${kernelSize.contentToString()}, " +
"strides=${strides.contentToString()}, dilations=${dilations.contentToString()}, " +
Expand Down
Loading