-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for PyTorch model #346
Conversation
Test FAILed. |
@@ -1,15 +1,20 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pulled the latest cloudpickle. This fixes issues with recursive dependencies that caused issues when serializing pytorch models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're replacing Cloudpickle, let's just switch to using the cloudpickle pip package (pip install cloudpickle
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #343
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
PYTORCH_MODEL_RELATIVE_PATH) | ||
|
||
try: | ||
torch.save(pytorch_model.state_dict(), torch_weights_save_loc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now save the model weights and class definition separately. Weights are saved via PyTorch, and the class is pickled with CloudPickle
Test PASSed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two small changes then this is good to go. @Corey-Zumar can you go ahead and make those?
@@ -1,15 +1,20 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're replacing Cloudpickle, let's just switch to using the cloudpickle pip package (pip install cloudpickle
).
&& apt-get update --fix-missing \ | ||
&& apt-get install -yqq -t jessie-backports openjdk-8-jdk \ | ||
&& conda install -y --file /lib/python_container_conda_deps.txt \ | ||
&& conda install pytorch torchvision -c soumith |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They created a PyTorch conda channel recently, so this should be conda install pytorch torchvision -c pytorch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
jenkins test this please |
Test FAILed. |
Test PASSed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
@haofanwang Added some updates to fix the class serialization issue and accidentally closed #322 in the process. Take a look and let me know if you have any questions.
This should be good to go once @dcrankshaw takes a final pass.
Fixes #343, #314