Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 16, 2023
1 parent 5a7e4c0 commit 8bf521d
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions python/taichi/linalg/matrixfree_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def MatrixFreeBICGSTAB(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
rho_1 = ti.field(dtype=solver_dtype)
scalar_builder.place(alpha, beta, omega, rho, rho_1)
scalar_snode_tree = scalar_builder.finalize()

@ti.kernel
def init():
for I in ti.grouped(x):
Expand All @@ -203,20 +203,20 @@ def reduce(p: ti.template(), q: ti.template()) -> solver_dtype:
return result

@ti.kernel
def copy(orig:ti.template(), dest:ti.template()):
def copy(orig: ti.template(), dest: ti.template()):
for I in ti.grouped(orig):
dest[I] = orig[I]

@ti.kernel
def update_p():
for I in ti.grouped(p):
p[I] = r[I] + beta[None]*(p[I] - omega[None] * Ap[I])
p[I] = r[I] + beta[None] * (p[I] - omega[None] * Ap[I])

@ti.kernel
def update_phat():
for I in ti.grouped(p_hat):
p_hat[I] = p[I]
p_hat[I] = p[I]

@ti.kernel
def update_s():
for I in ti.grouped(s):
Expand All @@ -231,12 +231,12 @@ def update_shat():
def update_x():
for I in ti.grouped(x):
x[I] += alpha[None] * p_hat[I] + omega[None] * s_hat[I]

@ti.kernel
def update_r():
for I in ti.grouped(r):
r[I] = s[I] - omega[None] * t[I]
r[I] = s[I] - omega[None] * t[I]

def solve():
init()
initial_rTr = reduce(r, r)
Expand All @@ -245,7 +245,7 @@ def solve():
for i in range(maxiter):
rho[None] = reduce(r, r_tld)
if rho[None] == 0.0:
print('>>> BICGSTAB failed because r@r_tld = 0.')
print(">>> BICGSTAB failed because r@r_tld = 0.")
break
if i == 0:
copy(orig=r, dest=p)
Expand All @@ -262,20 +262,18 @@ def solve():
copy(orig=Ashat, dest=t)
omega_upper = reduce(t, s)
omega_lower = reduce(t, t)
omega[None] = omega_upper / (omega_lower + 1e-16) if omega_lower == 0.0 \
else omega_upper / omega_lower
omega[None] = omega_upper / (omega_lower + 1e-16) if omega_lower == 0.0 else omega_upper / omega_lower
update_x()
update_r()
rTr = reduce(r, r)
if not quiet:
print(f">>> Iter = {i+1:4}, Residual = {sqrt(rTr):e}")
if sqrt(rTr) < tol:
if not quiet:
print(f'>>> BICGSTAB method converged at #iterations {i}')
print(f">>> BICGSTAB method converged at #iterations {i}")
break
rho_1[None] = rho[None]

solve()
vector_fields_snode_tree.destroy()
scalar_snode_tree.destroy()

0 comments on commit 8bf521d

Please sign in to comment.