Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Bugfix/mps acceleration support #10

Merged
merged 3 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ahcore/transforms/pre_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]:
roi = sample["annotation_data"]["roi"]
sample["roi"] = torch.from_numpy(roi[np.newaxis, ...]).float()

sample["mpp"] = torch.tensor(
sample["mpp"], dtype=torch.float32 if torch.backends.mps.is_available() else torch.float64
)

return sample

def __repr__(self) -> str:
Expand Down
16 changes: 16 additions & 0 deletions config/trainer/mps.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: pytorch_lightning.Trainer

# When using MPS acceleration set environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` in .env
accelerator: mps
devices: 1
max_epochs: 1000
num_nodes: 1

# number of validation steps to execute at the beginning of the training
num_sanity_val_steps: 0
#log_every_n_steps: 2

# if a smaller grid is used, the val check interval should be smaller
#val_check_interval: 2 # Used if you want to check val more than once per epoch
check_val_every_n_epoch: 1 # Used if you want to check val less than once per epoch
accumulate_grad_batches: 1