Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Understanding some RNN logic for QMIX SMAX algorithm #119

Open
Chulabhaya opened this issue Oct 29, 2024 · 3 comments
Open

Understanding some RNN logic for QMIX SMAX algorithm #119

Chulabhaya opened this issue Oct 29, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@Chulabhaya
Copy link

Chulabhaya commented Oct 29, 2024

Hey all, hope everyone is doing well! What follows may be bit of a dumb question, but I just wanted to clarify how this is working for my own algorithm development based on your guys' excellent code.

The QMIX code uses a ScannedRNN class where you pass in a sequence of observations and dones, and anywhere where a done condition is true, the hidden state is reset, and we pass the corresponding obs at that timestep through:

class ScannedRNN(nn.Module):
    @partial(
        nn.scan,
        variable_broadcast="params",
        in_axes=0,
        out_axes=0,
        split_rngs={"params": False},
    )
    @nn.compact
    def __call__(self, carry, x):
        """Applies the module."""
        rnn_state = carry
        ins, resets = x
        hidden_size = ins.shape[-1]
        rnn_state = jnp.where(
            resets[:, np.newaxis],
            self.initialize_carry(hidden_size, *ins.shape[:-1]),
            rnn_state,
        )
        new_rnn_state, y = nn.GRUCell(hidden_size)(rnn_state, ins)
        return new_rnn_state, y

This makes sense to me. However I noticed that when data is actually collected, any given timestep actually consists of the last obs + new done, instead of last obs + last done.

image

Therefore, doesn't it mean that when we're using this RNN and it resets the hidden state and then passes in the observation, we're actually using the previous observation (which is associated with the previous episode) as the first step in the RNN's new sequence with the reset hidden state, instead of the current/new observation (from the new episode after the environment was just reset) generated after the episode was ended with the done being True?

@Chulabhaya
Copy link
Author

Chulabhaya commented Oct 31, 2024

To follow up on this, I looked through the SMAX MAPPO code, and it looks like there a given timestep consists of both last done and last obs, and I think this is correct because when you then sample a batch of this data and feed it into the same ScannedRNN, the dones will correctly reset the hidden states because the observations that are then passed in afterwards are the new observations from the new episodes. So is one of these implementations correct and the other one is wrong, or are they done differently for intentional reasons?

def _env_step(runner_state, unused):
      train_states, env_state, last_obs, last_done, hstates, rng = (
          runner_state
      )

      # SELECT ACTION
      rng, _rng = jax.random.split(rng)
      avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
      avail_actions = jax.lax.stop_gradient(
          batchify(avail_actions, env.agents, config["NUM_ACTORS"])
      )
      obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
      ac_in = (
          obs_batch[np.newaxis, :],
          last_done[np.newaxis, :],
          avail_actions,
      )
      # print('env step ac in', ac_in)
      ac_hstate, pi = actor_network.apply(
          train_states[0].params, hstates[0], ac_in
      )
      action = pi.sample(seed=_rng)
      log_prob = pi.log_prob(action)
      env_act = unbatchify(
          action, env.agents, config["NUM_ENVS"], env.num_agents
      )
      env_act = {k: v.squeeze() for k, v in env_act.items()}

      # VALUE
      # output of wrapper is (num_envs, num_agents, world_state_size)
      # swap axes to (num_agents, num_envs, world_state_size) before reshaping to (num_actors, world_state_size)
      world_state = last_obs["world_state"].swapaxes(0, 1)
      world_state = world_state.reshape((config["NUM_ACTORS"], -1))

      cr_in = (
          world_state[None, :],
          last_done[np.newaxis, :],
      )
      cr_hstate, value = critic_network.apply(
          train_states[1].params, hstates[1], cr_in
      )

      # STEP ENV
      rng, _rng = jax.random.split(rng)
      rng_step = jax.random.split(_rng, config["NUM_ENVS"])
      obsv, env_state, reward, done, info = jax.vmap(
          env.step, in_axes=(0, 0, 0)
      )(rng_step, env_state, env_act)
      info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
      done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
      transition = Transition(
          jnp.tile(done["__all__"], env.num_agents),
          last_done,
          action.squeeze(),
          value.squeeze(),
          batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
          log_prob.squeeze(),
          obs_batch,
          world_state,
          info,
          avail_actions,
      )
      runner_state = (
          train_states,
          env_state,
          obsv,
          done_batch,
          (ac_hstate, cr_hstate),
          rng,
      )
      return runner_state, transition

@amacrutherford
Copy link
Collaborator

@mttga mind checking this, pretty sure the mappo way is correct as we had to fix to use last_done

@amacrutherford amacrutherford added the bug Something isn't working label Nov 5, 2024
@Chulabhaya
Copy link
Author

Chulabhaya commented Nov 5, 2024

@amacrutherford Thanks for the follow up, I'm happy to make a PR with the update if you guys determine it is a bug. I also ran a couple experiments using last_dones and I got better performance on 2/3 (and equal on the third).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants