Skip to content

graphcore-research/pytorch-approx-topk

Repository files navigation

PyTorch Approx Topk

An alpha implementation of the bucketed top-k algorithm using a priority queue.

Requires: Python 3.11, CUDA toolkit 12.1, Ninja (ninja-build).

pip install git+https://github.com/graphcore-research/pytorch-approx-topk.git

Usage (note that kernel compilation on first use may take a while):

from approx_topk.priority_queue import topk as approx_topk
import torch

x = torch.randn(128, int(2**20), device="cuda")
values, indices = approx_topk(x, k=int(2**16), dim=-1, j=2, k_mult=1)

Note that j is $k_b$ and k_mult is $k_b \cdot b / k$.

Repository highlights:

Development

To set up the environment, install the dependencies:

  • CUDA toolkit 12.1
  • Ninja (ninja-build)
  • Python 3.11
  • Python Poetry

Then run poetry install --with benchmarks

License

Copyright (c) 2024 Graphcore Ltd and Oscar Key. Licensed under the MIT License.

About

Bucketed top-k for PyTorch using a priority queue

Topics

Resources

License

Stars

Watchers

Forks

Contributors 4

  •  
  •  
  •  
  •  

Languages