Skip to content

Commit

Permalink
Remove PRNGKeyArray
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Feb 28, 2024
1 parent 9cbeac2 commit 500bdc6
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 15 deletions.
6 changes: 2 additions & 4 deletions src/bmi/estimators/neural/_basic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
from bmi.estimators.neural._types import BatchedPoints, Critic, Point


def get_batch(
xs: BatchedPoints, ys: BatchedPoints, key: jax.random.PRNGKeyArray, batch_size: Optional[int]
):
def get_batch(xs: BatchedPoints, ys: BatchedPoints, key: jax.Array, batch_size: Optional[int]):
if batch_size is not None:
batch_indices = jax.random.choice(
key,
Expand All @@ -26,7 +24,7 @@ def get_batch(


def basic_training(
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
critic: eqx.Module,
mi_formula: Callable[[Critic, Point, Point], float],
xs: BatchedPoints,
Expand Down
2 changes: 1 addition & 1 deletion src/bmi/estimators/neural/_critics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class MLP(eqx.Module):

def __init__(
self,
key: jax.random.PRNGKeyArray,
key: jax.Array,
dim_x: int,
dim_y: int,
hidden_layers: Sequence[int] = (5,),
Expand Down
2 changes: 1 addition & 1 deletion src/bmi/estimators/neural/_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def train_test_split(
xs: BatchedPoints,
ys: BatchedPoints,
train_size: Optional[float],
key: jax.random.PRNGKeyArray,
key: jax.Array,
) -> tuple[BatchedPoints, BatchedPoints, BatchedPoints, BatchedPoints]:
if train_size is None:
return xs, xs, ys, ys
Expand Down
6 changes: 3 additions & 3 deletions src/bmi/estimators/neural/_mine_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _mine_value_neg_grad_log_denom(


def _sample_paired_unpaired(
key: jax.random.PRNGKeyArray,
key: jax.Array,
xs: BatchedPoints,
ys: BatchedPoints,
batch_size: Optional[int],
Expand Down Expand Up @@ -133,7 +133,7 @@ def _sample_paired_unpaired(


def mine_training(
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
critic: eqx.Module,
xs: BatchedPoints,
ys: BatchedPoints,
Expand Down Expand Up @@ -313,7 +313,7 @@ def trained_critic(self) -> Optional[eqx.Module]:
def parameters(self) -> MINEParams:
return self._params

def _create_critic(self, dim_x: int, dim_y: int, key: jax.random.PRNGKeyArray) -> MLP:
def _create_critic(self, dim_x: int, dim_y: int, key: jax.Array) -> MLP:
return MLP(dim_x=dim_x, dim_y=dim_y, key=key, hidden_layers=self._params.hidden_layers)

def estimate_with_info(self, x: ArrayLike, y: ArrayLike) -> EstimateResult:
Expand Down
2 changes: 1 addition & 1 deletion src/bmi/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class BaseModel(pydantic.BaseModel): # pytype: disable=invalid-annotation
pass


# This should be updated to the PRNGKeyArray (or possibly union with Any)
# This should be updated to the Array (or possibly union with Any)
# when it becomes a part of public JAX API
KeyArray = Any
Pathlike = Union[str, pathlib.Path]
Expand Down
8 changes: 3 additions & 5 deletions src/bmi/samplers/_tfp/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ class JointDistribution:
dim_y: int
analytic_mi: Optional[float] = None

def sample(
self, n_points: int, key: jax.random.PRNGKeyArray
) -> tuple[jnp.ndarray, jnp.ndarray]:
def sample(self, n_points: int, key: jax.Array) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Sample from the joint distribution $P_{XY}$.
Args:
Expand Down Expand Up @@ -152,7 +150,7 @@ def transform(
)


def pmi_profile(key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int) -> jnp.ndarray:
def pmi_profile(key: jax.Array, dist: JointDistribution, n: int) -> jnp.ndarray:
"""Monte Carlo draws a sample of size `n` from the PMI distribution.
Args:
Expand All @@ -168,7 +166,7 @@ def pmi_profile(key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int) -


def monte_carlo_mi_estimate(
key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int
key: jax.Array, dist: JointDistribution, n: int
) -> tuple[float, float]:
"""Estimates the mutual information $I(X; Y)$ using Monte Carlo sampling.
Expand Down

0 comments on commit 500bdc6

Please sign in to comment.