Skip to content

Commit

Permalink
Add benchmark results
Browse files Browse the repository at this point in the history
  • Loading branch information
adonath committed Dec 17, 2024
1 parent 1ba5ea8 commit 9734145
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

examples/benchmarks/results

uv.lock
29 changes: 28 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,37 @@
[![License](https://img.shields.io/github/license/adonath/gmmx)](https://img.shields.io/github/license/adonath/gmmx)

<p align="center">
<img width="50%" src="docs/gmmx-logo.png" alt="GMMX Logo"/>
<img width="50%" src="docs/_static/gmmx-logo.png" alt="GMMX Logo"/>
</p>

A minimal implementation of Gaussian Mixture Models in Jax

- **Github repository**: <https://github.com/adonath/gmmx/>
- **Documentation** <https://adonath.github.io/gmmx/>

## Installation

```bash
pip install gmmx
```

## Usage

```python
from gmmx import GaussianMixtureModelJax, EMFitter

# Create a Gaussian Mixture Model with 2 components
gmm = GaussianMixtureModelJax()



## Benchmarks

Here are some results from the benchmarks in the `benchmarks` folder comparing against Scikit-Learn. The benchmarks were run on a 2021 MacBook Pro with an M1 Pro chip.


### Prediction Time
| Time vs. Number of Components | Time vs. Number of Samples | Time vs. Number of Features |
| -------- | ------- | ------- |
| ![Time vs. Number of Components](docs/_static/time-vs-n-components-predict.png) | ![Time vs. Number of Samples](docs/_static/time-vs-n-samples-predict.png) | ![Time vs. Number of Features](docs/_static/time-vs-n-features-predict.png) |
```
File renamed without changes
Binary file added docs/_static/time-vs-n-components-predict.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/time-vs-n-features-predict.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/time-vs-n-samples-predict.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# gmmx
# GMMX: Gaussian Mixture Models in Jax

[![Release](https://img.shields.io/github/v/release/adonath/gmmx)](https://img.shields.io/github/v/release/adonath/gmmx)
[![Build status](https://img.shields.io/github/actions/workflow/status/adonath/gmmx/main.yml?branch=main)](https://github.com/adonath/gmmx/actions/workflows/main.yml?query=branch%3Amain)
[![Commit activity](https://img.shields.io/github/commit-activity/m/adonath/gmmx)](https://img.shields.io/github/commit-activity/m/adonath/gmmx)
[![License](https://img.shields.io/github/license/adonath/gmmx)](https://img.shields.io/github/license/adonath/gmmx)

A minimal implementation of Gaussian Mixture Models in Jax
A minimal implementation of Gaussian Mixture Models in Jax.
22 changes: 11 additions & 11 deletions examples/benchmarks/benchmark-predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from pathlib import Path
from typing import Optional

import jax
import matplotlib.pyplot as plt
import numpy as np
from jax import numpy as jnp
from jax.lib import xla_bridge

from gmmx import GaussianMixtureModelJax
Expand All @@ -27,9 +27,9 @@
DPI = 180


N_SAMPLES = 1024 * 2 ** np.arange(0, 11)
N_COMPONENTS = 2 ** np.arange(1, 8)
N_FEATURES = 2 ** np.arange(1, 8)
N_SAMPLES = 2 ** np.arange(5, 18)
N_COMPONENTS = 2 ** np.arange(1, 7)
N_FEATURES = 2 ** np.arange(1, 7)

PATH_TEMPLATE = "{user}-{machine}-{system}-{cpu}-{device-platform}"

Expand Down Expand Up @@ -103,9 +103,9 @@ def create_random_gmm(n_components, n_features, random_state=RANDOM_STATE, devic
weights /= weights.sum()

return GaussianMixtureModelJax.from_squeezed(
means=jnp.device_put(means, device=device),
covariances=jnp.device_put(covariances, device=device),
weights=jnp.device_put(weights, device=device),
means=jax.device_put(means, device=device),
covariances=jax.device_put(covariances, device=device),
weights=jax.device_put(weights, device=device),
)


Expand Down Expand Up @@ -144,8 +144,8 @@ def plot_result(result, x_axis, filename, title=""):
ax.scatter(x, result.time_sklearn, color=color)

color = "#405087"
ax.plot(x, result.time_jax, label=f"jax ({meta})", color=color)
ax.scatter(x, result.time_jax, color=color)
ax.plot(x, result.time_jax, label=f"jax ({meta})", color=color, zorder=3)
ax.scatter(x, result.time_jax, color=color, zorder=3)

if result.time_jax_gpu:
color = "#E58336"
Expand Down Expand Up @@ -197,7 +197,7 @@ def measure_time_sklearn_vs_jax(n_components_grid, n_samples_grid, n_features_gr

if INCLUDE_GPU:
gmm_gpu = create_random_gmm(n_component, n_features, device="gpu")
x_gpu = jnp.device_put(x, device="gpu")
x_gpu = jax.device_put(x, device="gpu")
time_jax_gpu.append(measure_time_predict_jax(gmm_gpu, x_gpu))

return BenchmarkResult(
Expand Down Expand Up @@ -272,5 +272,5 @@ def run_time_vs_n_samples(filename):
if __name__ == "__main__":
path = PATH_RESULTS / PATH_TEMPLATE.format(**get_provenance()["env"])
run_time_vs_n_components(path / "time-vs-n-components-predict.json")
run_time_vs_n_features(path / "time-vs-n-features-prefict.json")
run_time_vs_n_features(path / "time-vs-n-features-predict.json")
run_time_vs_n_samples(path / "time-vs-n-samples-predict.json")

0 comments on commit 9734145

Please sign in to comment.