-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up passing different JaxSimModel
with same pytree structure to JIT-compiled functions
#179
Conversation
In this way, JaxSimModel.description is never used by JIT-compiled functions, and can be treated as an ignored static leaf of the pytree (otherwise computing its hash would be expensive).
41a00f1
to
b7ff488
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks Diego!
ffaa1d5
to
bf28ec1
Compare
Thanks @flferretti for the suggestion on more compact syntax. The original idea was to prevent evaluating all the conditions if one of the first ones is false. Your suggestion using |
Co-authored-by: Filippo Luca Ferretti <filippo.ferretti@iit.it>
bf28ec1
to
a959a00
Compare
PR #173 changed how
JaxSimModel
objects are hashed and compared. It was a necessary change since our frame-related data were read fromJaxSimModel.description
that is aStatic
attribute. Before that PR, the hash of the description was explicitly ignored, but that caused problems when different models were passed to the same JIT-compiled function because JAX was not taking into account possible differences in frame data. The outcome was that the frames test was not passing, and the solution was to makeJaxSimModel.description
properly hashable.These days I played a bit around with different
JaxSimModels
(e.g. parameterized models), and I discovered that if there are two exact instancesmodel1
andmodel2
created from the same URDF (therefore, not having the second model directly copied from the first one), calling a JIT-compiled function onmodel
was extremely slow even if JIT-recompilation was not triggered. The problem is that in these cases, JAX computes the hash of static attributes, and right now the hash ofModelDescription
takes hundreds of milliseconds. This is a problem because the processing of static attributes is orders of magnitude longer than the actual compiled computation.Originally, I though ways to speed up the equality computation of
ModelDescription
, and this PR contains new__eq__
methods that do not call__hash__
.However, then I realized that the best solution is to move also frame-related data in
KinDynParameters
, similarly to what we already do for links, joints, and contacts. In this way,JaxSimModel.description
can be ignored again, and there's no longer the need to compute its hash. This speeds up significantly calling function that have been JIT-compiled onmodel1
usingmodel2
.There still is an overhead. I guess that JAX saves the id of
model1
and skips the check on static attributes, check that instead is done for `model2. The following is the runtime on CPU using two full ErgoCub models:Note: this PR goes also in the direction of exploting the on-disk JAX compilation cache supported by the GPU and TPU backends. In the past, I've never managed to make it work with JaxSim, probably due to factors related to this PR. Now things should be better since hash and equality are correct. What I suspect is still missing is wrapping static strings with a custom hash function since by default the hash of a string in Python is not the same among executions.
📚 Documentation preview 📚: https://jaxsim--179.org.readthedocs.build//179/