-
Notifications
You must be signed in to change notification settings - Fork 233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Run GPU tests on Jax + Torch #1160
Conversation
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
/gcbrun |
@@ -1,7 +1,4 @@ | |||
keras-core | |||
# Consider handling GPU here. | |||
torch>=2.0.1+cpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious what your take is here -- this can definitely be handled differently for the docker configs, but the way I see it these aren't really requirements for KerasNLP. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My take is basically what is laid out here. install_requires
should list the minimum needed for a valid install, but requirements.txt
should be an exhaustive and repeatable recipe for a complete environment.
After the next tf release, I think we can get by with just a requirements.txt
and requirements-cuda.txt
. These should be used by all our tooling (no listing extra deps in docker files, etc), and used by contributors to set up a development environment that matches our CI.
Short term, we could consider a requirements-jax.txt
, requirements-tf.txt
etc, that match our docker setups. Then all our version pinning, hacking, etc, is consolidated to requirements files.
No need to figure this out all on this PR though. Fine to land as is and keep tweaking.
Okay @mattdangerw I'm sending this for review, but many of the tests are still failing on GCB. For PyTorch, basically all of the failures are due to some sort of GPU device placement issue. I tried just updating the test cases but it ended up breaking a bunch of other tests with some cascading failures. The Jax+TF failures should now be just the |
/gcbrun |
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ianstenbit flipping back to you with a last few comments on the test setup!
Thanks very much for tackling this! These last failure is a known flake, we can ignore for now and I will open up a fix.
- Repeat the last two steps for Jax and Torch (replacing "tensorflow" with "jax" | ||
or "torch" in the docker image target name). `Dockerfile` for jax: | ||
``` | ||
FROM nvidia/cuda:11.7.1-base-ubuntu20.04 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not 11.8.0
here? that's the version tf depends on, and it would be nice to consolidate on one version of cuda for our testing. https://www.tensorflow.org/install/pip
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did 11.7 because that's the current default for PyTorch, and I figured we might as well use the same base image for the two to reduce possibility for variance. But 11.8 should work for both, we'd just need to update the pip install commands to use the correct cuda version
@@ -65,6 +65,8 @@ jobs: | |||
- name: Install dependencies | |||
run: | | |||
pip install -r requirements.txt --progress-bar off | |||
pip install torch>=2.0.1+cpu --progress-bar off |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would also need to go into the publish to pypi action (though long term, I still favor a requirments file as a single source of truth to avoid duplication)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this feels a little janky to me as well. It's possible that we could leave these in requirements.txt for now and just have the GPU CI workflow uninstall the CPU versions before installing the GPU versions.
I'm fine either way, this just seemed like less of a headache for the GPU tests.
Here's fix for that flake, but I think we can merge things in any order. #1171 |
I can't technically LGTM since I originally opened up the PR, but... LGTM. Feel free to merge this whenever you're ready. |
Thanks!! |
* Update requirements and README * Update test configs * Fix normal CI * . - but that is actually the commit message * Fix cloudbuild dockerfile * Fix docker configs * Some test case fixes * Fix rich imports * Fix test case * Revert test case * Fix gpt_neo_x saving * Skip xlm roberta presets on jax/torch for now * Fix torch GPU detach errors * More detach fixes --------- Co-authored-by: Matt Watson <mattdangerw@gmail.com>
This was done in KerasCV in keras-team/keras-cv#1935
I've made the relevant changes to our GCB config on the backend already, and I've built the docker images for each of the three test suites.