From 500bdc6f364c4a89ef719fc3e734e93fc4b3d49a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Wed, 28 Feb 2024 16:28:59 +0100 Subject: [PATCH] Remove PRNGKeyArray --- src/bmi/estimators/neural/_basic_training.py | 6 ++---- src/bmi/estimators/neural/_critics.py | 2 +- src/bmi/estimators/neural/_estimators.py | 2 +- src/bmi/estimators/neural/_mine_estimator.py | 6 +++--- src/bmi/interface.py | 2 +- src/bmi/samplers/_tfp/_core.py | 8 +++----- 6 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/bmi/estimators/neural/_basic_training.py b/src/bmi/estimators/neural/_basic_training.py index 5457a0e1..2e335b4f 100644 --- a/src/bmi/estimators/neural/_basic_training.py +++ b/src/bmi/estimators/neural/_basic_training.py @@ -10,9 +10,7 @@ from bmi.estimators.neural._types import BatchedPoints, Critic, Point -def get_batch( - xs: BatchedPoints, ys: BatchedPoints, key: jax.random.PRNGKeyArray, batch_size: Optional[int] -): +def get_batch(xs: BatchedPoints, ys: BatchedPoints, key: jax.Array, batch_size: Optional[int]): if batch_size is not None: batch_indices = jax.random.choice( key, @@ -26,7 +24,7 @@ def get_batch( def basic_training( - rng: jax.random.PRNGKeyArray, + rng: jax.Array, critic: eqx.Module, mi_formula: Callable[[Critic, Point, Point], float], xs: BatchedPoints, diff --git a/src/bmi/estimators/neural/_critics.py b/src/bmi/estimators/neural/_critics.py index 2d0654bd..e07682c7 100644 --- a/src/bmi/estimators/neural/_critics.py +++ b/src/bmi/estimators/neural/_critics.py @@ -17,7 +17,7 @@ class MLP(eqx.Module): def __init__( self, - key: jax.random.PRNGKeyArray, + key: jax.Array, dim_x: int, dim_y: int, hidden_layers: Sequence[int] = (5,), diff --git a/src/bmi/estimators/neural/_estimators.py b/src/bmi/estimators/neural/_estimators.py index 3ee156e3..5045df50 100644 --- a/src/bmi/estimators/neural/_estimators.py +++ b/src/bmi/estimators/neural/_estimators.py @@ -44,7 +44,7 @@ def train_test_split( xs: BatchedPoints, ys: BatchedPoints, train_size: Optional[float], - key: jax.random.PRNGKeyArray, + key: jax.Array, ) -> tuple[BatchedPoints, BatchedPoints, BatchedPoints, BatchedPoints]: if train_size is None: return xs, xs, ys, ys diff --git a/src/bmi/estimators/neural/_mine_estimator.py b/src/bmi/estimators/neural/_mine_estimator.py index 474c16a1..e09e7686 100644 --- a/src/bmi/estimators/neural/_mine_estimator.py +++ b/src/bmi/estimators/neural/_mine_estimator.py @@ -101,7 +101,7 @@ def _mine_value_neg_grad_log_denom( def _sample_paired_unpaired( - key: jax.random.PRNGKeyArray, + key: jax.Array, xs: BatchedPoints, ys: BatchedPoints, batch_size: Optional[int], @@ -133,7 +133,7 @@ def _sample_paired_unpaired( def mine_training( - rng: jax.random.PRNGKeyArray, + rng: jax.Array, critic: eqx.Module, xs: BatchedPoints, ys: BatchedPoints, @@ -313,7 +313,7 @@ def trained_critic(self) -> Optional[eqx.Module]: def parameters(self) -> MINEParams: return self._params - def _create_critic(self, dim_x: int, dim_y: int, key: jax.random.PRNGKeyArray) -> MLP: + def _create_critic(self, dim_x: int, dim_y: int, key: jax.Array) -> MLP: return MLP(dim_x=dim_x, dim_y=dim_y, key=key, hidden_layers=self._params.hidden_layers) def estimate_with_info(self, x: ArrayLike, y: ArrayLike) -> EstimateResult: diff --git a/src/bmi/interface.py b/src/bmi/interface.py index e88232c5..3982f178 100644 --- a/src/bmi/interface.py +++ b/src/bmi/interface.py @@ -26,7 +26,7 @@ class BaseModel(pydantic.BaseModel): # pytype: disable=invalid-annotation pass -# This should be updated to the PRNGKeyArray (or possibly union with Any) +# This should be updated to the Array (or possibly union with Any) # when it becomes a part of public JAX API KeyArray = Any Pathlike = Union[str, pathlib.Path] diff --git a/src/bmi/samplers/_tfp/_core.py b/src/bmi/samplers/_tfp/_core.py index a1f2e6f1..262b61e7 100644 --- a/src/bmi/samplers/_tfp/_core.py +++ b/src/bmi/samplers/_tfp/_core.py @@ -32,9 +32,7 @@ class JointDistribution: dim_y: int analytic_mi: Optional[float] = None - def sample( - self, n_points: int, key: jax.random.PRNGKeyArray - ) -> tuple[jnp.ndarray, jnp.ndarray]: + def sample(self, n_points: int, key: jax.Array) -> tuple[jnp.ndarray, jnp.ndarray]: """Sample from the joint distribution $P_{XY}$. Args: @@ -152,7 +150,7 @@ def transform( ) -def pmi_profile(key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int) -> jnp.ndarray: +def pmi_profile(key: jax.Array, dist: JointDistribution, n: int) -> jnp.ndarray: """Monte Carlo draws a sample of size `n` from the PMI distribution. Args: @@ -168,7 +166,7 @@ def pmi_profile(key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int) - def monte_carlo_mi_estimate( - key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int + key: jax.Array, dist: JointDistribution, n: int ) -> tuple[float, float]: """Estimates the mutual information $I(X; Y)$ using Monte Carlo sampling.