Skip to content

Commit

Permalink
Remove the 'device' property, add fix in __init__ (#50)
Browse files Browse the repository at this point in the history
* Remove the 'device' property, add fix in __init__

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* "fix" the example algo __init__

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Sep 26, 2024
1 parent 64c6f36 commit 725d9fa
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 33 deletions.
31 changes: 9 additions & 22 deletions project/algorithms/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
```
"""

import dataclasses
import functools
from logging import getLogger
from typing import Any, Literal

import torch
from lightning import LightningModule
from omegaconf import DictConfig
from torch import Tensor
from torch.nn import functional as F
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -48,20 +46,20 @@ def __init__(
self.datamodule = datamodule
self.network = network
self.optimizer_config = optimizer_config
assert dataclasses.is_dataclass(optimizer_config) or isinstance(
optimizer_config, dict | DictConfig
), optimizer_config

# Save hyper-parameters.
self.save_hyperparameters(ignore=["datamodule", "network"])

# Small fix for the `device` property in LightningModule, which is CPU by default.
self._device = next((p.device for p in self.parameters()), torch.device("cpu"))
# Used by Pytorch-Lightning to compute the input/output shapes of the network.
self.example_input_array = torch.zeros(
(datamodule.batch_size, *datamodule.dims), device=self.device
)
# Do a forward pass to initialize any lazy weights. This is necessary for distributed
# training and to infer shapes.
_ = self.network(self.example_input_array)

# Save hyper-parameters.
self.save_hyperparameters(ignore=["datamodule", "network"])
if any(torch.nn.parameter.is_lazy(p) for p in self.network.parameters()):
# Do a forward pass to initialize any lazy weights. This is necessary for distributed
# training and to display network activation shapes in the summary.
_ = self.network(self.example_input_array)

def forward(self, input: Tensor) -> Tensor:
logits = self.network(input)
Expand Down Expand Up @@ -99,14 +97,3 @@ def configure_optimizers(self):
optimizer_partial = instantiate(self.optimizer_config)
optimizer = optimizer_partial(self.parameters())
return optimizer

@property
def device(self) -> torch.device:
"""Small fixup for the `device` property in LightningModule, which is CPU by default."""
if self._device.type == "cpu":
self._device = next((p.device for p in self.parameters()), torch.device("cpu"))
device = self._device
# make this more explicit to always include the index
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device
3 changes: 3 additions & 0 deletions project/algorithms/hf_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def __init__(
experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S"),
)

# Small fix for the `device` property in LightningModule, which is CPU by default.
self._device = next((p.device for p in self.parameters()), torch.device("cpu"))

def forward(
self,
input_ids: torch.Tensor,
Expand Down
11 changes: 0 additions & 11 deletions project/algorithms/jax_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,6 @@ def configure_callbacks(self) -> list[Callback]:
ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes),
]

@property
def device(self) -> torch.device:
"""Small fixup for the `device` property in LightningModule, which is CPU by default."""
if self._device.type == "cpu":
self._device = next((p.device for p in self.parameters()), torch.device("cpu"))
device = self._device
# make this more explicit to always include the index
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device


# Register a handler function to "convert" `torch.nn.Parameter`s to jax Arrays: they can be viewed
# as jax Arrays by just viewing their data as a jax array.
Expand Down

0 comments on commit 725d9fa

Please sign in to comment.