Skip to content

Commit

Permalink
Add weight to pick_random()
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Oct 28, 2024
1 parent 0397a6c commit 8cf0629
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def random_permutation(*shape: Union[Shape, Any], dims=non_batch, index_dim=chan
return stack(result, nu)


def pick_random(value: TensorOrTree, dim: DimFilter, count: Union[int, Shape, None] = 1) -> TensorOrTree:
def pick_random(value: TensorOrTree, dim: DimFilter, count: Union[int, Shape, None] = 1, weight: Optional[Tensor] = None) -> TensorOrTree:
"""
Pick one or multiple random entries from `value`.
Expand All @@ -527,22 +527,25 @@ def pick_random(value: TensorOrTree, dim: DimFilter, count: Union[int, Shape, No
You can pass `range` (the type) to retrieve the picked indices.
dim: Dimension along which to pick random entries. `Shape` with one dim.
count: Number of entries to pick. When specified as a `Shape`, lists picked values along `count` instead of `dim`.
weight: Probability weight of each item along `dim`. Will be normalized to sum to 1.
Returns:
`Tensor` or tree equal to `value`.
"""
v_shape = shape(value)
dim = v_shape.only(dim)
idx = random_permutation(dim & v_shape.batch & dim.non_uniform_shape, dims=dim)
if count is None and dim.well_defined:
count = dim.size
if count is not None:
if isinstance(count, int):
idx = idx[{dim: slice(count)}]
else:
assert isinstance(count, Shape)
idx = idx[{dim: slice(count.volume)}]
idx = unpack_dim(idx, dim, count)
n = dim.volume if count is None else (count.volume if isinstance(count, Shape) else count)
if n == dim.volume and weight is None:
idx = random_permutation(dim & v_shape.batch & dim.non_uniform_shape, dims=dim)
idx = unpack_dim(idx, dim, count) if isinstance(count, Shape) else idx
else:
probability = weight / sum_(weight, dim)
np_idx = np.random.choice(dim.volume, size=n, replace=False, p=probability.numpy([dim]))
idx = wrap(np_idx, count if isinstance(count, Shape) else dim.without_sizes())
# idx = ravel_index()
idx = expand(idx, channel(index=dim.name))
return slice_(value, idx)


Expand Down

0 comments on commit 8cf0629

Please sign in to comment.