-
Notifications
You must be signed in to change notification settings - Fork 655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LayerNorm using PyTorch #1069
LayerNorm using PyTorch #1069
Conversation
pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
Show resolved
Hide resolved
tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java
Show resolved
Hide resolved
@@ -144,6 +144,25 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNBatchNorm( | |||
API_END_RETURN() | |||
} | |||
|
|||
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm(JNIEnv* env, jobject jthis, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please run:
./gradlew formatCpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just did
gradlew formatCpp
Found C:\djl\\gradle\wrapper\gradle-wrapper.jar
Deprecated Gradle features were used in this build, making it incompatible with Gradle 8.0.
Use '--warning-mode all' to show the individual deprecation warnings.
See https://docs.gradle.org/7.0.2/userguide/command_line_interface.html#sec:command_line_warnings
BUILD SUCCESSFUL in 1s
6 actionable tasks: 6 executed
but not seen an effect. Do I miss something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, formatCpp doesn't work on Windows
Here is what it should looks like:
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm(
JNIEnv* env, jobject jthis, jlong jinput, jlongArray jnormalizedshape, jlong jweight, jlong jbias, jdouble jeps) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copy/paste the two lines ... hope it fits ... looks like I should build djl on linux.
…Ex.java Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
…DArrayEx.java Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
…ngine/TfNDArrayEx.java Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
Codecov Report
@@ Coverage Diff @@
## master #1069 +/- ##
============================================
+ Coverage 69.98% 70.02% +0.03%
- Complexity 5212 5227 +15
============================================
Files 510 511 +1
Lines 23255 23339 +84
Branches 2489 2492 +3
============================================
+ Hits 16276 16342 +66
- Misses 5650 5665 +15
- Partials 1329 1332 +3
Continue to review full report at Codecov.
|
@@ -144,6 +144,25 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNBatchNorm( | |||
API_END_RETURN() | |||
} | |||
|
|||
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm(JNIEnv* env, jobject jthis, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, formatCpp doesn't work on Windows
Here is what it should looks like:
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm(
JNIEnv* env, jobject jthis, jlong jinput, jlongArray jnormalizedshape, jlong jweight, jlong jbias, jdouble jeps) {
(PtNDArray) gamma, | ||
(PtNDArray) beta, | ||
eps)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add en empty line here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried to fix it ... but could be that I am blind here ... I am using
gradlew formatJava
and
gradlew build
@lanking520 @stu1130 Please take a look. |
Description
Similar to BatchNormalization there exist some other variants of normalizing data flowing through the network that have been implemented by the underlying ai frameworks. Here LayerNorm as one of them has been wired up to be used with PyTorch.
This PullRequest would close #1057.