Skip to content

Commit

Permalink
Add support for differentiating calls to mem fns in the reverse mode
Browse files Browse the repository at this point in the history
This commit adds support for differentiating calls to member functions
in the reverse mode AD.

This commit introduces a breaking change in the signature of derived member
functions. Earlier we were considering data members to be constant while differentiating
member functions. This is no longer the case, and thus now we need a separate variable
to store derivative of the function value with respect to implicit `this` pointer. To allow
users to obtain derivatives with respect to the implicit `this` object, derived member function
prototype has been changed to include a derivative parameter for the implicit `this` pointer.

For example:

```cpp
struct SomeStruct {
  double memFn(double i, long double j) {...}
};
```

This function will have the gradient function as follows:

```cpp
void memFn(double i, long double j, clad::array_ref<SomeStruct> _d_this, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {...}
```
  • Loading branch information
parth-07 committed Mar 23, 2022
1 parent 19baccd commit e22cc9e
Show file tree
Hide file tree
Showing 21 changed files with 902 additions and 264 deletions.
17 changes: 14 additions & 3 deletions include/clad/Differentiator/ArrayRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,24 @@ namespace clad {
public:
// delete the default constructor
array_ref() = delete;
CUDA_HOST_DEVICE array_ref(void* arr, std::size_t size = 1)
: m_arr(arr), m_size(size) {}
// Here we are using C-style cast instead of `static_cast` because
// we may also need to remove qualifiers (`const`, `volatile`, etc) while
// converting to `void*` type.
// We cannot create specialisation of `array_ref<void>` with qualifiers
// (such as `array_ref<const void>`, `array_ref<volatile void>` etc) because
// each derivative parameter has to be of the same type in the overloaded
// gradient for the overloaded gradient mechanism to work and this class is
// used as the placeholder type for the common derivative parameter type.
template <typename T, class = typename std::enable_if<
std::is_pointer<T>::value ||
std::is_same<T, std::nullptr_t>::value>::type>
CUDA_HOST_DEVICE array_ref(T arr, std::size_t size = 1)
: m_arr((void*)arr), m_size(size) {}
template <typename T>
CUDA_HOST_DEVICE array_ref(const array_ref<T>& other)
: m_arr(other.ptr()), m_size(other.size()) {}
template <typename T> CUDA_HOST_DEVICE operator array_ref<T>() {
return array_ref<T>(static_cast<T*>(m_arr), m_size);
return array_ref<T>((T*)(m_arr), m_size);
}
CUDA_HOST_DEVICE void* ptr() const { return m_arr; }
CUDA_HOST_DEVICE std::size_t size() const { return m_size; }
Expand Down
6 changes: 3 additions & 3 deletions include/clad/Differentiator/CladConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ void trap(int code) {

#ifdef __CUDACC__
template<typename T>
__device__ T* addressof(T& r) {
__device__ T* clad_addressof(T& r) {
return __builtin_addressof(r);
}
template<typename T>
__host__ T* addressof(T& r) {
__host__ T* clad_addressof(T& r) {
return std::addressof(r);
}
#else
template<typename T>
T* addressof(T& r) {
T* clad_addressof(T& r) {
return std::addressof(r);
}
#endif
Expand Down
11 changes: 9 additions & 2 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ namespace clad {
clang::DeclarationNameInfo BuildDeclarationNameInfo(clang::Sema& S,
llvm::StringRef name);

/// Returns true if the function has any reference or pointer parameter;
/// otherwise returns false.
bool HasAnyReferenceOrPointerArgument(const clang::FunctionDecl* FD);

/// Returns true if `T` is a reference, pointer or array type.
Expand Down Expand Up @@ -196,8 +198,9 @@ namespace clad {
clang::TypeSourceInfo* TSI = nullptr);

/// If `T` represents an array or a pointer type then returns the
/// corresponding array element or the pointee type. Otherwise, if `T` is
/// neither an array nor a pointer type, then simply returns `T`.
/// corresponding array element or the pointee type. If `T` is a reference
/// type then return the corresponding non-reference type. Otherwise, if `T`
/// is neither an array nor a pointer type, then simply returns `T`.
clang::QualType GetValueType(clang::QualType T);

/// Builds and returns the init expression to initialise `clad::array` and
Expand All @@ -207,6 +210,10 @@ namespace clad {
/// `{arr, arrSize}`
clang::Expr* BuildCladArrayInitByConstArray(clang::Sema& semaRef,
clang::Expr* constArrE);

/// Returns true if `FD` is a class static method; otherwise returns
/// false.
bool IsStaticMethod(const clang::FunctionDecl* FD);
} // namespace utils
}

Expand Down
17 changes: 17 additions & 0 deletions include/clad/Differentiator/Compatibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,5 +577,22 @@ static inline Expr* GetSubExpr(const MaterializeTemporaryExpr* MTE) {
#else
#define CLAD_COMPAT_IS_LIST_INITIALIZATION_PARAM(E) , E->isListInitialization()
#endif

#if CLANG_VERSION_MAJOR < 9
static inline QualType
CXXMethodDecl_GetThisObjectType(Sema& semaRef, const CXXMethodDecl* MD) {
ASTContext& C = semaRef.getASTContext();
const CXXRecordDecl* RD = MD->getParent();
auto RDType = RD->getTypeForDecl();
auto thisObjectQType = C.getQualifiedType(
RDType, clad_compat::CXXMethodDecl_getMethodQualifiers(MD));
return thisObjectQType;
}
#else
static inline QualType
CXXMethodDecl_GetThisObjectType(Sema& semaRef, const CXXMethodDecl* MD) {
return MD->getThisObjectType();
}
#endif
} // namespace clad_compat
#endif //CLAD_COMPATIBILITY
4 changes: 2 additions & 2 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ namespace clad {
typename std::enable_if<EnablePadding, bool>::type = true>
CUDA_HOST_DEVICE return_type_t<F>
execute_with_default_args(list<Rest...>, F f, Args&&... args) {
return f(static_cast<Args>(args)..., static_cast<Rest>(0)...);
return f(static_cast<Args>(args)..., static_cast<Rest>(nullptr)...);
}

template <bool EnablePadding, class... Rest, class F, class... Args,
Expand All @@ -122,7 +122,7 @@ namespace clad {
Args&&... args)
-> return_type_t<decltype(f)> {
return (static_cast<Obj>(obj).*f)(static_cast<Args>(args)...,
static_cast<Rest>(0)...);
static_cast<Rest>(nullptr)...);
}

template <bool EnablePadding, class... Rest, class ReturnType, class C,
Expand Down
5 changes: 3 additions & 2 deletions include/clad/Differentiator/FunctionTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,9 @@ namespace clad {
#define GradientDerivedFnTraits_AddSPECS(var, cv, vol, ref, noex) \
template <typename R, typename C, typename... Args> \
struct GradientDerivedFnTraits<R (C::*)(Args...) cv vol ref noex> { \
using type = void (C::*)(Args..., \
OutputParamType_t<Args, void>...) cv vol ref noex; \
using type = \
void (C::*)(Args..., OutputParamType_t<C, void>, \
OutputParamType_t<Args, void>...) cv vol ref noex; \
};

#if __cpp_noexcept_function_type > 0
Expand Down
5 changes: 5 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ namespace clad {
unsigned outputArrayCursor = 0;
unsigned numParams = 0;
bool isVectorValued = false;
/// Stores derivative expression of the implicit `this` pointer.
///
/// \note `this` pointer derivative expression is always of the class object
/// type rather than the pointer type.
clang::Expr* m_ThisExprDerivative = nullptr;
// FIXME: Should we make this an object instead of a pointer?
// Downside of making it an object: We will need to include
// 'MultiplexExternalRMVSource.h' file
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/Tape.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ namespace clad {
// allocation properly.
for (; first != last; ++first, (void)++current) {
::new (const_cast<void*>(
static_cast<const volatile void*>(addressof(*current))))
static_cast<const volatile void*>(clad_addressof(*current))))
T(std::move(*first));
}
}
Expand Down
6 changes: 3 additions & 3 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,9 @@ namespace clad {
/// \returns Built member function call expression
/// Base.MemberFunction(ArgExprs) or Base->MemberFunction(ArgExprs)
clang::Expr*
BuildCallExprToMemFn(clang::Expr* Base, bool isArrow,
llvm::StringRef MemberFunctionName,
llvm::MutableArrayRef<clang::Expr*> ArgExprs);
BuildCallExprToMemFn(clang::Expr* Base, llvm::StringRef MemberFunctionName,
llvm::MutableArrayRef<clang::Expr*> ArgExprs,
clang::ValueDecl* memberDecl = nullptr);

/// Build a call to member function through this pointer.
///
Expand Down
41 changes: 37 additions & 4 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,31 @@ namespace clad {
DC2 = DC2->getParent();
continue;
}
// We don't want to 'extend' the DC1 context with class declarations.
// There are 2 main reasons for this:
// - Class declaration context cannot be extended the way namespace
// declaration contexts can.
//
// - Primary usage of `FindDeclarationContext` is to create the correct
// declaration context for searching some particular custom derivative.
// But for class functions, we would 'need' to create custom derivative
// in the original declaration context only (We don't support custom
// derivatives for class functions yet). We cannot use the default
// context that starts from `clad::custom_derivatives::`. This is
// because custom derivatives of class functions need to have same
// access permissions as the original member function.
//
// We may need to change this behaviour if in the future
// `FindDeclarationContext` function is being used for some place other
// than finding custom derivative declaration context as well.
//
// Silently return nullptr if DC2 contains any CXXRecord declaration
// context.
if (isa<CXXRecordDecl>(DC2))
return nullptr;
assert(isa<NamespaceDecl>(DC2) &&
"DC2 should only contain namespace (and "
"translation unit) declaration.");
"DC2 should only consists of namespace, CXXRecord and "
"translation unit declaration context.");
contexts.push_back(DC2);
DC2 = DC2->getParent();
}
Expand Down Expand Up @@ -237,10 +259,14 @@ namespace clad {

clang::QualType GetValueType(clang::QualType T) {
QualType valueType = T;
if (isArrayOrPointerType(T)) {
if (T->isPointerType())
valueType = T->getPointeeType();
else if (T->isReferenceType())
valueType = T.getNonReferenceType();
// FIXME: `QualType::getPointeeOrArrayElementType` loses type qualifiers.
else if (T->isArrayType())
valueType =
T->getPointeeOrArrayElementType()->getCanonicalTypeInternal();
}
return valueType;
}

Expand All @@ -255,5 +281,12 @@ namespace clad {
llvm::SmallVector<Expr*, 2> args = {constArrE, sizeE};
return semaRef.ActOnInitList(noLoc, args, noLoc).get();
}

bool IsStaticMethod(const clang::FunctionDecl* FD) {
if (auto MD = dyn_cast<CXXMethodDecl>(FD)) {
return MD->isStatic();
}
return false;
}
} // namespace utils
} // namespace clad
27 changes: 27 additions & 0 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,33 @@ namespace clad {
});
DeclRefToParams.pop_back();

/// If we are differentiating a member function then create a parameter
/// that can represent the derivative for the implicit `this` pointer. It
/// is required because reverse mode derived function expects an explicit
/// parameter for storing derivative with respect to `implicit` this
/// object.
///
// FIXME: Add support for class type in the hessian matrix. For this, we
// need to add a way to represent hessian matrix when class type objects
// are involved.
if (auto MD = dyn_cast<CXXMethodDecl>(m_Function)) {
const CXXRecordDecl* RD = MD->getParent();
if (MD->isInstance() && !RD->isLambda()) {
QualType thisObjectType =
clad_compat::CXXMethodDecl_GetThisObjectType(m_Sema, MD);
// Derivatives should never be of `const` types. Even if the original
// variable is of `const` type.
thisObjectType.removeLocalConst();
auto dThisVD = BuildVarDecl(thisObjectType, "_d_this",
/*Init=*/nullptr, false, /*TSI=*/nullptr,
VarDecl::InitializationStyle::CallInit);
CompStmtSave.push_back(BuildDeclStmt(dThisVD));
Expr* dThisExpr = BuildDeclRef(dThisVD);
DeclRefToParams.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, dThisExpr));
}
}

size_t columnIndex = 0;
// Create Expr parameters for each independent arg in the CallExpr
for (size_t j = 0, n = IndependentArgsSize.size(); j < n; j++) {
Expand Down
Loading

0 comments on commit e22cc9e

Please sign in to comment.