Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Bugfix/mps acceleration support #10

Merged
merged 3 commits into from
Oct 23, 2023
Merged

Conversation

BPdeRooij
Copy link
Contributor

@BPdeRooij BPdeRooij commented Oct 23, 2023

Fixes #{9} by transforming mpp from float64 to torch.float32 if mps backend is available. Also adds an explicit trainer/mps.yaml config file, based on the trainer/default.yaml

@BPdeRooij BPdeRooij requested a review from jonasteuwen October 23, 2023 13:45
@BPdeRooij BPdeRooij linked an issue Oct 23, 2023 that may be closed by this pull request
@jonasteuwen
Copy link
Contributor

Why is mpp even in there?

Copy link
Contributor

@jonasteuwen jonasteuwen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jonasteuwen jonasteuwen merged commit fca3679 into main Oct 23, 2023
1 check passed
@jonasteuwen jonasteuwen deleted the bugfix/mps-acceleration-support branch October 23, 2023 19:50
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MPS acceleration support
2 participants