Skip to content

Commit

Permalink
Adding ParallelBlock (#1088)
Browse files Browse the repository at this point in the history
* Merging

* Merging

* Adding BlockContainerDict

* Adding ParallelBlock

* Improving doc-strings

* Adding doc-strings to BlockContainerDict

* Changes according to PR-review

* Move example-usage to class instead of forward-method
  • Loading branch information
marcromeyn authored May 22, 2023
1 parent b3562dd commit 67719df
Show file tree
Hide file tree
Showing 6 changed files with 550 additions and 7 deletions.
1 change: 0 additions & 1 deletion .github/workflows/pytorch.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

name: pytorch

on:
Expand Down
4 changes: 2 additions & 2 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
#

from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import Block
from merlin.models.torch.block import Block, ParallelBlock

__all__ = ["Batch", "Block", "Sequence"]
__all__ = ["Batch", "Block", "ParallelBlock", "Sequence"]
239 changes: 238 additions & 1 deletion merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import nn

from merlin.models.torch.batch import Batch
from merlin.models.torch.container import BlockContainer
from merlin.models.torch.container import BlockContainer, BlockContainerDict


class Block(BlockContainer):
Expand Down Expand Up @@ -98,3 +98,240 @@ def copy(self) -> "Block":
The copy of the current block.
"""
return deepcopy(self)


class ParallelBlock(Block):
"""A base-class that calls it's modules in parallel.
A ParallelBlock contains multiple branches that will be executed
in parallel. Each branch can contain multiple modules, and
the outputs of all branches are collected into a dictionary.
If a branch returns a dictionary of tensors instead of a single tensor,
it will be flattened into the output dictionary. This ensures the output-type
is always Dict[str, torch.Tensor].
Example usage::
>>> parallel_block = ParallelBlock({"a": nn.LazyLinear(2), "b": nn.LazyLinear(2)})
>>> x = torch.randn(2, 2)
>>> output = parallel_block(x)
>>> # The output is a dictionary containing the output of each branch
>>> print(output)
{
'a': tensor([[-0.0801, 0.0436],
[ 0.1235, -0.0318]]),
'b': tensor([[ 0.0918, -0.0274],
[-0.0652, 0.0381]])
}
Parameters
----------
*module : nn.Module
Variable length argument list of PyTorch modules to be contained in the block.
name : Optional[str], default = None
The name of the block. If None, no name is assigned.
"""

def __init__(
self,
*inputs: Union[nn.Module, Dict[str, nn.Module]],
):
pre = BlockContainer(name="pre")
branches = BlockContainerDict(*inputs)
post = BlockContainer(name="post")

super().__init__()

self.pre = pre
self.branches = branches
self.post = post

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
):
"""Forward pass through the block.
The steps are as follows:
1. Pre-processing stage: Applies each module in the pre-processing stage sequentially.
2. Branching stage: Applies each module in each branch sequentially.
3. Post-processing stage: Applies each module in the post-processing stage sequentially.
If a branch returns a dictionary of tensors instead of a single tensor,
it will be flattened into the output dictionary. This ensures the output-type
is always Dict[str, torch.Tensor].
Parameters
----------
inputs : Union[torch.Tensor, Dict[str, torch.Tensor]]
The input tensor or dictionary of tensors.
batch : Optional[Batch], default=None
An optional batch of data.
Returns
-------
Dict[str, torch.Tensor]
The output tensors.
"""
for module in self.pre.values:
inputs = module(inputs, batch=batch)

outputs = {}
for name, branch_container in self.branches.items():
branch = inputs
for module in branch_container.values:
branch = module(branch, batch=batch)

if isinstance(branch, torch.Tensor):
branch_dict = {name: branch}
elif torch.jit.isinstance(branch, Dict[str, torch.Tensor]):
branch_dict = branch
else:
raise TypeError(
f"Branch output must be a tensor or a dictionary of tensors. Got {type(branch)}"
)

for key in branch_dict.keys():
if key in outputs:
raise RuntimeError(f"Duplicate output name: {key}")

outputs.update(branch_dict)

for module in self.post.values:
outputs = module(outputs, batch=batch)

return outputs

def append(self, module: nn.Module):
"""Appends a module to the post-processing stage.
Parameters
----------
module : nn.Module
The module to append.
Returns
-------
ParallelBlock
The current object itself.
"""

self.post.append(module)

return self

def prepend(self, module: nn.Module):
"""Prepends a module to the pre-processing stage.
Parameters
----------
module : nn.Module
The module to prepend.
Returns
-------
ParallelBlock
The current object itself.
"""

self.pre.prepend(module)

return self

def append_to(self, name: str, module: nn.Module):
"""Appends a module to a specified branch.
Parameters
----------
name : str
The name of the branch.
module : nn.Module
The module to append.
Returns
-------
ParallelBlock
The current object itself.
"""

self.branches[name].append(module)

return self

def prepend_to(self, name: str, module: nn.Module):
"""Prepends a module to a specified branch.
Parameters
----------
name : str
The name of the branch.
module : nn.Module
The module to prepend.
Returns
-------
ParallelBlock
The current object itself.
"""
self.branches[name].prepend(module)

return self

def append_for_each(self, module: nn.Module, shared=False):
"""Appends a module to each branch.
Parameters
----------
module : nn.Module
The module to append.
shared : bool, default=False
If True, the same module is shared across all branches.
Otherwise a deep copy of the module is used in each branch.
Returns
-------
ParallelBlock
The current object itself.
"""

self.branches.append_for_each(module, shared=shared)

return self

def prepend_for_each(self, module: nn.Module, shared=False):
"""Prepends a module to each branch.
Parameters
----------
module : nn.Module
The module to prepend.
shared : bool, default=False
If True, the same module is shared across all branches.
Otherwise a deep copy of the module is used in each branch.
Returns
-------
ParallelBlock
The current object itself.
"""

self.branches.prepend_for_each(module, shared=shared)

return self

def __getitem__(self, idx: Union[slice, int, str]):
if isinstance(idx, str) and idx in self.branches:
return self.branches[idx]

if idx == 0:
return self.pre

if idx == -1 or idx == 2:
return self.post

raise IndexError(f"Index {idx} is out of range for {self.__class__.__name__}")

def __len__(self):
return len(self.branches)

def __contains__(self, name):
return name in self.branches
Loading

0 comments on commit 67719df

Please sign in to comment.