Skip to content

Commit

Permalink
docs: Add empty lines before Args: in docstring (#284)
Browse files Browse the repository at this point in the history
* Add empty lines

* Add empty lines above Args:
  • Loading branch information
akihironitta authored Oct 19, 2020
1 parent 3187ebe commit d32d3eb
Show file tree
Hide file tree
Showing 13 changed files with 111 additions and 7 deletions.
1 change: 1 addition & 0 deletions pl_bolts/callbacks/self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def on_pretrain_routine_start(self, trainer, pl_module):
def get_representations(self, pl_module, x):
"""
Override this to customize for the particular model
Args:
pl_module:
x:
Expand Down
12 changes: 12 additions & 0 deletions pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ def __init__(self, env, agent, n_steps: int = 1) -> None:
def runner(self, device: torch.device) -> Tuple[Experience]:
"""Experience Source iterator yielding Tuple of experiences for n_steps. These come from the pool
of environments provided by the user.
Args:
device: current device to be used for executing experience steps
Returns:
Tuple of Experiences
"""
Expand Down Expand Up @@ -113,6 +115,7 @@ def update_history_queue(self, env_idx, exp, history) -> None:
Updates the experience history queue with the lastest experiences. In the event of an experience step is in
the done state, the history will be incrementally appended to the queue, removing the tail of the history
each time.
Args:
env_idx: index of the environment
exp: the current experience
Expand Down Expand Up @@ -172,10 +175,12 @@ def env_actions(self, device) -> List[List[int]]:
def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience:
"""
Carries out a step through the given environment using the given action
Args:
env_idx: index of the current environment
env: env at index env_idx
action: action for this environment step
Returns:
Experience tuple
"""
Expand All @@ -192,6 +197,7 @@ def update_env_stats(self, env_idx: int) -> None:
"""
To be called at the end of the history tail generation during the termination state. Updates the stats
tracked for all environments
Args:
env_idx: index of the environment used to update stats
"""
Expand Down Expand Up @@ -238,8 +244,10 @@ def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99):
def runner(self, device: torch.device) -> Experience:
"""
Iterates through experience tuple and calculate discounted experience
Args:
device: current device to be used for executing experience steps
Yields:
Discounted Experience
"""
Expand All @@ -255,8 +263,10 @@ def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tup
"""
Takes in a tuple of experiences and returns the last state and tail experiences based on
if the last state is the end of an episode
Args:
experiences: Tuple of N Experience
Returns:
last state (Array or None) and remaining Experience
"""
Expand All @@ -271,8 +281,10 @@ def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tup
def discount_rewards(self, experiences: Tuple[Experience]) -> float:
"""
Calculates the discounted reward over N experiences
Args:
experiences: Tuple of Experience
Returns:
total discounted reward
"""
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
Trainer().fit(model, dm)
Args::
Args:
data_dir: where to load the data from path, i.e. '/path/to/folder/with/data_semantics/'
val_split: size of validation test (default 0.2)
test_split: size of test set (default 0.1)
Expand Down
6 changes: 6 additions & 0 deletions pl_bolts/losses/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ def dqn_loss(batch: Tuple[torch.Tensor, torch.Tensor], net: nn.Module,
target_net: nn.Module, gamma: float = 0.99) -> torch.Tensor:
"""
Calculates the mse loss using a mini batch from the replay buffer
Args:
batch: current mini batch of replay data
net: main training network
target_net: target network of the main training network
gamma: discount factor
Returns:
loss
"""
Expand Down Expand Up @@ -45,11 +47,13 @@ def double_dqn_loss(batch: Tuple[torch.Tensor, torch.Tensor], net: nn.Module,
Calculates the mse loss using a mini batch from the replay buffer. This uses an improvement to the original
DQN loss by using the double dqn. This is shown by using the actions of the train network to pick the
value from the target network. This code is heavily commented in order to explain the process clearly
Args:
batch: current mini batch of replay data
net: main training network
target_net: target network of the main training network
gamma: discount factor
Returns:
loss
"""
Expand Down Expand Up @@ -89,12 +93,14 @@ def per_dqn_loss(batch: Tuple[torch.Tensor, torch.Tensor], batch_weights: List,
target_net: nn.Module, gamma: float = 0.99) -> Tuple[torch.Tensor, np.ndarray]:
"""
Calculates the mse loss with the priority weights of the batch from the PER buffer
Args:
batch: current mini batch of replay data
batch_weights: how each of these samples are weighted in terms of priority
net: main training network
target_net: target network of the main training network
gamma: discount factor
Returns:
loss and batch_weights
"""
Expand Down
21 changes: 15 additions & 6 deletions pl_bolts/models/rl/common/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ def __init__(self, net: nn.Module):
def __call__(self, state: torch.Tensor, device: str, *args, **kwargs) -> List[int]:
"""
Using the given network, decide what action to carry
Args:
state: current state of the environment
device: device used for current batch
Returns:
action
"""
Expand Down Expand Up @@ -51,9 +53,11 @@ def __init__(
def __call__(self, state: torch.Tensor, device: str) -> List[int]:
"""
Takes in the current state and returns the action based on the agents policy
Args:
state: current state of the environment
device: the device used for the current batch
Returns:
action defined by policy
"""
Expand All @@ -79,12 +83,14 @@ def get_random_action(self, state: torch.Tensor) -> int:

def get_action(self, state: torch.Tensor, device: torch.device):
"""
Returns the best action based on the Q values of the network
Args:
state: current state of the environment
device: the device used for the current batch
Returns:
action defined by Q values
Returns the best action based on the Q values of the network
Args:
state: current state of the environment
device: the device used for the current batch
Returns:
action defined by Q values
"""
if not isinstance(state, torch.Tensor):
state = torch.tensor(state, device=device)
Expand All @@ -96,6 +102,7 @@ def get_action(self, state: torch.Tensor, device: torch.device):
def update_epsilon(self, step: int) -> None:
"""
Updates the epsilon value based on the current step
Args:
step: current global step
"""
Expand All @@ -109,9 +116,11 @@ class PolicyAgent(Agent):
def __call__(self, states: torch.Tensor, device: str) -> List[int]:
"""
Takes in the current state and returns the action based on the agents policy
Args:
states: current state of the environment
device: the device used for the current batch
Returns:
action defined by policy
"""
Expand Down
15 changes: 15 additions & 0 deletions pl_bolts/models/rl/common/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __len__(self) -> None:
def append(self, experience: Experience) -> None:
"""
Add experience to the buffer
Args:
experience: tuple (state, action, reward, done, new_state)
"""
Expand Down Expand Up @@ -65,8 +66,10 @@ class ReplayBuffer(Buffer):
def sample(self, batch_size: int) -> Tuple:
"""
Takes a sample of the buffer
Args:
batch_size: current batch_size
Returns:
a batch of tuple np arrays of state, action, reward, done, next_state
"""
Expand Down Expand Up @@ -107,6 +110,7 @@ def __init__(self, capacity: int, n_steps: int = 1, gamma: float = 0.99) -> None
def append(self, exp: Experience) -> None:
"""
Add experience to the buffer
Args:
exp: tuple (state, action, reward, done, new_state)
"""
Expand All @@ -128,6 +132,7 @@ def update_history_queue(self, exp) -> None:
Updates the experience history queue with the lastest experiences. In the event of an experience step is in
the done state, the history will be incrementally appended to the queue, removing the tail of the history
each time.
Args:
env_idx: index of the environment
exp: the current experience
Expand Down Expand Up @@ -161,8 +166,10 @@ def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tup
"""
Takes in a tuple of experiences and returns the last state and tail experiences based on
if the last state is the end of an episode
Args:
experiences: Tuple of N Experience
Returns:
last state (Array or None) and remaining Experience
"""
Expand All @@ -177,8 +184,10 @@ def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tup
def discount_rewards(self, experiences: Tuple[Experience]) -> float:
"""
Calculates the discounted reward over N experiences
Args:
experiences: Tuple of Experience
Returns:
total discounted reward
"""
Expand Down Expand Up @@ -233,8 +242,10 @@ def __init__(self, buffer_size, prob_alpha=0.6, beta_start=0.4, beta_frames=1000
def update_beta(self, step) -> float:
"""
Update the beta value which accounts for the bias in the PER
Args:
step: current global step
Returns:
beta value for this indexed experience
"""
Expand All @@ -246,6 +257,7 @@ def update_beta(self, step) -> float:
def append(self, exp) -> None:
"""
Adds experiences from exp_source to the PER buffer
Args:
exp: experience tuple being added to the buffer
"""
Expand All @@ -266,8 +278,10 @@ def append(self, exp) -> None:
def sample(self, batch_size=32) -> Tuple:
"""
Takes a prioritized sample from the buffer
Args:
batch_size: size of sample
Returns:
sample of experiences chosen with ranked probability
"""
Expand Down Expand Up @@ -308,6 +322,7 @@ def update_priorities(self, batch_indices: List, batch_priorities: List) -> None
"""
Update the priorities from the last batch, this should be called after the loss for this batch has been
calculated.
Args:
batch_indices: index of each datum in the batch
batch_priorities: priority of each datum in the batch
Expand Down
Loading

0 comments on commit d32d3eb

Please sign in to comment.