Skip to content

Commit

Permalink
Merge pull request #202 from astro-informatics/feature/notebook_plots
Browse files Browse the repository at this point in the history
Feature/notebook plots
  • Loading branch information
jasonmcewen authored May 2, 2024
2 parents bc7cbd8 + 0ff8fb4 commit a7e887e
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 62 deletions.
17 changes: 11 additions & 6 deletions notebooks/JAX_HEALPix_frontend.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
"metadata": {},
"outputs": [],
"source": [
"# Install s2fft\n",
"!pip install s2fft &> /dev/null"
"import sys\n",
"IN_COLAB = 'google.colab' in sys.modules\n",
"\n",
"# Install s2fft and data if running on google colab.\n",
"if IN_COLAB:\n",
" !pip install s2fft &> /dev/null"
]
},
{
Expand All @@ -42,11 +46,12 @@
"import numpy as np\n",
"import s2fft \n",
"\n",
"L = 1024\n",
"nside = 512\n",
"L = 128\n",
"nside = 64\n",
"method = \"jax_healpy\"\n",
"sampling = \"healpix\"\n",
"flm = np.random.randn(L, 2*L-1) + 1j*np.random.randn(L, 2*L-1)\n",
"rng = np.random.default_rng(23457801234570)\n",
"flm = s2fft.utils.signal_generator.generate_flm(rng, L)\n",
"f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)"
]
},
Expand Down Expand Up @@ -183,7 +188,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.10.0"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
17 changes: 11 additions & 6 deletions notebooks/JAX_SSHT_frontend.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
"metadata": {},
"outputs": [],
"source": [
"# Install s2fft\n",
"!pip install s2fft &> /dev/null"
"import sys\n",
"IN_COLAB = 'google.colab' in sys.modules\n",
"\n",
"# Install s2fft and data if running on google colab.\n",
"if IN_COLAB:\n",
" !pip install s2fft &> /dev/null"
]
},
{
Expand All @@ -42,9 +46,10 @@
"import numpy as np\n",
"import s2fft \n",
"\n",
"L = 1024\n",
"L = 128\n",
"method = \"jax_ssht\"\n",
"flm = np.random.randn(L, 2*L-1) + 1j*np.random.randn(L, 2*L-1)\n",
"rng = np.random.default_rng(23457801234570)\n",
"flm = s2fft.utils.signal_generator.generate_flm(rng, L)\n",
"f = s2fft.inverse(flm, L, method=method)"
]
},
Expand Down Expand Up @@ -107,7 +112,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Mean absolute error = 4.909423754134027e-11\n"
"Mean absolute error = 7.784372519411174e-13\n"
]
}
],
Expand Down Expand Up @@ -181,7 +186,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.10.0"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
82 changes: 48 additions & 34 deletions notebooks/spherical_harmonic_transform.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,21 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install s2fft\n",
"!pip install s2fft &> /dev/null"
"import sys\n",
"IN_COLAB = 'google.colab' in sys.modules\n",
"\n",
"# Install a spherical plotting package.\n",
"!pip install cartopy &> /dev/null\n",
"\n",
"# Install s2fft and data if running on google colab.\n",
"if IN_COLAB:\n",
" !pip install s2fft &> /dev/null\n",
" !mkdir data/\n",
" !wget https://github.com/astro-informatics/s2fft/raw/main/notebooks/data/Gaia_EDR3_flux.npy -P data/ &> /dev/null"
]
},
{
Expand All @@ -32,20 +41,41 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt \n",
"import cartopy.crs as ccrs \n",
"import s2fft \n",
"\n",
"L = 256\n",
"sampling = \"mw\"\n",
"flm = np.random.randn(L, 2*L-1) + 1j*np.random.randn(L, 2*L-1)\n",
"f = s2fft.inverse_jax(flm, L)"
"f = np.load('data/Gaia_EDR3_flux.npy')\n",
"L = f.shape[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets look at the input signal"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10,5))\n",
"ax = plt.axes(projection=ccrs.Mollweide())\n",
"im = ax.imshow(f, transform=ccrs.PlateCarree(), cmap='magma')\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
Expand All @@ -62,7 +92,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -81,7 +111,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -103,7 +133,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -122,7 +152,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -144,34 +174,18 @@
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean absolute error = 8.478196507592078e-11\n"
]
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(f\"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean absolute error using precomputes = 8.478196507592078e-11\n"
]
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(f\"Mean absolute error using precomputes = {np.nanmean(np.abs(f_recov_pre - f))}\")"
]
Expand All @@ -193,7 +207,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.10.0"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
8 changes: 6 additions & 2 deletions notebooks/spherical_rotation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
"metadata": {},
"outputs": [],
"source": [
"# Install s2fft\n",
"!pip install s2fft &> /dev/null"
"import sys\n",
"IN_COLAB = 'google.colab' in sys.modules\n",
"\n",
"# Install s2fft and data if running on google colab.\n",
"if IN_COLAB:\n",
" !pip install s2fft &> /dev/null"
]
},
{
Expand Down
8 changes: 6 additions & 2 deletions notebooks/torch_frontend.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
"metadata": {},
"outputs": [],
"source": [
"# Install s2fft\n",
"!pip install s2fft &> /dev/null"
"import sys\n",
"IN_COLAB = 'google.colab' in sys.modules\n",
"\n",
"# Install s2fft and data if running on google colab.\n",
"if IN_COLAB:\n",
" !pip install s2fft &> /dev/null"
]
},
{
Expand Down
12 changes: 8 additions & 4 deletions notebooks/wigner_transform.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
"metadata": {},
"outputs": [],
"source": [
"# Install s2fft\n",
"!pip install s2fft &> /dev/null"
"import sys\n",
"IN_COLAB = 'google.colab' in sys.modules\n",
"\n",
"# Install s2fft and data if running on google colab.\n",
"if IN_COLAB:\n",
" !pip install s2fft &> /dev/null"
]
},
{
Expand Down Expand Up @@ -47,7 +51,7 @@
"L = 128\n",
"N = 3\n",
"reality = True\n",
"rng = np.random.default_rng(0)\n",
"rng = np.random.default_rng(83459)\n",
"flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)"
]
},
Expand Down Expand Up @@ -190,7 +194,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.10.0"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
16 changes: 8 additions & 8 deletions s2fft/utils/signal_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ def generate_flm(

for el in range(max(L_lower, abs(spin)), L):
if reality:
flm[el, 0 + L - 1] = rng.uniform()
flm[el, 0 + L - 1] = rng.normal()
else:
flm[el, 0 + L - 1] = rng.uniform() + 1j * rng.uniform()
flm[el, 0 + L - 1] = rng.normal() + 1j * rng.normal()

for m in range(1, el + 1):
flm[el, m + L - 1] = rng.uniform() + 1j * rng.uniform()
flm[el, m + L - 1] = rng.normal() + 1j * rng.normal()
if reality:
flm[el, -m + L - 1] = (-1) ** m * np.conj(flm[el, m + L - 1])
else:
flm[el, -m + L - 1] = rng.uniform() + 1j * rng.uniform()
flm[el, -m + L - 1] = rng.normal() + 1j * rng.normal()

return torch.from_numpy(flm) if using_torch else flm

Expand Down Expand Up @@ -86,22 +86,22 @@ def generate_flmn(
for n in range(-N + 1, N):
for el in range(max(L_lower, abs(n)), L):
if reality:
flmn[N - 1 + n, el, 0 + L - 1] = rng.uniform()
flmn[N - 1 + n, el, 0 + L - 1] = rng.normal()
flmn[N - 1 - n, el, 0 + L - 1] = (-1) ** n * flmn[
N - 1 + n,
el,
0 + L - 1,
]
else:
flmn[N - 1 + n, el, 0 + L - 1] = rng.uniform() + 1j * rng.uniform()
flmn[N - 1 + n, el, 0 + L - 1] = rng.normal() + 1j * rng.normal()

for m in range(1, el + 1):
flmn[N - 1 + n, el, m + L - 1] = rng.uniform() + 1j * rng.uniform()
flmn[N - 1 + n, el, m + L - 1] = rng.normal() + 1j * rng.normal()
if reality:
flmn[N - 1 - n, el, -m + L - 1] = (-1) ** (m + n) * np.conj(
flmn[N - 1 + n, el, m + L - 1]
)
else:
flmn[N - 1 + n, el, -m + L - 1] = rng.uniform() + 1j * rng.uniform()
flmn[N - 1 + n, el, -m + L - 1] = rng.normal() + 1j * rng.normal()

return torch.from_numpy(flmn) if using_torch else flmn

0 comments on commit a7e887e

Please sign in to comment.