From 00df3ce1b60d21c624f7c70ff986701cf10c9129 Mon Sep 17 00:00:00 2001 From: Will Shainin Date: Mon, 24 Jun 2024 16:51:39 -0700 Subject: [PATCH 1/4] Removing stale reference window_size --- src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py b/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py index 6de3e7639..7963c77b2 100644 --- a/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py +++ b/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py @@ -48,7 +48,6 @@ class Config: max_positions_per_type: int = 10 base_reward_scale: float = 0.0 position_reward_scale: float = 1 - window_size: int = 10 episode_length: int = 200 # The threshold for the probability of opening and closing orders open_threshold: float = 0.5 From ae4c9a253e32aca2200c110c47c92cc754d2160f Mon Sep 17 00:00:00 2001 From: Will Shainin Date: Mon, 24 Jun 2024 16:55:44 -0700 Subject: [PATCH 2/4] Terminated if step count >= episode len I think this is where termination should be marked so that episode_length matches n_steps --- src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py b/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py index 7963c77b2..b62c24305 100644 --- a/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py +++ b/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py @@ -480,7 +480,7 @@ def step(self, action: np.ndarray) -> tuple[dict[str, np.ndarray], float, bool, self._step_count += 1 terminated = False - if self._step_count > self.gym_config.episode_length: + if self._step_count >= self.gym_config.episode_length: terminated = True # TODO when does the episode stop? From 6f673e2d55db61a399f0400f913187fb46011050 Mon Sep 17 00:00:00 2001 From: Will Shainin Date: Mon, 24 Jun 2024 16:58:36 -0700 Subject: [PATCH 3/4] Adding changes from lazy pnl update (#1561) --- .../gym_environments/full_hyperdrive_env.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py b/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py index b62c24305..791bebced 100644 --- a/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py +++ b/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py @@ -106,9 +106,18 @@ def __init__( self.eval_mode = gym_config.eval_mode self.sample_actions = gym_config.sample_actions if self.eval_mode: - local_chain_config = LocalChain.Config(block_timestamp_interval=12, db_port=5434, chain_port=10001) + db_port = 5434 + chain_port = 10001 else: - local_chain_config = LocalChain.Config(block_timestamp_interval=12, db_port=5435, chain_port=10002) + db_port = 5435 + chain_port = 10002 + + local_chain_config = LocalChain.Config( + block_timestamp_interval=12, + db_port=db_port, + chain_port=chain_port, + calc_pnl=False) + initial_pool_config = LocalHyperdrive.Config() self.chain = LocalChain(local_chain_config) self.interactive_hyperdrive = LocalHyperdrive(self.chain, initial_pool_config) @@ -507,7 +516,7 @@ def _get_observation(self) -> dict[str, np.ndarray]: out_obs["lp_orders"] = np.zeros(2) # Observation data uses floats - rl_bot_wallet = self.rl_bot.get_positions(coerce_float=True) + rl_bot_wallet = self.rl_bot.get_positions(coerce_float=True, calc_pnl=True) if not rl_bot_wallet.empty: position_duration = self.interactive_hyperdrive.config.position_duration @@ -543,7 +552,10 @@ def _get_observation(self) -> dict[str, np.ndarray]: def _calculate_reward(self) -> float: # The total delta for this episode - current_wallet = self.interactive_hyperdrive.get_positions(show_closed_positions=True, coerce_float=True) + current_wallet = self.interactive_hyperdrive.get_positions( + show_closed_positions=True, + calc_pnl=True, + coerce_float=True) # Filter by rl bot rl_bot_wallet = current_wallet[current_wallet["wallet_address"] == self.rl_bot.address] # The rl_bot_wallet shows the pnl of all positions From 1ee43c3b30002547c9cd6ab70da1c41f8901a8d9 Mon Sep 17 00:00:00 2001 From: Will Shainin Date: Thu, 27 Jun 2024 16:30:09 -0700 Subject: [PATCH 4/4] Linting fix --- .../gym_environments/full_hyperdrive_env.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py b/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py index 791bebced..344511055 100644 --- a/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py +++ b/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py @@ -113,10 +113,8 @@ def __init__( chain_port = 10002 local_chain_config = LocalChain.Config( - block_timestamp_interval=12, - db_port=db_port, - chain_port=chain_port, - calc_pnl=False) + block_timestamp_interval=12, db_port=db_port, chain_port=chain_port, calc_pnl=False + ) initial_pool_config = LocalHyperdrive.Config() self.chain = LocalChain(local_chain_config) @@ -553,9 +551,8 @@ def _calculate_reward(self) -> float: # The total delta for this episode current_wallet = self.interactive_hyperdrive.get_positions( - show_closed_positions=True, - calc_pnl=True, - coerce_float=True) + show_closed_positions=True, calc_pnl=True, coerce_float=True + ) # Filter by rl bot rl_bot_wallet = current_wallet[current_wallet["wallet_address"] == self.rl_bot.address] # The rl_bot_wallet shows the pnl of all positions