From 078df4ae6cc703ff033c096051d1896668840903 Mon Sep 17 00:00:00 2001 From: Qian Bao Date: Fri, 7 Apr 2023 18:46:44 +0800 Subject: [PATCH] [lang] Implement experimental CG(Conjugate Gradient) solver in Taichi-lang (#7690) Issue: #7634 ### Brief Summary This PR implements a matrix-free CG (Conjugate-Gradient) solver in Taichi. The solver targets to solve the linear equation system: $$ Ax = b$$ where $A$ is implicitly represented as a `LinearOperator` instead of a explicitly stored matrix, hence the name "matrix-free". --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/linalg/__init__.py | 1 + python/taichi/linalg/taichi_cg.py | 105 ++++++++++++++++++++++++++++++ tests/python/test_taichi_cg.py | 57 ++++++++++++++++ 3 files changed, 163 insertions(+) create mode 100644 python/taichi/linalg/taichi_cg.py create mode 100644 tests/python/test_taichi_cg.py diff --git a/python/taichi/linalg/__init__.py b/python/taichi/linalg/__init__.py index 0d003667ca284..2bb594ffe273f 100644 --- a/python/taichi/linalg/__init__.py +++ b/python/taichi/linalg/__init__.py @@ -3,3 +3,4 @@ from taichi.linalg.cg import CG from taichi.linalg.sparse_matrix import * from taichi.linalg.sparse_solver import SparseSolver +from taichi.linalg.taichi_cg import * diff --git a/python/taichi/linalg/taichi_cg.py b/python/taichi/linalg/taichi_cg.py new file mode 100644 index 0000000000000..9ca805599007b --- /dev/null +++ b/python/taichi/linalg/taichi_cg.py @@ -0,0 +1,105 @@ +from math import sqrt + +from taichi.lang.exception import TaichiRuntimeError, TaichiTypeError + +import taichi as ti + + +@ti.data_oriented +class LinearOperator: + def __init__(self, matvec_kernel): + self._matvec = matvec_kernel + + def matvec(self, x, Ax): + if x.shape != Ax.shape: + raise TaichiRuntimeError( + f"Dimension mismatch x.shape{x.shape} != Ax.shape{Ax.shape}.") + self._matvec(x, Ax) + + +def taichi_cg_solver(A, b, x, tol=1e-6, maxiter=5000, quiet=True): + if b.dtype != x.dtype: + raise TaichiTypeError( + f"Dtype mismatch b.dtype({b.dtype}) != x.dtype({x.dtype}).") + if str(b.dtype) == 'f32': + solver_dtype = ti.f32 + elif str(b.dtype) == 'f64': + solver_dtype = ti.f64 + else: + raise TaichiTypeError(f"Not supported dtype: {b.dtype}") + if b.shape != x.shape: + raise TaichiRuntimeError( + f"Dimension mismatch b.shape{b.shape} != x.shape{x.shape}.") + + size = b.shape + vector_fields_builder = ti.FieldsBuilder() + p = ti.field(dtype=solver_dtype) + r = ti.field(dtype=solver_dtype) + Ap = ti.field(dtype=solver_dtype) + vector_fields_builder.dense(ti.ij, size).place(p, r, Ap) + vector_fields_snode_tree = vector_fields_builder.finalize() + + scalar_builder = ti.FieldsBuilder() + alpha = ti.field(dtype=solver_dtype) + beta = ti.field(dtype=solver_dtype) + scalar_builder.place(alpha, beta) + scalar_snode_tree = scalar_builder.finalize() + + @ti.kernel + def init(): + for I in ti.grouped(x): + r[I] = b[I] + p[I] = 0.0 + Ap[I] = 0.0 + + @ti.kernel + def reduce(p: ti.template(), q: ti.template()) -> solver_dtype: + result = 0.0 + for I in ti.grouped(p): + result += p[I] * q[I] + return result + + @ti.kernel + def update_x(): + for I in ti.grouped(x): + x[I] += alpha[None] * p[I] + + @ti.kernel + def update_r(): + for I in ti.grouped(r): + r[I] -= alpha[None] * Ap[I] + + @ti.kernel + def update_p(): + for I in ti.grouped(p): + p[I] = r[I] + beta[None] * p[I] + + def solve(): + init() + initial_rTr = reduce(r, r) + if not quiet: + print(f'>>> Initial residual = {initial_rTr:e}') + old_rTr = initial_rTr + update_p() + # -- Main loop -- + for i in range(maxiter): + A._matvec(p, Ap) # compute Ap = A x p + pAp = reduce(p, Ap) + alpha[None] = old_rTr / pAp + update_x() + update_r() + new_rTr = reduce(r, r) + if sqrt(new_rTr) < tol: + if not quiet: + print('>>> Conjugate Gradient method converged.') + print(f'>>> #iterations {i}') + break + beta[None] = new_rTr / old_rTr + update_p() + old_rTr = new_rTr + if not quiet: + print(f'>>> Iter = {i+1:4}, Residual = {sqrt(new_rTr):e}') + + solve() + vector_fields_snode_tree.destroy() + scalar_snode_tree.destroy() diff --git a/tests/python/test_taichi_cg.py b/tests/python/test_taichi_cg.py new file mode 100644 index 0000000000000..226a520a5e78a --- /dev/null +++ b/tests/python/test_taichi_cg.py @@ -0,0 +1,57 @@ +import math + +import pytest +from taichi.linalg import LinearOperator, taichi_cg_solver + +import taichi as ti +from tests import test_utils + +vk_on_mac = (ti.vulkan, 'Darwin') + + +@pytest.mark.parametrize("ti_dtype", [ti.f32, ti.f64]) +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan], exclude=[vk_on_mac]) +def test_taichi_cg(ti_dtype): + GRID = 32 + Ax = ti.field(dtype=ti_dtype, shape=(GRID, GRID)) + x = ti.field(dtype=ti_dtype, shape=(GRID, GRID)) + b = ti.field(dtype=ti_dtype, shape=(GRID, GRID)) + + @ti.kernel + def init(): + for i, j in ti.ndrange(GRID, GRID): + xl = i / (GRID - 1) + yl = j / (GRID - 1) + b[i, j] = ti.sin(2 * math.pi * xl) * ti.sin(2 * math.pi * yl) + x[i, j] = 0.0 + + @ti.kernel + def compute_Ax(v: ti.template(), mv: ti.template()): + for i, j in v: + l = v[i - 1, j] if i - 1 >= 0 else 0.0 + r = v[i + 1, j] if i + 1 <= GRID - 1 else 0.0 + t = v[i, j + 1] if j + 1 <= GRID - 1 else 0.0 + b = v[i, j - 1] if j - 1 >= 0 else 0.0 + # Avoid ill-conditioned matrix A + mv[i, j] = 20 * v[i, j] - l - r - t - b + + @ti.kernel + def check_solution(sol: ti.template(), ans: ti.template(), + tol: ti_dtype) -> bool: + exit_code = True + for i, j in ti.ndrange(GRID, GRID): + if ti.abs(ans[i, j] - sol[i, j]) < tol: + pass + else: + exit_code = False + return exit_code + + A = LinearOperator(compute_Ax) + init() + taichi_cg_solver(A, b, x, maxiter=10 * GRID * GRID, tol=1e-18, quiet=True) + compute_Ax(x, Ax) + # `tol` can't be < 1e-6 for ti.f32 because of accumulating round-off error; + # see https://en.wikipedia.org/wiki/Conjugate_gradient_method#cite_note-6 + # for more details. + result = check_solution(Ax, b, tol=1e-6) + assert result