You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This RFC proposes to add an API to the standard for returning library defaults (e.g., dtypes and devices).
Currently, the standard requires that conforming array libraries explicitly state their default dtypes in their documentation; no guidance is provided regarding devices. The standard provides no APIs for querying library defaults.
The inability to query defaults requires manual workarounds, such as allocating a fresh array and checking the dtype or device attribute. This is less than ideal, especially for third-party array libraries wanting to generically extend an array library's namespace and adhere to the same default behavior (e.g., wrap library x to expose additional array creation functions for specialized matrices).
Prior art
PyTorch
torch.get_default_dtype() → torch.dtype
Returns the current default floating-point dtype.
JAX
jax.default_backend() ->str
Returns the platform name of the default XLA backend.
More keys could be added in the future, depending on evolution of the standard.
If provided a device argument, the function would return a dictionary as described above, but specific to the specified device.
If provided a name argument, the function would return the default for the specific setting. E.g.,
>>>d=xp.defaults("device")
If provided both a name and device argument, the function would return the default for the specified setting for the specified device.
Notes
If we wanted to allow supporting a standardized means for configuring defaults (e.g., setting the default real floating-point dtype to float32, instead of float64), may want to rename the API to something like get_defaults() and then setting could be set_default().
When invoked without a device argument, the function would return default dtypes based on the current device context, as array libraries may have differing default dtypes, depending on the device. Accordingly, users should be advised to not assume that defaults are static.
This RFC proposes to add an API to the standard for returning library defaults (e.g., dtypes and devices).
Currently, the standard requires that conforming array libraries explicitly state their default dtypes in their documentation; no guidance is provided regarding devices. The standard provides no APIs for querying library defaults.
The inability to query defaults requires manual workarounds, such as allocating a fresh array and checking the
dtype
ordevice
attribute. This is less than ideal, especially for third-party array libraries wanting to generically extend an array library's namespace and adhere to the same default behavior (e.g., wrap libraryx
to expose additional array creation functions for specialized matrices).Prior art
PyTorch
Returns the current default floating-point dtype.
JAX
Returns the platform name of the default XLA backend.
Proposal
If not provided an argument, the function would return a dictionary with the following keys:
More keys could be added in the future, depending on evolution of the standard.
If provided a device argument, the function would return a dictionary as described above, but specific to the specified device.
If provided a name argument, the function would return the default for the specific setting. E.g.,
If provided both a name and device argument, the function would return the default for the specified setting for the specified device.
Notes
float32
, instead offloat64
), may want to rename the API to something likeget_defaults()
and then setting could beset_default()
.Related
The text was updated successfully, but these errors were encountered: