Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: add API to return library defaults #638

Closed
kgryte opened this issue Jun 1, 2023 · 2 comments
Closed

RFC: add API to return library defaults #638

kgryte opened this issue Jun 1, 2023 · 2 comments
Labels
API extension Adds new functions or objects to the API.

Comments

@kgryte
Copy link
Contributor

kgryte commented Jun 1, 2023

This RFC proposes to add an API to the standard for returning library defaults (e.g., dtypes and devices).

Currently, the standard requires that conforming array libraries explicitly state their default dtypes in their documentation; no guidance is provided regarding devices. The standard provides no APIs for querying library defaults.

The inability to query defaults requires manual workarounds, such as allocating a fresh array and checking the dtype or device attribute. This is less than ideal, especially for third-party array libraries wanting to generically extend an array library's namespace and adhere to the same default behavior (e.g., wrap library x to expose additional array creation functions for specialized matrices).

Prior art

PyTorch

torch.get_default_dtype() → torch.dtype

Returns the current default floating-point dtype.

JAX

jax.default_backend() -> str

Returns the platform name of the default XLA backend.

Proposal

defaults() -> dict[str, any]
defaults(device: device) -> dict[str, any]
defaults(name: str) -> any
defaults(name: str, device: device) -> any

If not provided an argument, the function would return a dictionary with the following keys:

  • device: default device.
  • dtypes.real_floating_point: default real floating-point dtype.
  • dtypes.complex_floating_point: default complex floating-point dtype.
  • dtypes.integral: default integral dtype.
  • dtypes.indexing: default index dtype.

More keys could be added in the future, depending on evolution of the standard.

If provided a device argument, the function would return a dictionary as described above, but specific to the specified device.

If provided a name argument, the function would return the default for the specific setting. E.g.,

>>> d = xp.defaults("device")

If provided both a name and device argument, the function would return the default for the specified setting for the specified device.

Notes

  • If we wanted to allow supporting a standardized means for configuring defaults (e.g., setting the default real floating-point dtype to float32, instead of float64), may want to rename the API to something like get_defaults() and then setting could be set_default().
  • When invoked without a device argument, the function would return default dtypes based on the current device context, as array libraries may have differing default dtypes, depending on the device. Accordingly, users should be advised to not assume that defaults are static.

Related

@kgryte kgryte added the API extension Adds new functions or objects to the API. label Jun 1, 2023
@kgryte
Copy link
Contributor Author

kgryte commented Jun 15, 2023

Update: updated the OP to include a device kwarg.

@kgryte
Copy link
Contributor Author

kgryte commented Jun 29, 2023

This RFC has been superseded by #640.

@kgryte kgryte closed this as completed Jun 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API.
Projects
None yet
Development

No branches or pull requests

1 participant