Skip to content

Commit

Permalink
[refactor] [autodiff] Refactor autodiff api and add corresponding tes…
Browse files Browse the repository at this point in the history
…ts (#5175)

- Make ad a submodule instead of a single python file.
- Deprecate ti.clear_all_gradients, use ti.ad.clear_all_gradients instead
- Add api test for ad module
  • Loading branch information
erizmr authored Jun 15, 2022
1 parent d0d70ae commit d976ddd
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 20 deletions.
4 changes: 2 additions & 2 deletions python/taichi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from taichi._lib import core as _ti_core
from taichi._logging import *
from taichi._snode import *
from taichi.ad import clear_all_gradients
from taichi.lang import * # pylint: disable=W0622 # TODO(archibate): It's `taichi.lang.core` overriding `taichi.core`
from taichi.types.annotations import *
# Provide a shortcut to types since they're commonly used.
Expand Down Expand Up @@ -40,7 +39,8 @@
'imwrite': 'tools.imwrite',
'ext_arr': 'types.ndarray',
'any_arr': 'types.ndarray',
'Tape': 'ad.Tape'
'Tape': 'ad.Tape',
'clear_all_gradients': 'ad.clear_all_gradients'
}

__customized_deprecations__ = {
Expand Down
1 change: 1 addition & 0 deletions python/taichi/ad/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from taichi.ad._ad import *
6 changes: 6 additions & 0 deletions python/taichi/ad.py → python/taichi/ad/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,9 @@ def allocate_dual(x, dual_root):
x._set_grad(x_dual, reverse_mode=False)
x._get_field_members()[0].ptr.set_dual(x_dual._get_field_members()[0].ptr)
dual_root.dense(impl.index_nd(dim), shape).place(x_dual)


__all__ = [
'FwdMode', 'Tape', 'clear_all_gradients', 'grad_for', 'grad_replaced',
'no_grad'
]
40 changes: 22 additions & 18 deletions tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,28 @@ def _get_expected_matrix_apis():
'abs', 'acos', 'activate', 'ad', 'aot', 'append', 'arm64', 'asin',
'assume_in_range', 'atan2', 'atomic_add', 'atomic_and', 'atomic_max',
'atomic_min', 'atomic_or', 'atomic_sub', 'atomic_xor', 'axes', 'bit_cast',
'bit_shr', 'block_local', 'cache_read_only', 'cast', 'cc', 'ceil',
'clear_all_gradients', 'cos', 'cpu', 'cuda', 'data_oriented', 'deactivate',
'deactivate_all_snodes', 'dx11', 'eig', 'exp', 'experimental', 'extension',
'f16', 'f32', 'f64', 'field', 'float16', 'float32', 'float64', 'floor',
'func', 'get_addr', 'global_thread_idx', 'gpu', 'graph', 'grouped',
'hex_to_rgb', 'i', 'i16', 'i32', 'i64', 'i8', 'ij', 'ijk', 'ijkl', 'ijl',
'ik', 'ikl', 'il', 'init', 'int16', 'int32', 'int64', 'int8', 'is_active',
'is_logging_effective', 'j', 'jk', 'jkl', 'jl', 'k', 'kernel', 'kl', 'l',
'lang', 'length', 'linalg', 'log', 'loop_config', 'math', 'max',
'mesh_local', 'mesh_patch_idx', 'metal', 'min', 'ndarray', 'ndrange',
'no_activate', 'one', 'opengl', 'polar_decompose', 'pow', 'profiler',
'randn', 'random', 'raw_div', 'raw_mod', 'ref', 'rescale_index', 'reset',
'rgb_to_hex', 'root', 'round', 'rsqrt', 'select', 'set_logging_level',
'simt', 'sin', 'solve', 'sparse_matrix_builder', 'sqrt', 'static',
'static_assert', 'static_print', 'stop_grad', 'struct_class', 'svd',
'swizzle_generator', 'sym_eig', 'sync', 'tan', 'tanh', 'template', 'tools',
'types', 'u16', 'u32', 'u64', 'u8', 'ui', 'uint16', 'uint32', 'uint64',
'uint8', 'vulkan', 'wasm', 'x64', 'x86_64', 'zero'
'bit_shr', 'block_local', 'cache_read_only', 'cast', 'cc', 'ceil', 'cos',
'cpu', 'cuda', 'data_oriented', 'deactivate', 'deactivate_all_snodes',
'dx11', 'eig', 'exp', 'experimental', 'extension', 'f16', 'f32', 'f64',
'field', 'float16', 'float32', 'float64', 'floor', 'func', 'get_addr',
'global_thread_idx', 'gpu', 'graph', 'grouped', 'hex_to_rgb', 'i', 'i16',
'i32', 'i64', 'i8', 'ij', 'ijk', 'ijkl', 'ijl', 'ik', 'ikl', 'il', 'init',
'int16', 'int32', 'int64', 'int8', 'is_active', 'is_logging_effective',
'j', 'jk', 'jkl', 'jl', 'k', 'kernel', 'kl', 'l', 'lang', 'length',
'linalg', 'log', 'loop_config', 'math', 'max', 'mesh_local',
'mesh_patch_idx', 'metal', 'min', 'ndarray', 'ndrange', 'no_activate',
'one', 'opengl', 'polar_decompose', 'pow', 'profiler', 'randn', 'random',
'raw_div', 'raw_mod', 'ref', 'rescale_index', 'reset', 'rgb_to_hex',
'root', 'round', 'rsqrt', 'select', 'set_logging_level', 'simt', 'sin',
'solve', 'sparse_matrix_builder', 'sqrt', 'static', 'static_assert',
'static_print', 'stop_grad', 'struct_class', 'svd', 'swizzle_generator',
'sym_eig', 'sync', 'tan', 'tanh', 'template', 'tools', 'types', 'u16',
'u32', 'u64', 'u8', 'ui', 'uint16', 'uint32', 'uint64', 'uint8', 'vulkan',
'wasm', 'x64', 'x86_64', 'zero'
]
user_api[ti.ad] = [
'FwdMode', 'Tape', 'clear_all_gradients', 'grad_for', 'grad_replaced',
'no_grad'
]
user_api[ti.Field] = [
'copy_from', 'dtype', 'fill', 'from_numpy', 'from_paddle', 'from_torch',
Expand Down

0 comments on commit d976ddd

Please sign in to comment.