Skip to content

Commit

Permalink
Handle forward, reverse preprocess cache collision (rust-lang#398)
Browse files Browse the repository at this point in the history
* Loose type analysis for insertvalue

* Fix forward/reverse preprocess cache collision
  • Loading branch information
wsmoses authored Dec 15, 2021
1 parent 875022f commit 394992e
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 14 deletions.
24 changes: 18 additions & 6 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1530,12 +1530,24 @@ class AdjointGenerator
7) /
8;

Type *flt = nullptr;
if (!gutils->isConstantValue(orig_inserted) &&
(flt = TR.intType(size0, orig_inserted).isFloat())) {
auto prediff = diffe(&IVI, Builder2);
auto dindex = Builder2.CreateExtractValue(prediff, IVI.getIndices());
addToDiffe(orig_inserted, dindex, Builder2, flt);
if (!gutils->isConstantValue(orig_inserted)) {
auto it =
TR.intType(size0, orig_inserted, /*errIfFalse*/ !looseTypeAnalysis);
Type *flt = it.isFloat();
if (!it.isKnown()) {
assert(looseTypeAnalysis);
if (orig_inserted->getType()->isFPOrFPVectorTy())
flt = orig_inserted->getType()->getScalarType();
else if (orig_inserted->getType()->isIntOrIntVectorTy())
flt = nullptr;
else
TR.intType(size0, orig_inserted);
}
if (flt) {
auto prediff = diffe(&IVI, Builder2);
auto dindex = Builder2.CreateExtractValue(prediff, IVI.getIndices());
addToDiffe(orig_inserted, dindex, Builder2, flt);
}
}

size_t size1 = 1;
Expand Down
19 changes: 12 additions & 7 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ IsFunctionRecursive(Function *F,
static inline bool OnlyUsedInOMP(AllocaInst *AI) {
bool ompUse = false;
for (auto U : AI->users()) {
if (isa<StoreInst>(U))
continue;
if (auto SI = dyn_cast<StoreInst>(U))
if (SI->getPointerOperand() == AI)
continue;
if (auto CI = dyn_cast<CallInst>(U)) {
if (auto F = CI->getCalledFunction()) {
if (F->getName() == "__kmpc_for_static_init_4" ||
Expand Down Expand Up @@ -873,12 +874,16 @@ PreProcessCache::getAAResultsFromFunction(llvm::Function *NewF) {
Function *PreProcessCache::preprocessForClone(Function *F,
DerivativeMode mode) {

if (mode == DerivativeMode::ReverseModeGradient)
mode = DerivativeMode::ReverseModePrimal;
if (mode == DerivativeMode::ForwardModeVector ||
mode == DerivativeMode::ForwardModeSplit)
mode = DerivativeMode::ForwardMode;

// If we've already processed this, return the previous version
// and derive aliasing information
if (cache.find(std::make_pair(
F, mode == DerivativeMode::ReverseModeCombined)) != cache.end()) {
Function *NewF =
cache[std::make_pair(F, mode == DerivativeMode::ReverseModeCombined)];
if (cache.find(std::make_pair(F, mode)) != cache.end()) {
Function *NewF = cache[std::make_pair(F, mode)];
return NewF;
}

Expand Down Expand Up @@ -1546,7 +1551,7 @@ Function *PreProcessCache::preprocessForClone(Function *F,
llvm::errs() << *NewF << "\n";
report_fatal_error("function failed verification (1)");
}
cache[std::make_pair(F, mode == DerivativeMode::ReverseModeCombined)] = NewF;
cache[std::make_pair(F, mode)] = NewF;
return NewF;
}

Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/FunctionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class PreProcessCache {
llvm::FunctionAnalysisManager FAM;
llvm::ModuleAnalysisManager MAM;

std::map<std::pair<llvm::Function *, bool>, llvm::Function *> cache;
std::map<std::pair<llvm::Function *, DerivativeMode>, llvm::Function *> cache;
std::map<llvm::Function *, llvm::Function *> CloneOrigin;

llvm::Function *preprocessForClone(llvm::Function *F, DerivativeMode mode);
Expand Down
89 changes: 89 additions & 0 deletions enzyme/test/Integration/ForwardMode/fwdandrev.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// RUN: %clang++ -std=c++11 -fno-exceptions -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -

#include <cmath>

template <typename T, int... n>
struct tensor;

template <typename T, int n>
struct tensor<T, n> {

T& operator[](int i) { return value[0]; };

T value[n];
};

template <typename T, int first, int... rest>
struct tensor<T, first, rest...> {

tensor<T, rest...>& operator[](int i) { return value[0]; };

tensor<T, rest...> value[first];
};

int enzyme_dup;
int enzyme_out;
int enzyme_const;

template<typename return_type, typename... Args>
return_type __enzyme_autodiff(Args...);

template<typename return_type, typename... Args>
return_type __enzyme_fwddiff(Args...);

extern "C" {
__attribute__((noinline))
constexpr double ptr(double* A) {
return A[0];
}
}

template <int n>
__attribute__((noinline))
constexpr tensor<double, 1, 1> pdev(const tensor<double, 1, 1>& A) {
auto devA = A;
auto trA = ptr((double*)&A);
devA[0][0] -= trA;
devA[0][0] -= trA;
devA[0][0] -= trA;
return devA;
}

extern "C" {
double mystress_calculation(void* __restrict__ D, const tensor<double, 1, 1> & __restrict__ du_dx) {
auto devB = pdev<2>(du_dx);

return 2 * devB[0][0];
}
}

int main(int argc, char * argv[]) {
tensor<double, 1, 1> dudx = {{{0.0}}};
tensor<double, 1, 1> ddudxi = {{{0.0}}};

tensor<double, 1, 1> gradient{};
tensor<double, 1, 1> sigma{};
tensor<double, 1, 1> dir{};

dir[0][0] = 1;
// Forward pass of gradient can segfault if forward and reverse preprocess
// functions collide in cache.
for (int i = 0; i < 2; i++)
{

__enzyme_autodiff<void>(mystress_calculation,
&enzyme_const, nullptr,
&enzyme_dup, &dudx, &gradient);
}

__enzyme_fwddiff<void>(mystress_calculation,
&enzyme_const, nullptr,
&enzyme_dup, &dudx, &ddudxi);
}

0 comments on commit 394992e

Please sign in to comment.