Skip to content

Commit

Permalink
#2472: optimise with numba.
Browse files Browse the repository at this point in the history
  • Loading branch information
alejandrobodas committed Oct 21, 2022
1 parent 014d660 commit 44c14b1
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions esmvaltool/diag_scripts/iht_toa/poisson_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
Convergence is achieved faster by using a preconditioner on the output field.
The meridional heat transport is estimated as the gradient of the scalar
The heat transport is calculated as the gradient of the scalar
p-field output of the Poisson solver.
"""

import numpy as np
from numba import jit


def set_metrics():
Expand Down Expand Up @@ -108,46 +109,47 @@ def calc_Ax(x, A_e, A_w, A_s, A_n, A_p):
Ax = np.zeros([M + 2, N + 2])

x = swap_bounds(x)
for j in range(1, M + 1):
for i in range(1, N + 1):
Ax[j, i] = A_s[j-1, i-1]*x[j-1, i] + A_w[j-1, i-1]*x[j, i-1] + \
A_e[j-1, i-1]*x[j, i+1] + A_n[j-1, i-1]*x[j+1, i] + \
A_p[j-1, i-1]*x[j, i]

# Ax[j, i] = A_s[j-1, i-1]*x[j-1, i] + A_w[j-1, i-1]*x[j, i-1] + \
# A_e[j-1, i-1]*x[j, i+1] + A_n[j-1, i-1]*x[j+1, i] + \
# A_p[j-1, i-1]*x[j, i]
Ax[1:M+1, 1:N+1] = A_s[0:M, 0:N] * x[0:M, 1:N+1] + A_w[0:M, 0:N] * \
x[1:M+1, 0:N] + A_e[0:M, 0:N] * x[1:M+1, 2:N+2] + A_n[0:M, 0:N] * \
x[2:M+2, 1:N+1] + A_p[0:M, 0:N] * x[1:M+1, 1:N+1]
Ax = swap_bounds(Ax)
return Ax


def dot_prod(x, y):
# Calculate dot product of two matrices
dot_prod = 0

for j in range(1, M + 1):
for i in range(1, N + 1):
dot_prod += x[j, i] * y[j, i]

return dot_prod
return (x[1:M + 1, 1:N + 1] * y[1:M + 1, 1:N + 1]).sum()


def precon(x, M_e, M_w, M_s, M_n, M_p):
# Preconditioner
Cx = np.zeros([M + 2, N + 2])
precon_a(x, M_w, M_s, M_p, Cx)
Cx = swap_bounds(Cx)
precon_b(M_e, M_n, Cx)
Cx = swap_bounds(Cx)
return Cx


@jit
def precon_a(x, M_w, M_s, M_p, Cx):
for j in range(1, M + 1):
for i in range(1, N + 1):
Cx[j, i] = M_p[j, i] * (x[j, i] - M_s[j, i] * Cx[j - 1, i] -
M_w[j, i] * Cx[j, i - 1])

Cx = swap_bounds(Cx)

@jit
def precon_b(M_e, M_n, Cx):
for j in range(M, 0, -1):
for i in range(N, 0, -1):
Cx[j,
i] = Cx[j,
i] - M_e[j, i] * Cx[j, i + 1] - M_n[j, i] * Cx[j + 1, i]

Cx = swap_bounds(Cx)
return Cx


def bicgstab(logger, x, b, A_e, A_w, A_s, A_n, A_p, M_e, M_w, M_s, M_n, M_p):
# Bi-conjugate gradient stabilized numerical solver
Expand Down

0 comments on commit 44c14b1

Please sign in to comment.