diff --git a/python/taichi/__init__.py b/python/taichi/__init__.py index 605bf1fbf5d45..8870837b8d8c3 100644 --- a/python/taichi/__init__.py +++ b/python/taichi/__init__.py @@ -4,7 +4,6 @@ from taichi._lib import core as _ti_core from taichi._logging import * from taichi._snode import * -from taichi.ad import clear_all_gradients from taichi.lang import * # pylint: disable=W0622 # TODO(archibate): It's `taichi.lang.core` overriding `taichi.core` from taichi.types.annotations import * # Provide a shortcut to types since they're commonly used. @@ -40,7 +39,8 @@ 'imwrite': 'tools.imwrite', 'ext_arr': 'types.ndarray', 'any_arr': 'types.ndarray', - 'Tape': 'ad.Tape' + 'Tape': 'ad.Tape', + 'clear_all_gradients': 'ad.clear_all_gradients' } __customized_deprecations__ = { diff --git a/python/taichi/ad/__init__.py b/python/taichi/ad/__init__.py new file mode 100644 index 0000000000000..058736c34901e --- /dev/null +++ b/python/taichi/ad/__init__.py @@ -0,0 +1 @@ +from taichi.ad._ad import * diff --git a/python/taichi/ad.py b/python/taichi/ad/_ad.py similarity index 99% rename from python/taichi/ad.py rename to python/taichi/ad/_ad.py index e641dcbc0dda3..9145a263e7ce2 100644 --- a/python/taichi/ad.py +++ b/python/taichi/ad/_ad.py @@ -320,3 +320,9 @@ def allocate_dual(x, dual_root): x._set_grad(x_dual, reverse_mode=False) x._get_field_members()[0].ptr.set_dual(x_dual._get_field_members()[0].ptr) dual_root.dense(impl.index_nd(dim), shape).place(x_dual) + + +__all__ = [ + 'FwdMode', 'Tape', 'clear_all_gradients', 'grad_for', 'grad_replaced', + 'no_grad' +] diff --git a/tests/python/test_api.py b/tests/python/test_api.py index eece804c0b960..f13d73070cd15 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -69,24 +69,28 @@ def _get_expected_matrix_apis(): 'abs', 'acos', 'activate', 'ad', 'aot', 'append', 'arm64', 'asin', 'assume_in_range', 'atan2', 'atomic_add', 'atomic_and', 'atomic_max', 'atomic_min', 'atomic_or', 'atomic_sub', 'atomic_xor', 'axes', 'bit_cast', - 'bit_shr', 'block_local', 'cache_read_only', 'cast', 'cc', 'ceil', - 'clear_all_gradients', 'cos', 'cpu', 'cuda', 'data_oriented', 'deactivate', - 'deactivate_all_snodes', 'dx11', 'eig', 'exp', 'experimental', 'extension', - 'f16', 'f32', 'f64', 'field', 'float16', 'float32', 'float64', 'floor', - 'func', 'get_addr', 'global_thread_idx', 'gpu', 'graph', 'grouped', - 'hex_to_rgb', 'i', 'i16', 'i32', 'i64', 'i8', 'ij', 'ijk', 'ijkl', 'ijl', - 'ik', 'ikl', 'il', 'init', 'int16', 'int32', 'int64', 'int8', 'is_active', - 'is_logging_effective', 'j', 'jk', 'jkl', 'jl', 'k', 'kernel', 'kl', 'l', - 'lang', 'length', 'linalg', 'log', 'loop_config', 'math', 'max', - 'mesh_local', 'mesh_patch_idx', 'metal', 'min', 'ndarray', 'ndrange', - 'no_activate', 'one', 'opengl', 'polar_decompose', 'pow', 'profiler', - 'randn', 'random', 'raw_div', 'raw_mod', 'ref', 'rescale_index', 'reset', - 'rgb_to_hex', 'root', 'round', 'rsqrt', 'select', 'set_logging_level', - 'simt', 'sin', 'solve', 'sparse_matrix_builder', 'sqrt', 'static', - 'static_assert', 'static_print', 'stop_grad', 'struct_class', 'svd', - 'swizzle_generator', 'sym_eig', 'sync', 'tan', 'tanh', 'template', 'tools', - 'types', 'u16', 'u32', 'u64', 'u8', 'ui', 'uint16', 'uint32', 'uint64', - 'uint8', 'vulkan', 'wasm', 'x64', 'x86_64', 'zero' + 'bit_shr', 'block_local', 'cache_read_only', 'cast', 'cc', 'ceil', 'cos', + 'cpu', 'cuda', 'data_oriented', 'deactivate', 'deactivate_all_snodes', + 'dx11', 'eig', 'exp', 'experimental', 'extension', 'f16', 'f32', 'f64', + 'field', 'float16', 'float32', 'float64', 'floor', 'func', 'get_addr', + 'global_thread_idx', 'gpu', 'graph', 'grouped', 'hex_to_rgb', 'i', 'i16', + 'i32', 'i64', 'i8', 'ij', 'ijk', 'ijkl', 'ijl', 'ik', 'ikl', 'il', 'init', + 'int16', 'int32', 'int64', 'int8', 'is_active', 'is_logging_effective', + 'j', 'jk', 'jkl', 'jl', 'k', 'kernel', 'kl', 'l', 'lang', 'length', + 'linalg', 'log', 'loop_config', 'math', 'max', 'mesh_local', + 'mesh_patch_idx', 'metal', 'min', 'ndarray', 'ndrange', 'no_activate', + 'one', 'opengl', 'polar_decompose', 'pow', 'profiler', 'randn', 'random', + 'raw_div', 'raw_mod', 'ref', 'rescale_index', 'reset', 'rgb_to_hex', + 'root', 'round', 'rsqrt', 'select', 'set_logging_level', 'simt', 'sin', + 'solve', 'sparse_matrix_builder', 'sqrt', 'static', 'static_assert', + 'static_print', 'stop_grad', 'struct_class', 'svd', 'swizzle_generator', + 'sym_eig', 'sync', 'tan', 'tanh', 'template', 'tools', 'types', 'u16', + 'u32', 'u64', 'u8', 'ui', 'uint16', 'uint32', 'uint64', 'uint8', 'vulkan', + 'wasm', 'x64', 'x86_64', 'zero' +] +user_api[ti.ad] = [ + 'FwdMode', 'Tape', 'clear_all_gradients', 'grad_for', 'grad_replaced', + 'no_grad' ] user_api[ti.Field] = [ 'copy_from', 'dtype', 'fill', 'from_numpy', 'from_paddle', 'from_torch',