Skip to content

Commit

Permalink
feat: better internal error handling (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Nov 27, 2022
1 parent a4272c2 commit 9fda41e
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 194 deletions.
32 changes: 24 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ project(optree LANGUAGES CXX)
include(FetchContent)
set(PYBIND11_VERSION v2.10.1)
set(ABSEIL_CPP_VERSION 20220623.1)
set(THIRD_PARTY_DIR "${CMAKE_SOURCE_DIR}/third-party")

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
Expand All @@ -31,15 +32,28 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) # -fPIC
set(CMAKE_CXX_VISIBILITY_PRESET hidden) # -fvisibility=hidden

if(MSVC)
string(APPEND CMAKE_CXX_FLAGS " /Wall")
string(
APPEND CMAKE_CXX_FLAGS
" /Zc:preprocessor"
" /Wall"
" /WX /wd4365 /wd4514 /wd4710 /wd4711 /wd4820 /wd4868 /wd5045"
" /experimental:external /external:anglebrackets /external:W0"
)
string(APPEND CMAKE_CXX_FLAGS_DEBUG " /Zi")
string(APPEND CMAKE_CXX_FLAGS_RELEASE " /O2 /Ob2")
else()
string(APPEND CMAKE_CXX_FLAGS " -Wall")
string(
APPEND CMAKE_CXX_FLAGS
" -Wall -Wextra"
" -Werror -Wno-attributes"
)
string(APPEND CMAKE_CXX_FLAGS_DEBUG " -g -Og")
string(APPEND CMAKE_CXX_FLAGS_RELEASE " -O3")
endif()

string(LENGTH "${CMAKE_SOURCE_DIR}/" SOURCE_PATH_PREFIX_SIZE)
add_definitions("-DSOURCE_PATH_PREFIX_SIZE=${SOURCE_PATH_PREFIX_SIZE}")

function(system)
set(options STRIP)
set(oneValueArgs OUTPUT_VARIABLE ERROR_VARIABLE WORKING_DIRECTORY)
Expand Down Expand Up @@ -130,11 +144,12 @@ if("${PYBIND11_CMAKE_DIR}" STREQUAL "")
GIT_REPOSITORY https://github.com/pybind/pybind11.git
GIT_TAG "${PYBIND11_VERSION}"
GIT_SHALLOW TRUE
SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/pybind11"
BINARY_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/build"
STAMP_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/stamp"
SOURCE_DIR "${THIRD_PARTY_DIR}/pybind11"
BINARY_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/build"
STAMP_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/stamp"
)
FetchContent_GetProperties(pybind11)

if(NOT pybind11_POPULATED)
message(STATUS "Populating Git repository pybind11@${PYBIND11_VERSION} to third-party/pybind11...")
FetchContent_MakeAvailable(pybind11)
Expand All @@ -152,11 +167,12 @@ FetchContent_Declare(
GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git
GIT_TAG "${ABSEIL_CPP_VERSION}"
GIT_SHALLOW TRUE
SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/abseil-cpp"
BINARY_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/abseil-cpp/build"
STAMP_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/abseil-cpp/stamp"
SOURCE_DIR "${THIRD_PARTY_DIR}/abseil-cpp"
BINARY_DIR "${THIRD_PARTY_DIR}/.cmake/abseil-cpp/build"
STAMP_DIR "${THIRD_PARTY_DIR}/.cmake/abseil-cpp/stamp"
)
FetchContent_GetProperties(abseilcpp)

if(NOT abseilcpp_POPULATED)
message(STATUS "Populating Git repository abseil-cpp@${ABSEIL_CPP_VERSION} to third-party/abseil-cpp...")
FetchContent_MakeAvailable(abseilcpp)
Expand Down
107 changes: 107 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ The `metadata` is some necessary data apart from the children to reconstruct the

The `entries` can be omitted (only returns a pair) or is optional to implement (returns `None`). If so, use `range(len(children))` (i.e., flat indices) as path entries of the current node. The function signature can be `flatten_func(container) -> (children, metadata)` or `flatten_func(container) -> (children, metadata, None)`.

The following examples show how to register custom types and utilize them for `tree_flatten` and `tree_map`. Please refer to section [Notes about the PyTree Type Registry](#notes-about-the-pytree-type-registry) for more information.

```python
# Registry a Python type with lambda functions
register_pytree_node(
Expand Down Expand Up @@ -292,6 +294,111 @@ There are several key attributes of the pytree type registry:
)
```

5. **Be careful about the potential infinite recursion of the custom flatten function.** The returned `children` from the custom flatten function are considered subtrees. They will be further flattened recursively. The `children` can have the same type as the current node. Users must design their termination condition carefully.

```python
import numpy as np
import torch

optree.register_pytree_node(
np.ndarray,
# Children are nest lists of Python objects
lambda array: (np.atleast_1d(array).tolist(), array.ndim == 0),
lambda scalar, rows: np.asarray(rows) if not scalar else np.asarray(rows[0]),
namespace='numpy1',
)

optree.register_pytree_node(
np.ndarray,
# Children are Python objects
lambda array: (
list(array.ravel()), # list(NDArray[T]) -> List[T]
dict(shape=array.shape, dtype=array.dtype)
),
lambda metadata, children: np.asarray(children, dtype=metadata['dtype']).reshape(metadata['shape']),
namespace='numpy2',
)

optree.register_pytree_node(
np.ndarray,
# Returns a list of `np.ndarray`s without termination condition
lambda array: ([array.ravel()], array.dtype),
lambda shape, children: children[0].reshape(shape),
namespace='numpy3',
)

optree.register_pytree_node(
torch.Tensor,
# Children are nest lists of Python objects
lambda tensor: (torch.atleast_1d(tensor).tolist(), tensor.ndim == 0),
lambda scalar, rows: torch.tensor(rows) if not scalar else torch.tensor(rows[0])),
namespace='torch1',
)

optree.register_pytree_node(
torch.Tensor,
# Returns a list of `torch.Tensor`s without termination condition
lambda tensor: (
list(tensor.view(-1)), # list(NDTensor[T]) -> List[0DTensor[T]] (STILL TENSORS!)
tensor.shape
),
lambda shape, children: torch.stack(children).reshape(shape),
namespace='torch2',
)
```

```python
>>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy1')
(
[0, 1, 2, 3, 4, 5, 6, 7, 8],
PyTreeSpec(
CustomTreeNode(ndarray[False], [[*, *, *], [*, *, *], [*, *, *]]),
namespace='numpy1'
)
)
# Implicitly casts `float`s to `np.float64`
>>> optree.tree_map(lambda x: x + 1.5, np.arange(9).reshape(3, 3), namespace='numpy1')
array([[1.5, 2.5, 3.5],
[4.5, 5.5, 6.5],
[7.5, 8.5, 9.5]])

>>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy2')
(
[0, 1, 2, 3, 4, 5, 6, 7, 8],
PyTreeSpec(
CustomTreeNode(ndarray[{'shape': (3, 3), 'dtype': dtype('int64')}], [*, *, *, *, *, *, *, *, *]),
namespace='numpy2'
)
)
# Explicitly casts `float`s to `np.int64`
>>> optree.tree_map(lambda x: x + 1.5, np.arange(9).reshape(3, 3), namespace='numpy2')
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])

# Children are also `np.ndarray`s, recurse without termination condition.
>>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy3')
RecursionError: maximum recursion depth exceeded during flattening the tree

>>> optree.tree_flatten(torch.arange(9).reshape(3, 3), namespace='torch1')
(
[0, 1, 2, 3, 4, 5, 6, 7, 8],
PyTreeSpec(
CustomTreeNode(Tensor[False], [[*, *, *], [*, *, *], [*, *, *]]),
namespace='torch1'
)
)
# Implicitly casts `float`s to `torch.float32`
>>> optree.tree_map(lambda x: x + 1.5, torch.arange(9).reshape(3, 3), namespace='torch1')
tensor([[1.5000, 2.5000, 3.5000],
[4.5000, 5.5000, 6.5000],
[7.5000, 8.5000, 9.5000]])

# Children are also `torch.Tensor`s, recurse without termination condition.
>>> optree.tree_flatten(torch.arange(9).reshape(3, 3), namespace='torch2')
RecursionError: maximum recursion depth exceeded during flattening the tree
```

### `None` is Non-leaf Node vs. `None` is Leaf

The [`None`](https://docs.python.org/3/library/constants.html#None) object is a special object in the Python language.
Expand Down
6 changes: 6 additions & 0 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ class PyTreeTypeRegistry {
template <bool NoneIsLeaf>
static PyTreeTypeRegistry *Singleton();

template <bool NoneIsLeaf>
static void RegisterImpl(const py::object &cls,
const py::function &to_iterable,
const py::function &from_iterable,
const std::string &registry_namespace);

class TypeHash {
public:
using is_transparent = void;
Expand Down
4 changes: 4 additions & 0 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ limitations under the License.
namespace optree {

// The maximum depth of a pytree.
#ifdef _WIN32
constexpr ssize_t MAX_RECURSION_DEPTH = 4000;
#else
constexpr ssize_t MAX_RECURSION_DEPTH = 10000;
#endif

// A PyTreeSpec describes the tree structure of a PyTree. A PyTree is a tree of Python values, where
// the interior nodes are tuples, lists, dictionaries, or user-defined containers, and the leaves
Expand Down
43 changes: 37 additions & 6 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,43 @@ limitations under the License.
#include <utility>
#include <vector>

#define CHECK(condition) \
if (!(condition)) [[likely]] \
throw std::runtime_error(std::string(#condition) + " failed at " __FILE__ + ':' + \
std::to_string(__LINE__))
#ifndef SOURCE_PATH_PREFIX_SIZE
#define SOURCE_PATH_PREFIX_SIZE 0
#endif
#ifndef __FILENAME__
#define __FILENAME__ (&(__FILE__[SOURCE_PATH_PREFIX_SIZE]))
#endif

#define VFUNC2(__0, __1, NAME, ...) NAME
#define VFUNC3(__0, __1, __2, NAME, ...) NAME

#define INTERNAL_ERROR1(message) \
throw std::logic_error(absl::StrFormat("%s (at file %s:%lu)", message, __FILENAME__, __LINE__))
#define INTERNAL_ERROR0() INTERNAL_ERROR1("Unreachable code.")
#define INTERNAL_ERROR(...) /* NOLINTNEXTLINE[whitespace/parens] */ \
VFUNC2(__0 __VA_OPT__(, ) __VA_ARGS__, INTERNAL_ERROR1, INTERNAL_ERROR0)(__VA_ARGS__)

#define EXPECT2(condition, message) \
if (!(condition)) [[unlikely]] { \
INTERNAL_ERROR1(message); \
}
#define EXPECT0() INTERNAL_ERROR0()
#define EXPECT1(condition) EXPECT2(condition, "`" #condition "` failed.")
#define EXPECT(...) /* NOLINTNEXTLINE[whitespace/parens] */ \
VFUNC3(__0 __VA_OPT__(, ) __VA_ARGS__, EXPECT2, EXPECT1, EXPECT0)(__VA_ARGS__)

#define DCHECK(condition) CHECK(condition)
#define EXPECT_EQ(a, b, ...) \
EXPECT((a) == (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]
#define EXPECT_NE(a, b, ...) \
EXPECT((a) != (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]
#define EXPECT_LT(a, b, ...) \
EXPECT((a) < (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]
#define EXPECT_LE(a, b, ...) \
EXPECT((a) <= (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]
#define EXPECT_GT(a, b, ...) \
EXPECT((a) > (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]
#define EXPECT_GE(a, b, ...) \
EXPECT((a) >= (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]

#define NONE_IS_LEAF true
#define NONE_IS_NODE false
Expand Down Expand Up @@ -306,7 +337,7 @@ inline void SET_ITEM<py::list>(const py::handle& container,
template <typename PyType>
inline void AssertExact(const py::handle& object) {
if (!py::isinstance<PyType>(object)) [[unlikely]] {
throw std::runtime_error(absl::StrFormat(
throw std::invalid_argument(absl::StrFormat(
"Expected an instance of %s, got %s.", typeid(PyType).name(), py::repr(object)));
}
}
Expand Down
21 changes: 9 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@


class CMakeExtension(Extension):
def __init__(self, name, source_dir='.', **kwargs):
def __init__(self, name, source_dir='.', target=None, **kwargs):
super().__init__(name, sources=[], **kwargs)
self.source_dir = os.path.abspath(source_dir)
self.target = target if target is not None else name.rpartition('.')[-1]


class cmake_build_ext(build_ext):
Expand All @@ -39,18 +40,15 @@ def build_extension(self, ext):
if cmake is None:
raise RuntimeError('Cannot find CMake executable.')

build_temp = pathlib.Path(self.build_temp)
ext_path = pathlib.Path(self.get_ext_fullpath(ext.name)).absolute()
build_temp = pathlib.Path(self.build_temp).absolute()
build_temp.mkdir(parents=True, exist_ok=True)

config = 'Debug' if self.debug else 'Release'

extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
print(self.get_ext_fullpath(ext.name))

cmake_args = [
f'-DCMAKE_BUILD_TYPE={config}',
f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={extdir}',
f'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{config.upper()}={self.build_temp}',
f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={ext_path.parent}',
f'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{config.upper()}={build_temp}',
f'-DPYTHON_EXECUTABLE={sys.executable}',
f'-DPYTHON_INCLUDE_DIR={sysconfig.get_path("platinclude")}',
]
Expand All @@ -64,14 +62,11 @@ def build_extension(self, ext):
try:
import pybind11

cmake_args.append(
f'-DPYBIND11_CMAKE_DIR={pybind11.get_cmake_dir()}',
)
cmake_args.append(f'-DPYBIND11_CMAKE_DIR={pybind11.get_cmake_dir()}')
except ImportError:
pass

build_args = ['--config', config]

if (
'CMAKE_BUILD_PARALLEL_LEVEL' not in os.environ
and hasattr(self, 'parallel')
Expand All @@ -81,6 +76,8 @@ def build_extension(self, ext):
else:
build_args.append('--parallel')

build_args.extend([f'--target={ext.target}', '--'])

try:
os.chdir(build_temp)
self.spawn(['cmake', ext.source_dir] + cmake_args)
Expand Down
Loading

0 comments on commit 9fda41e

Please sign in to comment.