Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: better internal error handling #17

Merged
merged 4 commits into from
Nov 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ project(optree LANGUAGES CXX)
include(FetchContent)
set(PYBIND11_VERSION v2.10.1)
set(ABSEIL_CPP_VERSION 20220623.1)
set(THIRD_PARTY_DIR "${CMAKE_SOURCE_DIR}/third-party")

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
Expand All @@ -31,15 +32,28 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) # -fPIC
set(CMAKE_CXX_VISIBILITY_PRESET hidden) # -fvisibility=hidden

if(MSVC)
string(APPEND CMAKE_CXX_FLAGS " /Wall")
string(
APPEND CMAKE_CXX_FLAGS
" /Zc:preprocessor"
" /Wall"
" /WX /wd4365 /wd4514 /wd4710 /wd4711 /wd4820 /wd4868 /wd5045"
" /experimental:external /external:anglebrackets /external:W0"
)
string(APPEND CMAKE_CXX_FLAGS_DEBUG " /Zi")
string(APPEND CMAKE_CXX_FLAGS_RELEASE " /O2 /Ob2")
else()
string(APPEND CMAKE_CXX_FLAGS " -Wall")
string(
APPEND CMAKE_CXX_FLAGS
" -Wall -Wextra"
" -Werror -Wno-attributes"
)
string(APPEND CMAKE_CXX_FLAGS_DEBUG " -g -Og")
string(APPEND CMAKE_CXX_FLAGS_RELEASE " -O3")
endif()

string(LENGTH "${CMAKE_SOURCE_DIR}/" SOURCE_PATH_PREFIX_SIZE)
add_definitions("-DSOURCE_PATH_PREFIX_SIZE=${SOURCE_PATH_PREFIX_SIZE}")

function(system)
set(options STRIP)
set(oneValueArgs OUTPUT_VARIABLE ERROR_VARIABLE WORKING_DIRECTORY)
Expand Down Expand Up @@ -130,11 +144,12 @@ if("${PYBIND11_CMAKE_DIR}" STREQUAL "")
GIT_REPOSITORY https://github.com/pybind/pybind11.git
GIT_TAG "${PYBIND11_VERSION}"
GIT_SHALLOW TRUE
SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/pybind11"
BINARY_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/build"
STAMP_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/pybind11/stamp"
SOURCE_DIR "${THIRD_PARTY_DIR}/pybind11"
BINARY_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/build"
STAMP_DIR "${THIRD_PARTY_DIR}/.cmake/pybind11/stamp"
)
FetchContent_GetProperties(pybind11)

if(NOT pybind11_POPULATED)
message(STATUS "Populating Git repository pybind11@${PYBIND11_VERSION} to third-party/pybind11...")
FetchContent_MakeAvailable(pybind11)
Expand All @@ -152,11 +167,12 @@ FetchContent_Declare(
GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git
GIT_TAG "${ABSEIL_CPP_VERSION}"
GIT_SHALLOW TRUE
SOURCE_DIR "${CMAKE_SOURCE_DIR}/third-party/abseil-cpp"
BINARY_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/abseil-cpp/build"
STAMP_DIR "${CMAKE_SOURCE_DIR}/third-party/.cmake/abseil-cpp/stamp"
SOURCE_DIR "${THIRD_PARTY_DIR}/abseil-cpp"
BINARY_DIR "${THIRD_PARTY_DIR}/.cmake/abseil-cpp/build"
STAMP_DIR "${THIRD_PARTY_DIR}/.cmake/abseil-cpp/stamp"
)
FetchContent_GetProperties(abseilcpp)

if(NOT abseilcpp_POPULATED)
message(STATUS "Populating Git repository abseil-cpp@${ABSEIL_CPP_VERSION} to third-party/abseil-cpp...")
FetchContent_MakeAvailable(abseilcpp)
Expand Down
107 changes: 107 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ The `metadata` is some necessary data apart from the children to reconstruct the

The `entries` can be omitted (only returns a pair) or is optional to implement (returns `None`). If so, use `range(len(children))` (i.e., flat indices) as path entries of the current node. The function signature can be `flatten_func(container) -> (children, metadata)` or `flatten_func(container) -> (children, metadata, None)`.

The following examples show how to register custom types and utilize them for `tree_flatten` and `tree_map`. Please refer to section [Notes about the PyTree Type Registry](#notes-about-the-pytree-type-registry) for more information.

```python
# Registry a Python type with lambda functions
register_pytree_node(
Expand Down Expand Up @@ -292,6 +294,111 @@ There are several key attributes of the pytree type registry:
)
```

5. **Be careful about the potential infinite recursion of the custom flatten function.** The returned `children` from the custom flatten function are considered subtrees. They will be further flattened recursively. The `children` can have the same type as the current node. Users must design their termination condition carefully.

```python
import numpy as np
import torch

optree.register_pytree_node(
np.ndarray,
# Children are nest lists of Python objects
lambda array: (np.atleast_1d(array).tolist(), array.ndim == 0),
lambda scalar, rows: np.asarray(rows) if not scalar else np.asarray(rows[0]),
namespace='numpy1',
)

optree.register_pytree_node(
np.ndarray,
# Children are Python objects
lambda array: (
list(array.ravel()), # list(NDArray[T]) -> List[T]
dict(shape=array.shape, dtype=array.dtype)
),
lambda metadata, children: np.asarray(children, dtype=metadata['dtype']).reshape(metadata['shape']),
namespace='numpy2',
)

optree.register_pytree_node(
np.ndarray,
# Returns a list of `np.ndarray`s without termination condition
lambda array: ([array.ravel()], array.dtype),
lambda shape, children: children[0].reshape(shape),
namespace='numpy3',
)

optree.register_pytree_node(
torch.Tensor,
# Children are nest lists of Python objects
lambda tensor: (torch.atleast_1d(tensor).tolist(), tensor.ndim == 0),
lambda scalar, rows: torch.tensor(rows) if not scalar else torch.tensor(rows[0])),
namespace='torch1',
)

optree.register_pytree_node(
torch.Tensor,
# Returns a list of `torch.Tensor`s without termination condition
lambda tensor: (
list(tensor.view(-1)), # list(NDTensor[T]) -> List[0DTensor[T]] (STILL TENSORS!)
tensor.shape
),
lambda shape, children: torch.stack(children).reshape(shape),
namespace='torch2',
)
```

```python
>>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy1')
(
[0, 1, 2, 3, 4, 5, 6, 7, 8],
PyTreeSpec(
CustomTreeNode(ndarray[False], [[*, *, *], [*, *, *], [*, *, *]]),
namespace='numpy1'
)
)
# Implicitly casts `float`s to `np.float64`
>>> optree.tree_map(lambda x: x + 1.5, np.arange(9).reshape(3, 3), namespace='numpy1')
array([[1.5, 2.5, 3.5],
[4.5, 5.5, 6.5],
[7.5, 8.5, 9.5]])

>>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy2')
(
[0, 1, 2, 3, 4, 5, 6, 7, 8],
PyTreeSpec(
CustomTreeNode(ndarray[{'shape': (3, 3), 'dtype': dtype('int64')}], [*, *, *, *, *, *, *, *, *]),
namespace='numpy2'
)
)
# Explicitly casts `float`s to `np.int64`
>>> optree.tree_map(lambda x: x + 1.5, np.arange(9).reshape(3, 3), namespace='numpy2')
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])

# Children are also `np.ndarray`s, recurse without termination condition.
>>> optree.tree_flatten(np.arange(9).reshape(3, 3), namespace='numpy3')
RecursionError: maximum recursion depth exceeded during flattening the tree

>>> optree.tree_flatten(torch.arange(9).reshape(3, 3), namespace='torch1')
(
[0, 1, 2, 3, 4, 5, 6, 7, 8],
PyTreeSpec(
CustomTreeNode(Tensor[False], [[*, *, *], [*, *, *], [*, *, *]]),
namespace='torch1'
)
)
# Implicitly casts `float`s to `torch.float32`
>>> optree.tree_map(lambda x: x + 1.5, torch.arange(9).reshape(3, 3), namespace='torch1')
tensor([[1.5000, 2.5000, 3.5000],
[4.5000, 5.5000, 6.5000],
[7.5000, 8.5000, 9.5000]])

# Children are also `torch.Tensor`s, recurse without termination condition.
>>> optree.tree_flatten(torch.arange(9).reshape(3, 3), namespace='torch2')
RecursionError: maximum recursion depth exceeded during flattening the tree
```

### `None` is Non-leaf Node vs. `None` is Leaf

The [`None`](https://docs.python.org/3/library/constants.html#None) object is a special object in the Python language.
Expand Down
6 changes: 6 additions & 0 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ class PyTreeTypeRegistry {
template <bool NoneIsLeaf>
static PyTreeTypeRegistry *Singleton();

template <bool NoneIsLeaf>
static void RegisterImpl(const py::object &cls,
const py::function &to_iterable,
const py::function &from_iterable,
const std::string &registry_namespace);

class TypeHash {
public:
using is_transparent = void;
Expand Down
4 changes: 4 additions & 0 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ limitations under the License.
namespace optree {

// The maximum depth of a pytree.
#ifdef _WIN32
constexpr ssize_t MAX_RECURSION_DEPTH = 4000;
#else
constexpr ssize_t MAX_RECURSION_DEPTH = 10000;
#endif

// A PyTreeSpec describes the tree structure of a PyTree. A PyTree is a tree of Python values, where
// the interior nodes are tuples, lists, dictionaries, or user-defined containers, and the leaves
Expand Down
43 changes: 37 additions & 6 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,43 @@ limitations under the License.
#include <utility>
#include <vector>

#define CHECK(condition) \
if (!(condition)) [[likely]] \
throw std::runtime_error(std::string(#condition) + " failed at " __FILE__ + ':' + \
std::to_string(__LINE__))
#ifndef SOURCE_PATH_PREFIX_SIZE
#define SOURCE_PATH_PREFIX_SIZE 0
#endif
#ifndef __FILENAME__
#define __FILENAME__ (&(__FILE__[SOURCE_PATH_PREFIX_SIZE]))
#endif

#define VFUNC2(__0, __1, NAME, ...) NAME
#define VFUNC3(__0, __1, __2, NAME, ...) NAME

#define INTERNAL_ERROR1(message) \
throw std::logic_error(absl::StrFormat("%s (at file %s:%lu)", message, __FILENAME__, __LINE__))
#define INTERNAL_ERROR0() INTERNAL_ERROR1("Unreachable code.")
#define INTERNAL_ERROR(...) /* NOLINTNEXTLINE[whitespace/parens] */ \
VFUNC2(__0 __VA_OPT__(, ) __VA_ARGS__, INTERNAL_ERROR1, INTERNAL_ERROR0)(__VA_ARGS__)

#define EXPECT2(condition, message) \
if (!(condition)) [[unlikely]] { \
INTERNAL_ERROR1(message); \
}
#define EXPECT0() INTERNAL_ERROR0()
#define EXPECT1(condition) EXPECT2(condition, "`" #condition "` failed.")
#define EXPECT(...) /* NOLINTNEXTLINE[whitespace/parens] */ \
VFUNC3(__0 __VA_OPT__(, ) __VA_ARGS__, EXPECT2, EXPECT1, EXPECT0)(__VA_ARGS__)

#define DCHECK(condition) CHECK(condition)
#define EXPECT_EQ(a, b, ...) \
EXPECT((a) == (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]
#define EXPECT_NE(a, b, ...) \
EXPECT((a) != (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]
#define EXPECT_LT(a, b, ...) \
EXPECT((a) < (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]
#define EXPECT_LE(a, b, ...) \
EXPECT((a) <= (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]
#define EXPECT_GT(a, b, ...) \
EXPECT((a) > (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]
#define EXPECT_GE(a, b, ...) \
EXPECT((a) >= (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens]

#define NONE_IS_LEAF true
#define NONE_IS_NODE false
Expand Down Expand Up @@ -306,7 +337,7 @@ inline void SET_ITEM<py::list>(const py::handle& container,
template <typename PyType>
inline void AssertExact(const py::handle& object) {
if (!py::isinstance<PyType>(object)) [[unlikely]] {
throw std::runtime_error(absl::StrFormat(
throw std::invalid_argument(absl::StrFormat(
"Expected an instance of %s, got %s.", typeid(PyType).name(), py::repr(object)));
}
}
Expand Down
21 changes: 9 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@


class CMakeExtension(Extension):
def __init__(self, name, source_dir='.', **kwargs):
def __init__(self, name, source_dir='.', target=None, **kwargs):
super().__init__(name, sources=[], **kwargs)
self.source_dir = os.path.abspath(source_dir)
self.target = target if target is not None else name.rpartition('.')[-1]


class cmake_build_ext(build_ext):
Expand All @@ -39,18 +40,15 @@ def build_extension(self, ext):
if cmake is None:
raise RuntimeError('Cannot find CMake executable.')

build_temp = pathlib.Path(self.build_temp)
ext_path = pathlib.Path(self.get_ext_fullpath(ext.name)).absolute()
build_temp = pathlib.Path(self.build_temp).absolute()
build_temp.mkdir(parents=True, exist_ok=True)

config = 'Debug' if self.debug else 'Release'

extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
print(self.get_ext_fullpath(ext.name))

cmake_args = [
f'-DCMAKE_BUILD_TYPE={config}',
f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={extdir}',
f'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{config.upper()}={self.build_temp}',
f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={ext_path.parent}',
f'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{config.upper()}={build_temp}',
f'-DPYTHON_EXECUTABLE={sys.executable}',
f'-DPYTHON_INCLUDE_DIR={sysconfig.get_path("platinclude")}',
]
Expand All @@ -64,14 +62,11 @@ def build_extension(self, ext):
try:
import pybind11

cmake_args.append(
f'-DPYBIND11_CMAKE_DIR={pybind11.get_cmake_dir()}',
)
cmake_args.append(f'-DPYBIND11_CMAKE_DIR={pybind11.get_cmake_dir()}')
except ImportError:
pass

build_args = ['--config', config]

if (
'CMAKE_BUILD_PARALLEL_LEVEL' not in os.environ
and hasattr(self, 'parallel')
Expand All @@ -81,6 +76,8 @@ def build_extension(self, ext):
else:
build_args.append('--parallel')

build_args.extend([f'--target={ext.target}', '--'])

try:
os.chdir(build_temp)
self.spawn(['cmake', ext.source_dir] + cmake_args)
Expand Down
Loading