Skip to content

Commit

Permalink
Add __call__ typing to CategoricalPolicy
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Aug 6, 2023
1 parent bf1632f commit fc739c7
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions d3rlpy/models/torch/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,6 @@ def __init__(self, encoder: Encoder, hidden_size: int, action_size: int):

def forward(self, x: torch.Tensor) -> Categorical:
return Categorical(logits=self._fc(self._encoder(x)))

def __call__(self, x: torch.Tensor) -> Categorical:
return super().__call__(x)

0 comments on commit fc739c7

Please sign in to comment.