-
Notifications
You must be signed in to change notification settings - Fork 330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GPU testing for Jax+Torch #1935
Conversation
/gcbrun |
1 similar comment
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
std.format( | ||
||| | ||
export KERAS_BACKEND=%s | ||
export JAX_ENABLE_X64=true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you end up needed int64/float64? Just curious really.
In KerasNLP we have just been trying to go int32 everywhere by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we probably should just do int32 everywhere. During NMS porting I had some things that were real sticklers about int64 in TF and I never found a way to work around
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we potentially just change the dtype for those tests? Asking since we're always flirting with OOM issues and int64 can't help
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially, yes. I did some looking and it seems like the only place we're using int64 internally with KerasCore is in the YOLOV8 label encoder -- I guess I already got rid of them from NMS. I'll check with Tirth if he's planning on getting rid of those.
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@@ -78,7 +78,7 @@ jobs: | |||
KERAS_BACKEND: ${{ matrix.backend }} | |||
JAX_ENABLE_X64: true | |||
run: | | |||
pytest --run_large keras_cv/bounding_box \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume --run_large
was a mistake here since this is CI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a mistake, I was just doing this during development of the rebase so that we covered all the tests. But yes we shouldn't include it anymore
std.format( | ||
||| | ||
export KERAS_BACKEND=%s | ||
export JAX_ENABLE_X64=true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we potentially just change the dtype for those tests? Asking since we're always flirting with OOM issues and int64 can't help
* Add GPU testing for torch and jax * Consolidate cloudbuild files * Reformat image name * gcr.io/cloud-builders/docker * Underscores are hard * Yay Docker * I have activated my second brain cell * IMAGE_NAME * Entrypoint fix in jssonnet * Re-do env variables in jssonnet * Another one * Testing an idea * Try string format * Remove bad export * Rename + try Torch docker image * Create a base test case with Numpy conversion * Some test fixes * Some test fixes * We out here fixing tests * Test fixes -- morning style! * Update README and include a CUDA verification test * Better cuda test * Working docker config * Last round of test fixes ... maybe? * Fix docstring + add attribution to Matt
In order to make this work, I also had to create a new base
TestCase
which is why the diff got so massive.But hey, we're now doing GPU testing with CUDA on 3 different backends -- which is pretty dang cool!