diff --git a/python/taichi/__init__.py b/python/taichi/__init__.py index 439a04a4565bc..605bf1fbf5d45 100644 --- a/python/taichi/__init__.py +++ b/python/taichi/__init__.py @@ -4,6 +4,7 @@ from taichi._lib import core as _ti_core from taichi._logging import * from taichi._snode import * +from taichi.ad import clear_all_gradients from taichi.lang import * # pylint: disable=W0622 # TODO(archibate): It's `taichi.lang.core` overriding `taichi.core` from taichi.types.annotations import * # Provide a shortcut to types since they're commonly used. @@ -38,7 +39,8 @@ 'imshow': 'tools.imshow', 'imwrite': 'tools.imwrite', 'ext_arr': 'types.ndarray', - 'any_arr': 'types.ndarray' + 'any_arr': 'types.ndarray', + 'Tape': 'ad.Tape' } __customized_deprecations__ = { diff --git a/python/taichi/ad.py b/python/taichi/ad.py index 7cc3748811495..f941a2ab7e228 100644 --- a/python/taichi/ad.py +++ b/python/taichi/ad.py @@ -4,6 +4,106 @@ gradient computation task. """ from taichi.lang import impl +from taichi.lang.snode import SNode + +from taichi import _snode + + +class Tape: + def __init__(self, loss=None, clear_gradients=True): + """A context manager for reverse mode autodiff :class:`~taichi.ad.Tape`. The + context manager would catching all of the callings of functions that + decorated by :func:`~taichi.lang.kernel_impl.kernel` or + :func:`~taichi.ad.grad_replaced` under `with` statement, and calculate + all the partial gradients of a given loss variable by calling all of the + gradient function of the callings caught in reverse order while `with` + statement ended. + + See also :func:`~taichi.lang.kernel_impl.kernel` and + :func:`~taichi.ad.grad_replaced` for gradient functions. + + Args: + loss(:class:`~taichi.lang.expr.Expr`): The loss field, which shape should be (). + clear_gradients(Bool): Before `with` body start, clear all gradients or not. + + Example:: + + >>> @ti.kernel + >>> def sum(a: ti.float32): + >>> for I in ti.grouped(x): + >>> y[None] += x[I] ** a + >>> + >>> with ti.Tape(loss = y): + >>> sum(2) + """ + self.calls = [] + self.entered = False + self.gradient_evaluated = False + self.clear_gradients = clear_gradients + self.runtime = impl.get_runtime() + self.eval_on_exit = loss is not None + self.loss = loss + + def __enter__(self): + assert not self.entered, "Tape can be entered only once." + self.entered = True + + impl.get_runtime().materialize() + if len(self.loss.shape) != 0: + raise RuntimeError( + 'The loss of `Tape` must be a 0-D field, i.e. scalar') + if not self.loss.snode.ptr.has_adjoint(): + raise RuntimeError( + 'Gradients of loss are not allocated, please use ti.field(..., needs_grad=True)' + ' for all fields that are required by autodiff.') + if self.clear_gradients: + clear_all_gradients() + + from taichi._kernels import clear_loss # pylint: disable=C0415 + clear_loss(self.loss) + + # Attach the context manager to runtime + self.runtime.target_tape = self + + def __exit__(self, _type, value, tb): + self.runtime.target_tape = None + if self.eval_on_exit: + self.grad() + + def insert(self, func, args): + self.calls.append((func, args)) + + def grad(self): + assert self.entered, "Before evaluating gradients tape must be entered." + assert not self.gradient_evaluated, "Gradients of grad can be evaluated only once." + for func, args in reversed(self.calls): + func.grad(*args) + self.gradient_evaluated = True + + +def clear_all_gradients(): + """Sets the gradients of all fields to zero. + """ + impl.get_runtime().materialize() + + def visit(node): + places = [] + for _i in range(node.ptr.get_num_ch()): + ch = node.ptr.get_ch(_i) + if not ch.is_place(): + visit(SNode(ch)) + else: + if not ch.is_primal(): + places.append(ch.get_expr()) + + places = tuple(places) + if places: + from taichi._kernels import \ + clear_gradients # pylint: disable=C0415 + clear_gradients(places) + + for root_fb in _snode.FieldsBuilder._finalized_roots(): + visit(root_fb) def grad_replaced(func): diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index 32d4345f81965..bc7cac68655d9 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -20,6 +20,6 @@ 'any_array', 'ast', 'common_ops', 'enums', 'exception', 'expr', 'impl', 'inspect', 'kernel_arguments', 'kernel_impl', 'matrix', 'mesh', 'misc', 'ops', 'platform', 'runtime_ops', 'shell', 'snode', 'source_builder', - 'struct', 'tape', 'util' + 'struct', 'util' ] ] diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index e950d53adba0c..c7927b7ef2c9b 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -21,7 +21,6 @@ MeshReorderedScalarFieldProxy, element_type_name) from taichi.lang.snode import SNode from taichi.lang.struct import Struct, StructField, _IntermediateStruct -from taichi.lang.tape import TapeImpl from taichi.lang.util import (cook_dtype, get_traceback, is_taichi_class, python_scope, taichi_scope, warning) from taichi.types.primitive_types import f16, f32, f64, i32, i64, types @@ -333,9 +332,6 @@ def clear(self): self._signal_handler_registry = None self.materialized = False - def get_tape(self, loss=None): - return TapeImpl(self, loss) - def sync(self): self.materialize() self.prog.synchronize() diff --git a/python/taichi/lang/misc.py b/python/taichi/lang/misc.py index 2d03a935063d4..fb69fe8080c7d 100644 --- a/python/taichi/lang/misc.py +++ b/python/taichi/lang/misc.py @@ -11,7 +11,6 @@ from taichi.lang import impl from taichi.lang.expr import Expr from taichi.lang.impl import axes, get_runtime -from taichi.lang.snode import SNode from taichi.profiler.kernel_profiler import get_default_kernel_profiler from taichi.types.primitive_types import f32, f64, i32, i64 @@ -659,77 +658,6 @@ def mesh_patch_idx(): ) -def Tape(loss, clear_gradients=True): - """Returns a context manager of :class:`~taichi.lang.tape.TapeImpl`. The - context manager would catching all of the callings of functions that - decorated by :func:`~taichi.lang.kernel_impl.kernel` or - :func:`~taichi.ad.grad_replaced` under `with` statement, and calculate - all the partial gradients of a given loss variable by calling all of the - gradient function of the callings caught in reverse order while `with` - statement ended. - - See also :func:`~taichi.lang.kernel_impl.kernel` and - :func:`~taichi.ad.grad_replaced` for gradient functions. - - Args: - loss(:class:`~taichi.lang.expr.Expr`): The loss field, which shape should be (). - clear_gradients(Bool): Before `with` body start, clear all gradients or not. - - Returns: - :class:`~taichi.lang.tape.TapeImpl`: The context manager. - - Example:: - - >>> @ti.kernel - >>> def sum(a: ti.float32): - >>> for I in ti.grouped(x): - >>> y[None] += x[I] ** a - >>> - >>> with ti.Tape(loss = y): - >>> sum(2) - """ - impl.get_runtime().materialize() - if len(loss.shape) != 0: - raise RuntimeError( - 'The loss of `Tape` must be a 0-D field, i.e. scalar') - if not loss.snode.ptr.has_adjoint(): - raise RuntimeError( - 'Gradients of loss are not allocated, please use ti.field(..., needs_grad=True)' - ' for all fields that are required by autodiff.') - if clear_gradients: - clear_all_gradients() - - from taichi._kernels import clear_loss # pylint: disable=C0415 - clear_loss(loss) - - return impl.get_runtime().get_tape(loss) - - -def clear_all_gradients(): - """Sets the gradients of all fields to zero. - """ - impl.get_runtime().materialize() - - def visit(node): - places = [] - for _i in range(node.ptr.get_num_ch()): - ch = node.ptr.get_ch(_i) - if not ch.is_place(): - visit(SNode(ch)) - else: - if not ch.is_primal(): - places.append(ch.get_expr()) - - places = tuple(places) - if places: - from taichi._kernels import \ - clear_gradients # pylint: disable=C0415 - clear_gradients(places) - - for root_fb in _snode.FieldsBuilder._finalized_roots(): - visit(root_fb) - - def is_arch_supported(arch, use_gles=False): """Checks whether an arch is supported on the machine. @@ -787,7 +715,6 @@ def get_host_arch_list(): 'i', 'ij', 'ijk', 'ijkl', 'ijl', 'ik', 'ikl', 'il', 'j', 'jk', 'jkl', 'jl', 'k', 'kl', 'l', 'x86_64', 'x64', 'dx11', 'wasm', 'arm64', 'cc', 'cpu', 'cuda', 'gpu', 'metal', 'opengl', 'vulkan', 'extension', 'loop_config', - 'global_thread_idx', 'Tape', 'assume_in_range', 'block_local', - 'cache_read_only', 'clear_all_gradients', 'init', 'mesh_local', - 'no_activate', 'reset', 'mesh_patch_idx' + 'global_thread_idx', 'assume_in_range', 'block_local', 'cache_read_only', + 'init', 'mesh_local', 'no_activate', 'reset', 'mesh_patch_idx' ] diff --git a/python/taichi/lang/tape.py b/python/taichi/lang/tape.py deleted file mode 100644 index c9101d306cbd6..0000000000000 --- a/python/taichi/lang/tape.py +++ /dev/null @@ -1,28 +0,0 @@ -class TapeImpl: - def __init__(self, runtime, loss=None): - self.calls = [] - self.entered = False - self.gradient_evaluated = False - self.runtime = runtime - self.eval_on_exit = loss is not None - - def __enter__(self): - self.runtime.target_tape = self - assert not self.entered, "Tape can be entered only once." - self.entered = True - - def __exit__(self, _type, value, tb): - # print('# kernel calls', len(self.calls)) - self.runtime.target_tape = None - if self.eval_on_exit: - self.grad() - - def insert(self, func, args): - self.calls.append((func, args)) - - def grad(self): - assert self.entered, "Before evaluating gradients tape must be entered." - assert not self.gradient_evaluated, "Gradients of grad can be evaluated only once." - for func, args in reversed(self.calls): - func.grad(*args) - self.gradient_evaluated = True diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 67def75173424..eece804c0b960 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -65,8 +65,8 @@ def _get_expected_matrix_apis(): 'SNode', 'ScalarField', 'ScalarNdarray', 'Struct', 'StructField', 'TRACE', 'TaichiAssertionError', 'TaichiCompilationError', 'TaichiNameError', 'TaichiRuntimeError', 'TaichiRuntimeTypeError', 'TaichiSyntaxError', - 'TaichiTypeError', 'Tape', 'TetMesh', 'TriMesh', 'Vector', 'VectorNdarray', - 'WARN', 'abs', 'acos', 'activate', 'ad', 'aot', 'append', 'arm64', 'asin', + 'TaichiTypeError', 'TetMesh', 'TriMesh', 'Vector', 'VectorNdarray', 'WARN', + 'abs', 'acos', 'activate', 'ad', 'aot', 'append', 'arm64', 'asin', 'assume_in_range', 'atan2', 'atomic_add', 'atomic_and', 'atomic_max', 'atomic_min', 'atomic_or', 'atomic_sub', 'atomic_xor', 'axes', 'bit_cast', 'bit_shr', 'block_local', 'cache_read_only', 'cast', 'cc', 'ceil',