Skip to content

Commit

Permalink
Elementwise decoder
Browse files Browse the repository at this point in the history
Summary: Tensorf does relu or softmax after the density grid. This diff adds the ability to replicate that.

Reviewed By: bottler

Differential Revision: D40023228

fbshipit-source-id: 9f19868cd68460af98ab6e61c7f708158c26dc08
  • Loading branch information
Darijan Gudelj authored and facebook-github-bot committed Oct 13, 2022
1 parent a607dd0 commit 76cddd9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,43 @@ def forward(


@registry.register
class IdentityDecoder(DecoderFunctionBase):
class ElementwiseDecoder(DecoderFunctionBase):
"""
Decoding function which returns its input.
Decoding function which scales the input, adds shift and then applies
`relu`, `softplus`, `sigmoid` or nothing on its input:
`result = operation(input * scale + shift)`
Members:
scale: a scalar with which input is multiplied before being shifted.
Defaults to 1.
shift: a scalar which is added to the scaled input before performing
the operation. Defaults to 0.
operation: which operation to perform on the transformed input. Options are:
`relu`, `softplus`, `sigmoid` and `identity`. Defaults to `identity`.
"""

scale: float = 1
shift: float = 0
operation: str = "identity"

def __post_init__(self):
super().__post_init__()
if self.operation not in ["relu", "softplus", "sigmoid", "identity"]:
raise ValueError(
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
)

def forward(
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor:
return features
transfomed_input = features * self.scale + self.shift
if self.operation == "softplus":
return torch.nn.functional.softplus(transfomed_input)
if self.operation == "relu":
return torch.nn.functional.relu(transfomed_input)
if self.operation == "sigmoid":
return torch.nn.functional.sigmoid(transfomed_input)
return transfomed_input


class MLPWithInputSkips(Configurable, torch.nn.Module):
Expand Down
34 changes: 0 additions & 34 deletions tests/implicitron/test_decoding_functions.py

This file was deleted.

0 comments on commit 76cddd9

Please sign in to comment.