Skip to content

Commit

Permalink
Jit the returned lcm function (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
mj023 authored Jun 1, 2024
1 parent bbdb2a1 commit 3bed4b6
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def get_lcm_function(
model,
targets="solve",
debug_mode=True, # noqa: FBT002
jit=True, # noqa: FBT002
interpolation_options=None,
):
"""Entry point for users to get high level functions generated by lcm.
Expand All @@ -47,6 +48,7 @@ def get_lcm_function(
targets (str or iterable): The requested function types. Currently only
"solve", "simulate" and "solve_and_simulate" are supported.
debug_mode (bool): Whether to log debug messages.
jit (bool): Whether to jit the returned function.
interpolation_options (dict): Dictionary of keyword arguments for interpolation
via map_coordinates.
Expand Down Expand Up @@ -139,13 +141,13 @@ def get_lcm_function(
utility_and_feasibility=u_and_f,
continuous_choice_variables=list(_choice_grids),
)
compute_ccv_functions.append(jax.jit(compute_ccv))
compute_ccv_functions.append(compute_ccv)

compute_ccv_argmax = create_compute_conditional_continuation_policy(
utility_and_feasibility=u_and_f,
continuous_choice_variables=list(_choice_grids),
)
compute_ccv_policy_functions.append(jax.jit(compute_ccv_argmax))
compute_ccv_policy_functions.append(compute_ccv_argmax)

# create list of emax_calculators
# ==============================================================================
Expand All @@ -158,7 +160,7 @@ def get_lcm_function(
choice_segments=choice_segments[period],
params=_mod.params,
)
emax_calculators.append(jax.jit(calculator))
emax_calculators.append(calculator)

# ==================================================================================
# select requested solver and partial arguments into it
Expand All @@ -172,7 +174,8 @@ def get_lcm_function(
emax_calculators=emax_calculators,
logger=logger,
)
solve_model = _solve_model

solve_model = jax.jit(_solve_model) if jit else _solve_model

_next_state_simulate = get_next_state_function(model=_mod, target="simulate")
simulate_model = partial(
Expand Down

0 comments on commit 3bed4b6

Please sign in to comment.