diff --git a/.gitignore b/.gitignore index 120f03858..e7dfde769 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ /api/cache/ /pretrainedModels/ /examples/savedmodels/ +/dataset/build/ diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/Activations.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/Activations.kt index fd3fce544..022f5a9dc 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/Activations.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/Activations.kt @@ -248,6 +248,18 @@ public enum class Activations { */ HardShrink, + /** + * Gelu Function + * + * Computes the Gaussian Error Linear Unit (GELU): + * + * gelu(x) = x * P(X <= x) where P(X) ~ N(0, 1) + * + * Calls [GeluActivation] under the hood. + * @property [approximate], boolean to toggle approximation + */ + Gelu, + /** * Non-Parametric Linearly Scaled Hyperbolic Tangent (LiSHT) Activation Function. * @@ -280,6 +292,7 @@ public enum class Activations { Mish -> MishActivation() HardShrink -> HardShrinkActivation() LiSHT -> LishtActivation() + Gelu -> GeluActivation() } } } @@ -436,3 +449,37 @@ public class LishtActivation : Activation { override fun apply(tf: Ops, features: Operand): Operand = tf.math.mul(features, tf.math.tanh(features)) } + +/** + * @see [Activations.Gelu] + */ +public class GeluActivation(public val approximate: Boolean = false) : Activation { + override fun apply(tf: Ops, features: Operand): Operand { + if (approximate) { + val coeff = tf.constant(0.044715f) + return tf.math.mul( + tf.constant(0.5f), tf.math.mul( + features, tf.math.add( + tf.constant(1.0f), tf.math.tanh( + tf.math.mul( + tf.constant(0.7978845608028654f), // This value is equal to sqrt(2/pi) to avoid a constant division + tf.math.add(features, tf.math.mul(coeff, tf.math.pow(features, tf.constant(3f)))) + ) + ) + ) + ) + ) + } else { + return tf.math.mul( + tf.constant(0.5f), + tf.math.mul( + features, + tf.math.add( + tf.constant(1f), + tf.math.erf(tf.math.div(features, tf.constant(1.4142135623730951f))) + ) + ) + ) + } + } +} diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt index 9cc0ada5b..5a4753600 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/KerasConstants.kt @@ -106,6 +106,7 @@ internal const val ACTIVATION_EXP: String = "exponential" internal const val ACTIVATION_MISH: String = "mish" internal const val ACTIVATION_HARDSHRINK: String = "hardshrink" internal const val ACTIVATION_LISHT: String = "lisht" +internal const val ACTIVATION_GELU: String = "gelu" // Layer settings internal const val CHANNELS_LAST: String = "channels_last" diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt index 40a02dbe4..3d7b17aa3 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelLoader.kt @@ -360,6 +360,7 @@ private fun convertToActivation(activation: String): Activations { ACTIVATION_MISH -> Activations.Mish ACTIVATION_HARDSHRINK -> Activations.HardShrink ACTIVATION_LISHT -> Activations.LiSHT + ACTIVATION_GELU -> Activations.Gelu else -> throw IllegalStateException("$activation is not supported yet!") } } diff --git a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt index 4ad5bb407..e3448fa13 100644 --- a/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt +++ b/api/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/keras/ModelSaver.kt @@ -258,6 +258,7 @@ private fun convertToKerasActivation(activation: Activations): String? { Activations.Mish -> ACTIVATION_MISH Activations.HardShrink -> ACTIVATION_HARDSHRINK Activations.LiSHT -> ACTIVATION_LISHT + Activations.Gelu -> ACTIVATION_GELU } } diff --git a/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/GeluActivationTest.kt b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/GeluActivationTest.kt new file mode 100644 index 000000000..79c781285 --- /dev/null +++ b/api/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/GeluActivationTest.kt @@ -0,0 +1,22 @@ +package org.jetbrains.kotlinx.dl.api.core.activation + +import org.junit.jupiter.api.Test + +internal class GeluActivationTest : ActivationTest() { + + @Test + fun default() { + val input = floatArrayOf(-3.0f, -1.0f, 0.0f, 1.0f, 3.0f) + val actual = floatArrayOf(0f, 0f, 0f, 0f, 0f) + val expected = floatArrayOf(-0.00404951f, -0.15865529f, 0f, 0.8413447f, 2.9959507f) + assertActivationFunction(GeluActivation(), input, actual, expected) + } + + @Test + fun approxTest() { + val input = floatArrayOf(-3.0f, -1.0f, 0.0f, 1.0f, 3.0f) + val actual = floatArrayOf(0f, 0f, 0f, 0f, 0f) + val expected = floatArrayOf(-0.00363752f, -0.15880796f, 0f, 0.841192f, 2.9963627f) + assertActivationFunction(GeluActivation(approximate = true), input, actual, expected) + } +}