Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up passing different JaxSimModel with same pytree structure to JIT-compiled functions #179

Merged
merged 12 commits into from
Jun 14, 2024

Conversation

diegoferigo
Copy link
Member

@diegoferigo diegoferigo commented Jun 14, 2024

PR #173 changed how JaxSimModel objects are hashed and compared. It was a necessary change since our frame-related data were read from JaxSimModel.description that is a Static attribute. Before that PR, the hash of the description was explicitly ignored, but that caused problems when different models were passed to the same JIT-compiled function because JAX was not taking into account possible differences in frame data. The outcome was that the frames test was not passing, and the solution was to make JaxSimModel.description properly hashable.

These days I played a bit around with different JaxSimModels (e.g. parameterized models), and I discovered that if there are two exact instances model1 and model2 created from the same URDF (therefore, not having the second model directly copied from the first one), calling a JIT-compiled function on model was extremely slow even if JIT-recompilation was not triggered. The problem is that in these cases, JAX computes the hash of static attributes, and right now the hash of ModelDescription takes hundreds of milliseconds. This is a problem because the processing of static attributes is orders of magnitude longer than the actual compiled computation.

Originally, I though ways to speed up the equality computation of ModelDescription, and this PR contains new __eq__ methods that do not call __hash__.

However, then I realized that the best solution is to move also frame-related data in KinDynParameters, similarly to what we already do for links, joints, and contacts. In this way, JaxSimModel.description can be ignored again, and there's no longer the need to compute its hash. This speeds up significantly calling function that have been JIT-compiled on model1 using model2.

There still is an overhead. I guess that JAX saves the id of model1 and skips the check on static attributes, check that instead is done for `model2. The following is the runtime on CPU using two full ErgoCub models:

import jaxsim.api as js
import resolve_robotics_uri_py
from jaxsim import VelRepr

urdf_path = resolve_robotics_uri_py.resolve_robotics_uri(
    uri="model://ergoCubSN001/model.urdf"
)

model1 = js.model.JaxSimModel.build_from_model_description(
    model_description=urdf_path,
    is_urdf=True,
)

model2 = js.model.JaxSimModel.build_from_model_description(
    model_description=urdf_path,
    is_urdf=True,
)

data = js.data.random_model_data(model=model1, velocity_representation=VelRepr.Mixed)

# First run for JIT compilation, second one for runtime.
%time _ = js.model.forward_dynamics_aba(model1, data)  # Wall time: 4.31 s
%timeit js.model.forward_dynamics_aba(model1, data)    # 296 µs ± 12.6 µs

# This should not get JIT compiled.
%time _ = js.model.forward_dynamics_aba(model2, data)  # Wall time: 1.56 ms

# This should go almost as fast as on model1.
# The overhead is due to the comparison of static attributes.
%timeit js.model.forward_dynamics_aba(model2, data)  # 883 µs ± 14.2 µs

Note: this PR goes also in the direction of exploting the on-disk JAX compilation cache supported by the GPU and TPU backends. In the past, I've never managed to make it work with JaxSim, probably due to factors related to this PR. Now things should be better since hash and equality are correct. What I suspect is still missing is wrapping static strings with a custom hash function since by default the hash of a string in Python is not the same among executions.


📚 Documentation preview 📚: https://jaxsim--179.org.readthedocs.build//179/

@diegoferigo diegoferigo self-assigned this Jun 14, 2024
@diegoferigo diegoferigo marked this pull request as ready for review June 14, 2024 09:47
Copy link
Collaborator

@flferretti flferretti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Diego!

src/jaxsim/parsers/descriptions/joint.py Outdated Show resolved Hide resolved
src/jaxsim/parsers/descriptions/joint.py Outdated Show resolved Hide resolved
src/jaxsim/api/data.py Outdated Show resolved Hide resolved
src/jaxsim/parsers/descriptions/model.py Outdated Show resolved Hide resolved
src/jaxsim/parsers/descriptions/link.py Outdated Show resolved Hide resolved
src/jaxsim/parsers/descriptions/link.py Outdated Show resolved Hide resolved
src/jaxsim/utils/wrappers.py Show resolved Hide resolved
@diegoferigo
Copy link
Member Author

Thanks @flferretti for the suggestion on more compact syntax. The original idea was to prevent evaluating all the conditions if one of the first ones is false. Your suggestion using all would work similarly only if its argument is a generator expression, which is not the case. I updated your suggestion to use chained and, Python stops early also in this case.

Co-authored-by: Filippo Luca Ferretti <filippo.ferretti@iit.it>
@diegoferigo diegoferigo merged commit 35f25b1 into main Jun 14, 2024
29 checks passed
@diegoferigo diegoferigo deleted the improve_eq_and_hash branch June 14, 2024 15:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants