Skip to content

Commit

Permalink
rollout exchange chunk_divisor arg for chunk_size
Browse files Browse the repository at this point in the history
  • Loading branch information
aftersomemath committed Dec 12, 2024
1 parent 298ab2f commit 169cf99
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
18 changes: 12 additions & 6 deletions python/mujoco/rollout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ Roll out open-loop trajectories from initial states, get resulting states and se
output arguments (optional):
state (nroll x nstep x nstate) nroll nstep states
sensordata (nroll x nstep x nsendordata) nroll trajectories of nstep sensordata vectors
chunk_divisor integer, determines threadpool chunk size according to
chunk_size = max(1, nroll / (nthread * chunk_divisor)
chunk_size integer, determines threadpool chunk size. If unspecified
chunk_size = max(1, nroll / (nthread * 10)
)";

// C-style rollout function, assumes all arguments are valid
Expand Down Expand Up @@ -241,7 +241,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
std::optional<const PyCArray> control,
std::optional<const PyCArray> state,
std::optional<const PyCArray> sensordata,
int chunk_divisor
std::optional<int> chunk_size
) {
// get raw pointers
int nroll = state0.shape(0);
Expand Down Expand Up @@ -281,11 +281,17 @@ PYBIND11_MODULE(_rollout, pymodule) {

// call unsafe rollout function
if (nthread > 1 && nroll > 1) {
int chunk_size = std::max(1, nroll / (chunk_divisor * nthread));
int chunk_size_final = 1;
if (!chunk_size.has_value()) {
chunk_size_final = std::max(1, nroll / (10 * nthread));
}
else {
chunk_size_final = *chunk_size;
}
InterceptMjErrors(_unsafe_rollout_threaded)(
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
nthread, chunk_size);
nthread, chunk_size_final);
}
else {
InterceptMjErrors(_unsafe_rollout)(
Expand All @@ -303,7 +309,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
py::arg("control") = py::none(),
py::arg("state") = py::none(),
py::arg("sensordata") = py::none(),
py::arg("chunk_divisor") = 10,
py::arg("chunk_size") = py::none(),
py::doc(rollout_doc)
);
}
Expand Down
12 changes: 7 additions & 5 deletions python/mujoco/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def rollout(
initial_warmstart: Optional[npt.ArrayLike] = None,
state: Optional[npt.ArrayLike] = None,
sensordata: Optional[npt.ArrayLike] = None,
chunk_divisor: int = 10,
chunk_size: int = None,
):
"""Rolls out open-loop trajectories from initial states, get subsequent states and sensor values.
Expand All @@ -60,8 +60,8 @@ def rollout(
(nroll x nstep x nstate)
sensordata: Sensor data output array (optional).
(nroll x nstep x nsensordata)
chunk_divisor: Determines threadpool chunk size according to
chunk_size = max(1, nroll / (nthread * chunk_divisor)
chunk_size: Determines threadpool chunk size. If unspecified,
chunk_size = max(1, nroll / (nthread * 10)
Returns:
state:
Expand All @@ -88,7 +88,7 @@ def rollout(
control,
state,
sensordata,
chunk_divisor,
chunk_size,
)
return state, sensordata

Expand All @@ -102,6 +102,8 @@ def rollout(
# check types
if nstep and not isinstance(nstep, int):
raise ValueError('nstep must be an integer')
if chunk_size and not isinstance(chunk_size, int):
raise ValueError('chunk_size must be an integer')
_check_must_be_numeric(
initial_state=initial_state,
initial_warmstart=initial_warmstart,
Expand Down Expand Up @@ -202,7 +204,7 @@ def rollout(
control,
state,
sensordata,
chunk_divisor,
chunk_size,
)

# return outputs
Expand Down

0 comments on commit 169cf99

Please sign in to comment.