-
Notifications
You must be signed in to change notification settings - Fork 603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make default precision equal to jax.config's jax_enable_x64 #1485
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1485 +/- ##
=======================================
Coverage 98.32% 98.32%
=======================================
Files 180 180
Lines 12701 12705 +4
=======================================
+ Hits 12488 12492 +4
Misses 213 213
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great @dwierichs!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dwierichs nice, your changelog entry is better than my original suggestion
Co-authored-by: Josh Izaac <josh146@gmail.com>
Good catch regarding the syntax highlighting! |
The default qubit jax device currently always uses
float32
precision.This change makes the device check the status of
jax.config.config.read('jax_enable_x64')
, which is the variable set toTrue
by jax users to enablefloat64
precision (also see the jax gotcha).Benefit
When users consciously change the default of jax' precision, the PennyLane device should play along and now does.
Possible drawbacks
The behaviour changes when
jax.config.config.read('jax_enable_x64')=True
but into the direction of what users would presumably expect.An alternative method would be to use this update as new default but provide an override
precision
kwarg to allow for explicitly setting the precision.