From e32d11610b5fa4ac450a8da9a91b596adffe3145 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 4 Mar 2022 00:11:15 -0500 Subject: [PATCH] Initial Forward Split Mode (#539) * Begin forward split * Add atomicadd test * Starting to function split fwd * Start tests * Modref * Get most of fwd split working * memmove * Fix test * Adjust tests for higher llvm --- enzyme/Enzyme/AdjointGenerator.h | 229 ++++---- enzyme/Enzyme/CApi.cpp | 25 +- enzyme/Enzyme/CApi.h | 12 +- enzyme/Enzyme/Enzyme.cpp | 181 ++++-- enzyme/Enzyme/EnzymeLogic.cpp | 532 ++++++++++-------- enzyme/Enzyme/EnzymeLogic.h | 14 +- enzyme/Enzyme/FunctionUtils.cpp | 2 +- enzyme/Enzyme/GradientUtils.cpp | 208 ++++--- enzyme/Enzyme/GradientUtils.h | 41 +- enzyme/Enzyme/Utils.h | 27 +- enzyme/test/Enzyme/CMakeLists.txt | 3 +- .../Enzyme/ForwardModeSplit/CMakeLists.txt | 12 + .../Enzyme/ForwardModeSplit/Faddeeva_erf.ll | 53 ++ .../Enzyme/ForwardModeSplit/Faddeeva_erfc.ll | 53 ++ .../Enzyme/ForwardModeSplit/Faddeeva_erfi.ll | 51 ++ enzyme/test/Enzyme/ForwardModeSplit/add.ll | 29 + .../test/Enzyme/ForwardModeSplit/addOneMem.ll | 46 ++ .../test/Enzyme/ForwardModeSplit/badcall.ll | 65 +++ .../test/Enzyme/ForwardModeSplit/badcall2.ll | 82 +++ .../test/Enzyme/ForwardModeSplit/badcall3.ll | 83 +++ .../test/Enzyme/ForwardModeSplit/badcall4.ll | 83 +++ .../test/Enzyme/ForwardModeSplit/badcallsq.ll | 71 +++ .../Enzyme/ForwardModeSplit/badcallused.ll | 67 +++ .../Enzyme/ForwardModeSplit/badcallused2.ll | 84 +++ .../test/Enzyme/ForwardModeSplit/bitcast.ll | 22 + .../test/Enzyme/ForwardModeSplit/bsearch.ll | 77 +++ .../test/Enzyme/ForwardModeSplit/bsearch2.ll | 87 +++ enzyme/test/Enzyme/ForwardModeSplit/call.ll | 53 ++ .../ForwardModeSplit/callmincacheunwrap.ll | 99 ++++ .../test/Enzyme/ForwardModeSplit/callmod.ll | 116 ++++ enzyme/test/Enzyme/ForwardModeSplit/calloc.ll | 46 ++ .../test/Enzyme/ForwardModeSplit/constant.ll | 22 + .../Enzyme/ForwardModeSplit/constselect.ll | 49 ++ enzyme/test/Enzyme/ForwardModeSplit/cos.ll | 30 + enzyme/test/Enzyme/ForwardModeSplit/cosh.ll | 28 + .../test/Enzyme/ForwardModeSplit/custom0.ll | 55 ++ .../test/Enzyme/ForwardModeSplit/custom1.ll | 60 ++ .../test/Enzyme/ForwardModeSplit/custom2.ll | 44 ++ enzyme/test/Enzyme/ForwardModeSplit/div.ll | 28 + .../test/Enzyme/ForwardModeSplit/divreduce.ll | 122 ++++ .../Enzyme/ForwardModeSplit/divreduce2.ll | 80 +++ .../ForwardModeSplit/enzyme_inactive.ll | 27 + .../ForwardModeSplit/enzyme_inactive2.ll | 27 + enzyme/test/Enzyme/ForwardModeSplit/erf.ll | 29 + enzyme/test/Enzyme/ForwardModeSplit/erfc.ll | 29 + enzyme/test/Enzyme/ForwardModeSplit/erfi.ll | 28 + enzyme/test/Enzyme/ForwardModeSplit/exp.ll | 26 + enzyme/test/Enzyme/ForwardModeSplit/exp2.ll | 27 + .../experimental_vector_reduce_v2_fadd.ll | 26 + enzyme/test/Enzyme/ForwardModeSplit/fabs.ll | 29 + enzyme/test/Enzyme/ForwardModeSplit/fneg.ll | 31 + enzyme/test/Enzyme/ForwardModeSplit/fpext.ll | 22 + enzyme/test/Enzyme/ForwardModeSplit/ge.ll | 96 ++++ enzyme/test/Enzyme/ForwardModeSplit/global.ll | 95 ++++ .../test/Enzyme/ForwardModeSplit/globalfn.ll | 96 ++++ .../Enzyme/ForwardModeSplit/globallower.ll | 42 ++ .../Enzyme/ForwardModeSplit/insertvalue.ll | 31 + enzyme/test/Enzyme/ForwardModeSplit/intsum.ll | 55 ++ .../Enzyme/ForwardModeSplit/invertselect.ll | 33 ++ enzyme/test/Enzyme/ForwardModeSplit/log.ll | 26 + enzyme/test/Enzyme/ForwardModeSplit/log10.ll | 27 + enzyme/test/Enzyme/ForwardModeSplit/log2.ll | 27 + .../Enzyme/ForwardModeSplit/maskedload.ll | 30 + .../Enzyme/ForwardModeSplit/maskedstore.ll | 30 + enzyme/test/Enzyme/ForwardModeSplit/max.ll | 28 + .../ForwardModeSplit/maxnum-inactive.ll | 29 + enzyme/test/Enzyme/ForwardModeSplit/maxnum.ll | 29 + .../Enzyme/ForwardModeSplit/memcpy-flt.ll | 37 ++ .../ForwardModeSplit/memcpy-intstruct.ll | 38 ++ .../Enzyme/ForwardModeSplit/memcpy-ptr.ll | 35 ++ enzyme/test/Enzyme/ForwardModeSplit/minnum.ll | 25 + enzyme/test/Enzyme/ForwardModeSplit/mul.ll | 26 + .../Enzyme/ForwardModeSplit/negbithack.ll | 26 + .../Enzyme/ForwardModeSplit/negbithack2.ll | 26 + .../Enzyme/ForwardModeSplit/negbithack3.ll | 34 ++ enzyme/test/Enzyme/ForwardModeSplit/pow.ll | 35 ++ enzyme/test/Enzyme/ForwardModeSplit/powi13.ll | 34 ++ .../test/Enzyme/ForwardModeSplit/ptr-ret.ll | 63 +++ enzyme/test/Enzyme/ForwardModeSplit/relu.ll | 62 ++ enzyme/test/Enzyme/ForwardModeSplit/rwloop.ll | 166 ++++++ enzyme/test/Enzyme/ForwardModeSplit/sin.ll | 29 + enzyme/test/Enzyme/ForwardModeSplit/sqrelu.ll | 72 +++ enzyme/test/Enzyme/ForwardModeSplit/sqrt.ll | 29 + enzyme/test/Enzyme/ForwardModeSplit/square.ll | 32 ++ .../test/Enzyme/ForwardModeSplit/square2.ll | 92 +++ .../Enzyme/ForwardModeSplit/square_array.ll | 34 ++ enzyme/test/Enzyme/ForwardModeSplit/sret.ll | 100 ++++ enzyme/test/Enzyme/ForwardModeSplit/sret12.ll | 88 +++ enzyme/test/Enzyme/ForwardModeSplit/store3.ll | 29 + .../Enzyme/ForwardModeSplit/storeconstexpr.ll | 24 + enzyme/test/Enzyme/ForwardModeSplit/sub.ll | 24 + .../test/Enzyme/ForwardModeSplit/sumnlist.ll | 130 +++++ .../test/Enzyme/ForwardModeSplit/sumsimple.ll | 64 +++ .../ForwardModeSplit/sumsimpleoptnone.ll | 64 +++ .../test/Enzyme/ForwardModeSplit/sumsquare.ll | 60 ++ .../Enzyme/ForwardModeSplit/sumwithbreak.ll | 76 +++ .../test/Enzyme/ForwardModeSplit/vecsquare.ll | 43 ++ .../ForwardModeSplit/vector_reduce_fadd.ll | 26 + enzyme/test/Enzyme/ReverseMode/atomicadd.ll | 33 ++ enzyme/test/Enzyme/ReverseMode/bitcastfn.ll | 3 +- enzyme/test/Enzyme/ReverseMode/callvalue.ll | 4 +- enzyme/test/Enzyme/ReverseMode/phiswitch.ll | 10 +- 102 files changed, 5214 insertions(+), 505 deletions(-) create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/CMakeLists.txt create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erf.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erfc.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erfi.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/add.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/addOneMem.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/badcall.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/badcall2.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/badcall3.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/badcall4.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/badcallsq.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/badcallused.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/badcallused2.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/bitcast.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/bsearch.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/bsearch2.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/call.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/callmincacheunwrap.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/callmod.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/calloc.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/constant.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/constselect.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/cos.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/cosh.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/custom0.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/custom1.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/custom2.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/div.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/divreduce.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/divreduce2.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/enzyme_inactive.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/enzyme_inactive2.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/erf.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/erfc.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/erfi.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/exp.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/exp2.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/experimental_vector_reduce_v2_fadd.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/fabs.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/fneg.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/fpext.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/ge.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/global.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/globalfn.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/globallower.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/insertvalue.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/intsum.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/invertselect.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/log.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/log10.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/log2.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/maskedload.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/maskedstore.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/max.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/maxnum-inactive.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/maxnum.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/memcpy-flt.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/memcpy-intstruct.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/memcpy-ptr.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/minnum.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/mul.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/negbithack.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/negbithack2.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/negbithack3.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/pow.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/powi13.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/ptr-ret.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/relu.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/rwloop.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/sin.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/sqrelu.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/sqrt.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/square.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/square2.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/square_array.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/sret.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/sret12.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/store3.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/storeconstexpr.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/sub.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/sumnlist.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/sumsimple.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/sumsimpleoptnone.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/sumsquare.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/sumwithbreak.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/vecsquare.ll create mode 100644 enzyme/test/Enzyme/ForwardModeSplit/vector_reduce_fadd.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/atomicadd.ll diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 0a46c8b69ac75..870ab32b6733c 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -489,7 +489,6 @@ class AdjointGenerator } break; } - case DerivativeMode::ForwardModeSplit: case DerivativeMode::ForwardMode: { newip = gutils->invertPointerM(&I, BuilderZ); assert(newip->getType() == type); @@ -499,6 +498,7 @@ class AdjointGenerator (const Value *)&I, InvertedPointerVH(gutils, newip))); break; } + case DerivativeMode::ForwardModeSplit: case DerivativeMode::ReverseModeGradient: { if (!needShadow) { gutils->erase(placeholder); @@ -528,11 +528,6 @@ class AdjointGenerator } } - // Allow forcing cache reads to be on or off using flags. - assert(!(cache_reads_always && cache_reads_never) && - "Both cache_reads_always and cache_reads_never are true. This " - "doesn't make sense."); - Value *inst = newi; //! Store loads that need to be cached for use in reverse pass @@ -540,27 +535,30 @@ class AdjointGenerator // Only cache value here if caching decision isn't precomputed. // Otherwise caching will be done inside EnzymeLogic.cpp at // the end of the function jointly. - if ((Mode != DerivativeMode::ForwardMode && - gutils->knownRecomputeHeuristic.count(&I) == 0 && - !gutils->unnecessaryIntermediates.count(&I) && can_modref && - !cache_reads_never) || - cache_reads_always) { + if (Mode != DerivativeMode::ForwardMode && + !gutils->knownRecomputeHeuristic.count(&I) && can_modref && + !gutils->unnecessaryIntermediates.count(&I)) { // we can pre initialize all the knownRecomputeHeuristic values to false // (not needing) as we may assume that minCutCache already preserves // everything it requires. std::map Seen; + bool primalNeededInReverse = false; for (auto pair : gutils->knownRecomputeHeuristic) - Seen[UsageKey(pair.first, ValueType::Primal)] = false; - bool primalNeededInReverse = - is_value_needed_in_reverse(TR, gutils, &I, Mode, - Seen, oldUnreachable); + if (!pair.second) { + Seen[UsageKey(pair.first, ValueType::Primal)] = false; + if (pair.first == &I) + primalNeededInReverse = true; + } + primalNeededInReverse |= is_value_needed_in_reverse( + TR, gutils, &I, Mode, Seen, oldUnreachable); if (primalNeededInReverse) { IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&I)); inst = gutils->cacheForReverse(BuilderZ, newi, getIndex(&I, CacheType::Self)); assert(inst->getType() == type); - if (Mode == DerivativeMode::ReverseModeGradient) { + if (Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ForwardModeSplit) { assert(inst != newi); } else { assert(inst == newi); @@ -2695,6 +2693,7 @@ class AdjointGenerator }; applyChainRule(Builder2, rule, ddst, dsrc); + eraseIfUnused(MTI); return; } @@ -4134,8 +4133,9 @@ class AdjointGenerator if (called) { subdata = &gutils->Logic.CreateAugmentedPrimal( cast(called), subretType, argsInverted, - TR.analyzer.interprocedural, /*return is used*/ false, nextTypeInfo, - uncacheable_args, false, /*AtomicAdd*/ true, + TR.analyzer.interprocedural, /*return is used*/ false, + /*shadowReturnUsed*/ false, nextTypeInfo, uncacheable_args, false, + /*AtomicAdd*/ true, /*OpenMP*/ true); if (Mode == DerivativeMode::ReverseModePrimal) { assert(augmentedReturn); @@ -7735,6 +7735,7 @@ class AdjointGenerator subretused = true; } } + bool shadowReturnUsed = false; DIFFE_TYPE subretType; if (gutils->isConstantValue(orig)) { @@ -7743,13 +7744,15 @@ class AdjointGenerator if (Mode == DerivativeMode::ForwardMode || Mode == DerivativeMode::ForwardModeSplit) { subretType = DIFFE_TYPE::DUP_ARG; + shadowReturnUsed = true; } else { if (!orig->getType()->isFPOrFPVectorTy() && TR.query(orig).Inner0().isPossiblePointer()) { if (is_value_needed_in_reverse( - TR, gutils, orig, Mode, oldUnreachable)) + TR, gutils, orig, Mode, oldUnreachable)) { subretType = DIFFE_TYPE::DUP_ARG; - else + shadowReturnUsed = true; + } else subretType = DIFFE_TYPE::CONSTANT; } else { subretType = DIFFE_TYPE::OUT_DIFF; @@ -8564,21 +8567,22 @@ class AdjointGenerator return; switch (Mode) { - default: - llvm_unreachable("unhandled mode"); case DerivativeMode::ReverseModePrimal: return; case DerivativeMode::ForwardMode: + case DerivativeMode::ForwardModeSplit: case DerivativeMode::ReverseModeGradient: case DerivativeMode::ReverseModeCombined: { IRBuilder<> Builder2(&call); - if (Mode == DerivativeMode::ForwardMode) + if (Mode == DerivativeMode::ForwardMode || + Mode == DerivativeMode::ForwardModeSplit) getForwardBuilder(Builder2); else getReverseBuilder(Builder2); Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); - if (Mode != DerivativeMode::ForwardMode) + if (Mode != DerivativeMode::ForwardMode && + Mode != DerivativeMode::ForwardModeSplit) x = lookup(x, Builder2); Value *sq; @@ -8657,8 +8661,8 @@ class AdjointGenerator cal = Builder2.CreateFMul(cal, factor); } - Value *dfactor = Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeSplit + Value *dfactor = (Mode == DerivativeMode::ForwardMode || + Mode == DerivativeMode::ForwardModeSplit) ? diffe(orig->getArgOperand(0), Builder2) : diffe(orig, Builder2); @@ -8694,7 +8698,8 @@ class AdjointGenerator cal = applyChainRule(call.getType(), Builder2, rule2, dfactor); } - if (Mode == DerivativeMode::ForwardMode) { + if (Mode == DerivativeMode::ForwardMode || + Mode == DerivativeMode::ForwardModeSplit) { setDiffe(orig, cal, Builder2); } else { setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); @@ -9264,7 +9269,8 @@ class AdjointGenerator if (Mode == DerivativeMode::ReverseModeCombined || Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModePrimal) { + Mode == DerivativeMode::ReverseModePrimal || + Mode == DerivativeMode::ForwardModeSplit) { Value *anti = placeholder; // If rematerializable allocations and split mode, we can @@ -9429,6 +9435,8 @@ class AdjointGenerator (Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) || (Mode == DerivativeMode::ReverseModeGradient && + backwardsShadow) || + (Mode == DerivativeMode::ForwardModeSplit && backwardsShadow)) { if (!inLoop) zeroKnownAllocation(bb, anti, args, *called, gutils->TLI); @@ -9564,7 +9572,8 @@ class AdjointGenerator // Otherwise if in reverse pass, free the newly created allocation. if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { + Mode == DerivativeMode::ReverseModeCombined || + Mode == DerivativeMode::ForwardModeSplit) { IRBuilder<> Builder2(call.getParent()); getReverseBuilder(Builder2); auto dbgLoc = gutils->getNewFromOriginal(orig->getDebugLoc()); @@ -9615,7 +9624,8 @@ class AdjointGenerator // as can we can only guarantee that we don't erase those frees. bool hasPDFree = gutils->allocationsWithGuaranteedFree.count(orig); if (!primalNeededInReverse && hasPDFree) { - if (Mode == DerivativeMode::ReverseModeGradient) { + if (Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ForwardModeSplit) { eraseIfUnused(*orig, /*erase*/ true, /*check*/ false); } else { if (auto MD = hasMetadata(orig, "enzyme_fromstack")) { @@ -9651,7 +9661,8 @@ class AdjointGenerator funcName == "jl_alloc_array_3d" || funcName == "jl_array_copy" || funcName == "julia.gc_alloc_obj") { if (!primalNeededInReverse) { - if (Mode == DerivativeMode::ReverseModeGradient) { + if (Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ForwardModeSplit) { auto pn = BuilderZ.CreatePHI( orig->getType(), 1, (orig->getName() + "_replacementJ").str()); gutils->fictiousPHIs[pn] = orig; @@ -9671,38 +9682,32 @@ class AdjointGenerator // TODO enable this if we need to free the memory // NOTE THAT TOPLEVEL IS THERE SIMPLY BECAUSE THAT WAS PREVIOUS ATTITUTE // TO FREE'ing - if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModePrimal) { - if ((primalNeededInReverse && - !gutils->unnecessaryIntermediates.count(orig)) || - hasPDFree) { - Value *nop = gutils->cacheForReverse(BuilderZ, newCall, - getIndex(orig, CacheType::Self)); - if (Mode == DerivativeMode::ReverseModeGradient && hasPDFree && - shouldFree()) { - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - auto dbgLoc = gutils->getNewFromOriginal(orig->getDebugLoc()); - freeKnownAllocation(Builder2, lookup(nop, Builder2), *called, - dbgLoc, gutils->TLI); - } - } else if (Mode == DerivativeMode::ReverseModeGradient || - Mode == DerivativeMode::ReverseModeCombined) { - // Note that here we cannot simply replace with null as users who - // try to find the shadow pointer will use the shadow of null rather - // than the true shadow of this - auto pn = BuilderZ.CreatePHI( - orig->getType(), 1, (orig->getName() + "_replacementB").str()); - gutils->fictiousPHIs[pn] = orig; - gutils->replaceAWithB(newCall, pn); - gutils->erase(newCall); - } - } else if (Mode == DerivativeMode::ReverseModeCombined && shouldFree()) { - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - auto dbgLoc = gutils->getNewFromOriginal(orig)->getDebugLoc(); - freeKnownAllocation(Builder2, lookup(newCall, Builder2), *called, - dbgLoc, gutils->TLI); + if ((primalNeededInReverse && + !gutils->unnecessaryIntermediates.count(orig)) || + hasPDFree) { + Value *nop = gutils->cacheForReverse(BuilderZ, newCall, + getIndex(orig, CacheType::Self)); + if (hasPDFree && + ((Mode == DerivativeMode::ReverseModeGradient && shouldFree()) || + Mode == DerivativeMode::ReverseModeCombined || + (Mode == DerivativeMode::ForwardModeSplit && shouldFree()))) { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + auto dbgLoc = gutils->getNewFromOriginal(orig->getDebugLoc()); + freeKnownAllocation(Builder2, lookup(nop, Builder2), *called, dbgLoc, + gutils->TLI); + } + } else if (Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ReverseModeCombined || + Mode == DerivativeMode::ForwardModeSplit) { + // Note that here we cannot simply replace with null as users who + // try to find the shadow pointer will use the shadow of null rather + // than the true shadow of this + auto pn = BuilderZ.CreatePHI(orig->getType(), 1, + (orig->getName() + "_replacementB").str()); + gutils->fictiousPHIs[pn] = orig; + gutils->replaceAWithB(newCall, pn); + gutils->erase(newCall); } return; @@ -9994,6 +9999,7 @@ class AdjointGenerator return; } } + // If we need this value and it is illegal to recompute it (it writes or // may load uncacheable data) // Store and reload it @@ -10014,7 +10020,8 @@ class AdjointGenerator // Any uses of it should be handled by the case above so it is safe to // RAUW if (orig->mayWriteToMemory() && - Mode == DerivativeMode::ReverseModeGradient) { + (Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ForwardModeSplit)) { eraseIfUnused(*orig, /*erase*/ true, /*check*/ false); return; } @@ -10036,7 +10043,20 @@ class AdjointGenerator nextTypeInfo = TR.getCallInfo(*orig, *called); } - if (Mode == DerivativeMode::ForwardMode) { + const AugmentedReturn *subdata = nullptr; + if (Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ForwardModeSplit) { + assert(augmentedReturn); + if (augmentedReturn) { + auto fd = augmentedReturn->subaugmentations.find(&call); + if (fd != augmentedReturn->subaugmentations.end()) { + subdata = fd->second; + } + } + } + + if (Mode == DerivativeMode::ForwardMode || + Mode == DerivativeMode::ForwardModeSplit) { IRBuilder<> Builder2(&call); getForwardBuilder(Builder2); @@ -10103,13 +10123,41 @@ class AdjointGenerator } } + Optional tapeIdx; + if (subdata) { + auto found = subdata->returns.find(AugmentedStruct::Tape); + if (found != subdata->returns.end()) { + tapeIdx = found->second; + } + } + Value *tape = nullptr; + if (tapeIdx.hasValue()) { + + FunctionType *FT = cast( + cast(subdata->fn->getType())->getElementType()); + + tape = BuilderZ.CreatePHI( + (tapeIdx == -1) ? FT->getReturnType() + : cast(FT->getReturnType()) + ->getElementType(tapeIdx.getValue()), + 1, "tapeArg"); + + assert(!tape->getType()->isEmptyTy()); + gutils->TapesToPreventRecomputation.insert(cast(tape)); + tape = gutils->cacheForReverse(BuilderZ, tape, + getIndex(orig, CacheType::Tape)); + args.push_back(tape); + } + Value *newcalled = nullptr; if (called) { newcalled = gutils->Logic.CreateForwardDiff( cast(called), subretType, argsInverted, TR.analyzer.interprocedural, /*returnValue*/ subretused, Mode, - gutils->getWidth(), nullptr, nextTypeInfo, {}); + ((DiffeGradientUtils *)gutils)->FreeMemory, gutils->getWidth(), + tape ? tape->getType() : nullptr, nextTypeInfo, {}, + /*augmented*/ subdata); } else { #if LLVM_VERSION_MAJOR >= 11 auto callval = orig->getCalledOperand(); @@ -10131,9 +10179,9 @@ class AdjointGenerator ? (retActive ? ReturnType::TwoReturns : ReturnType::Return) : (retActive ? ReturnType::Return : ReturnType::Void); - FunctionType *FTy = - getFunctionTypeForClone(ft, Mode, gutils->getWidth(), nullptr, - argsInverted, false, subretVal, subretType); + FunctionType *FTy = getFunctionTypeForClone( + ft, Mode, gutils->getWidth(), tape ? tape->getType() : nullptr, + argsInverted, false, subretVal, subretType); PointerType *fptype = PointerType::getUnqual(FTy); newcalled = BuilderZ.CreatePointerCast(newcalled, PointerType::getUnqual(fptype)); @@ -10350,17 +10398,6 @@ class AdjointGenerator Optional returnIdx; Optional differetIdx; - const AugmentedReturn *subdata = nullptr; - if (Mode == DerivativeMode::ReverseModeGradient) { - assert(augmentedReturn); - if (augmentedReturn) { - auto fd = augmentedReturn->subaugmentations.find(&call); - if (fd != augmentedReturn->subaugmentations.end()) { - subdata = fd->second; - } - } - } - if (modifyPrimal) { Value *newcalled = nullptr; @@ -10408,7 +10445,8 @@ class AdjointGenerator subdata = &gutils->Logic.CreateAugmentedPrimal( cast(called), subretType, argsInverted, TR.analyzer.interprocedural, /*return is used*/ subretused, - nextTypeInfo, uncacheable_args, false, gutils->AtomicAdd); + shadowReturnUsed, nextTypeInfo, uncacheable_args, false, + gutils->AtomicAdd); if (Mode == DerivativeMode::ReverseModePrimal) { assert(augmentedReturn); auto subaugmentations = @@ -10729,12 +10767,8 @@ class AdjointGenerator IRBuilder<> Builder2(call.getParent()); getReverseBuilder(Builder2); - bool retUsed = replaceFunction && subretused; Value *newcalled = nullptr; - bool subdretptr = (subretType == DIFFE_TYPE::DUP_ARG || - subretType == DIFFE_TYPE::DUP_NONEED) && - replaceFunction; // && (call.getNumUses() != 0); DerivativeMode subMode = (replaceFunction || !modifyPrimal) ? DerivativeMode::ReverseModeCombined : DerivativeMode::ReverseModeGradient; @@ -10744,8 +10778,9 @@ class AdjointGenerator .retType = subretType, .constant_args = argsInverted, .uncacheable_args = uncacheable_args, - .returnUsed = retUsed, - .shadowReturnUsed = subdretptr, + .returnUsed = replaceFunction && subretused, + .shadowReturnUsed = + shadowReturnUsed && replaceFunction, .mode = subMode, .width = gutils->getWidth(), .freeMemory = true, @@ -10866,9 +10901,13 @@ class AdjointGenerator } #endif - unsigned structidx = retUsed ? 1 : 0; - if (subdretptr) - ++structidx; + unsigned structidx = 0; + if (replaceFunction) { + if (subretused) + structidx++; + if (shadowReturnUsed) + structidx++; + } #if LLVM_VERSION_MAJOR >= 14 for (unsigned i = 0; i < orig->arg_size(); ++i) @@ -10897,8 +10936,8 @@ class AdjointGenerator if (structidx != 0) { llvm::errs() << *gutils->oldFunc->getParent() << "\n"; llvm::errs() << "diffes: " << *diffes << " structidx=" << structidx - << " retUsed=" << retUsed << " subretptr=" << subdretptr - << "\n"; + << " subretused=" << subretused + << " shadowReturnUsed=" << shadowReturnUsed << "\n"; } assert(structidx == 0); } else { @@ -10917,10 +10956,10 @@ class AdjointGenerator if (ifound != gutils->invertedPointers.end()) { auto placeholder = cast(&*ifound->second); gutils->invertedPointers.erase(ifound); - if (subdretptr) { + if (shadowReturnUsed) { dumpMap(gutils->invertedPointers); - auto dretval = - cast(Builder2.CreateExtractValue(diffes, {1})); + auto dretval = cast( + Builder2.CreateExtractValue(diffes, {subretused ? 1U : 0U})); /* todo handle this case later */ assert(!subretused); gutils->invertedPointers.insert(std::make_pair( diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index cc912d67f6cde..7b086f60e76f4 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -379,8 +379,9 @@ LLVMValueRef EnzymeCreateForwardDiff( EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode, - unsigned width, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, - uint8_t *_uncacheable_args, size_t uncacheable_args_size) { + uint8_t freeMemory, unsigned width, LLVMTypeRef additionalArg, + CFnTypeInfo typeInfo, uint8_t *_uncacheable_args, + size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented) { std::vector nconstant_args((DIFFE_TYPE *)constant_args, (DIFFE_TYPE *)constant_args + constant_args_size); @@ -393,9 +394,9 @@ LLVMValueRef EnzymeCreateForwardDiff( } return wrap(eunwrap(Logic).CreateForwardDiff( cast(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args, - eunwrap(TA), returnValue, (DerivativeMode)mode, width, + eunwrap(TA), returnValue, (DerivativeMode)mode, freeMemory, width, unwrap(additionalArg), eunwrap(typeInfo, cast(unwrap(todiff))), - uncacheable_args)); + uncacheable_args, eunwrap(augmented))); } LLVMValueRef EnzymeCreatePrimalAndGradient( EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType, @@ -432,12 +433,14 @@ LLVMValueRef EnzymeCreatePrimalAndGradient( }, eunwrap(TA), eunwrap(augmented))); } -EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( - EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType, - CDIFFE_TYPE *constant_args, size_t constant_args_size, - EnzymeTypeAnalysisRef TA, uint8_t returnUsed, CFnTypeInfo typeInfo, - uint8_t *_uncacheable_args, size_t uncacheable_args_size, - uint8_t forceAnonymousTape, uint8_t AtomicAdd) { +EnzymeAugmentedReturnPtr +EnzymeCreateAugmentedPrimal(EnzymeLogicRef Logic, LLVMValueRef todiff, + CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, + size_t constant_args_size, EnzymeTypeAnalysisRef TA, + uint8_t returnUsed, uint8_t shadowReturnUsed, + CFnTypeInfo typeInfo, uint8_t *_uncacheable_args, + size_t uncacheable_args_size, + uint8_t forceAnonymousTape, uint8_t AtomicAdd) { std::vector nconstant_args((DIFFE_TYPE *)constant_args, (DIFFE_TYPE *)constant_args + @@ -451,7 +454,7 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( } return ewrap(eunwrap(Logic).CreateAugmentedPrimal( cast(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args, - eunwrap(TA), returnUsed, + eunwrap(TA), returnUsed, shadowReturnUsed, eunwrap(typeInfo, cast(unwrap(todiff))), uncacheable_args, forceAnonymousTape, AtomicAdd)); } diff --git a/enzyme/Enzyme/CApi.h b/enzyme/Enzyme/CApi.h index e6c474165f2b0..e7e5256b22418 100644 --- a/enzyme/Enzyme/CApi.h +++ b/enzyme/Enzyme/CApi.h @@ -121,8 +121,9 @@ LLVMValueRef EnzymeCreateForwardDiff( EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode, - unsigned width, LLVMTypeRef additionalArg, struct CFnTypeInfo typeInfo, - uint8_t *_uncacheable_args, size_t uncacheable_args_size); + uint8_t freeMemory, unsigned width, LLVMTypeRef additionalArg, + struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args, + size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented); LLVMValueRef EnzymeCreatePrimalAndGradient( EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType, @@ -136,9 +137,10 @@ LLVMValueRef EnzymeCreatePrimalAndGradient( EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, - EnzymeTypeAnalysisRef TA, uint8_t returnUsed, struct CFnTypeInfo typeInfo, - uint8_t *_uncacheable_args, size_t uncacheable_args_size, - uint8_t forceAnonymousTape, uint8_t AtomicAdd); + EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed, + struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args, + size_t uncacheable_args_size, uint8_t forceAnonymousTape, + uint8_t AtomicAdd); typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/, CTypeTreeRef * /*args*/, diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 79d43a7dc98cb..b314d54e842ff 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -92,7 +92,7 @@ llvm::cl::opt EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden, #endif namespace { -template +template static void handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, std::vector &globalsToErase) { @@ -130,7 +130,8 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, } } - if (numargs == 3) { + if (Mode == DerivativeMode::ReverseModeGradient) { + assert(numargs == 3); Fs[0]->setMetadata( "enzyme_augment", llvm::MDTuple::get(Fs[0]->getContext(), @@ -139,12 +140,24 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, "enzyme_gradient", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[2])})); - } else if (numargs == 2) { + } else if (Mode == DerivativeMode::ForwardMode) { + assert(numargs == 2); Fs[0]->setMetadata( "enzyme_derivative", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[1])})); - } + } else if (Mode == DerivativeMode::ForwardModeSplit) { + assert(numargs == 3); + Fs[0]->setMetadata( + "enzyme_augment", + llvm::MDTuple::get(Fs[0]->getContext(), + {llvm::ValueAsMetadata::get(Fs[1])})); + Fs[0]->setMetadata( + "enzyme_splitderivative", + llvm::MDTuple::get(Fs[0]->getContext(), + {llvm::ValueAsMetadata::get(Fs[2])})); + } else + assert("Unknown mode"); } } else { llvm::errs() << M << "\n"; @@ -444,6 +457,8 @@ class Enzyme : public ModulePass { IRBuilder<> Builder(CI); unsigned truei = 0; unsigned width = 1; + bool returnUsed = !cast(fn)->getReturnType()->isVoidTy() && + !cast(fn)->getReturnType()->isEmptyTy(); // determine width #if LLVM_VERSION_MAJOR >= 14 @@ -455,44 +470,44 @@ class Enzyme : public ModulePass { { Value *arg = CI->getArgOperand(i); - if (getMetadataName(arg) && *getMetadataName(arg) == "enzyme_width") { - assert(mode == DerivativeMode::ForwardMode); - - if (found) { - EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, - "vector width declared more than once", - *CI->getArgOperand(i), " in", *CI); - return false; - } + if (auto MDName = getMetadataName(arg)) { + if (*MDName == "enzyme_width") { + if (found) { + EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, + "vector width declared more than once", + *CI->getArgOperand(i), " in", *CI); + return false; + } #if LLVM_VERSION_MAJOR >= 14 - if (i + 1 >= CI->arg_size()) + if (i + 1 >= CI->arg_size()) #else - if (i + 1 >= CI->getNumArgOperands()) + if (i + 1 >= CI->getNumArgOperands()) #endif - { - EmitFailure("MissingVectorWidth", CI->getDebugLoc(), CI, - "constant integer followong enzyme_width is missing", - *CI->getArgOperand(i), " in", *CI); - return false; - } + { + EmitFailure("MissingVectorWidth", CI->getDebugLoc(), CI, + "constant integer followong enzyme_width is missing", + *CI->getArgOperand(i), " in", *CI); + return false; + } - Value *width_arg = CI->getArgOperand(i + 1); - if (auto cint = dyn_cast(width_arg)) { - width = cint->getZExtValue(); - found = true; - } else { - EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, - "enzyme_width must be a constant integer", - *CI->getArgOperand(i), " in", *CI); - return false; - } + Value *width_arg = CI->getArgOperand(i + 1); + if (auto cint = dyn_cast(width_arg)) { + width = cint->getZExtValue(); + found = true; + } else { + EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, + "enzyme_width must be a constant integer", + *CI->getArgOperand(i), " in", *CI); + return false; + } - if (!found) { - EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, - "illegal enzyme vector argument width ", - *CI->getArgOperand(i), " in", *CI); - return false; + if (!found) { + EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, + "illegal enzyme vector argument width ", + *CI->getArgOperand(i), " in", *CI); + return false; + } } } } @@ -579,8 +594,8 @@ class Enzyme : public ModulePass { DIFFE_TYPE retType = whatType(cast(fn)->getReturnType(), mode); - bool differentialReturn = mode != DerivativeMode::ForwardMode && - mode != DerivativeMode::ReverseModePrimal && + bool differentialReturn = (mode == DerivativeMode::ReverseModeCombined || + mode == DerivativeMode::ReverseModeGradient) && (retType == DIFFE_TYPE::OUT_DIFF); std::map byVal; @@ -598,7 +613,9 @@ class Enzyme : public ModulePass { Value *res = CI->getArgOperand(i); if (truei >= FT->getNumParams()) { - if (mode == DerivativeMode::ReverseModeGradient) { + if (!isa(res) && + (mode == DerivativeMode::ReverseModeGradient || + mode == DerivativeMode::ForwardModeSplit)) { if (differentialReturn && differet == nullptr) { differet = res; if (CI->paramHasAttr(i, Attribute::ByVal)) { @@ -644,6 +661,9 @@ class Enzyme : public ModulePass { ty = DIFFE_TYPE::OUT_DIFF; } else if (*metaString == "enzyme_const") { ty = DIFFE_TYPE::CONSTANT; + } else if (*metaString == "enzyme_noret") { + returnUsed = false; + continue; } else if (*metaString == "enzyme_allocated") { assert(!sizeOnly); ++i; @@ -814,13 +834,57 @@ class Enzyme : public ModulePass { Type *tapeType = nullptr; const AugmentedReturn *aug; switch (mode) { - case DerivativeMode::ForwardModeSplit: case DerivativeMode::ForwardMode: newFunc = Logic.CreateForwardDiff( cast(fn), retType, constants, TA, - /*should return*/ false, mode, width, - /*addedType*/ nullptr, type_args, volatile_args); + /*should return*/ false, mode, freeMemory, width, + /*addedType*/ nullptr, type_args, volatile_args, + /*augmented*/ nullptr); + break; + case DerivativeMode::ForwardModeSplit: { + bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1; + aug = &Logic.CreateAugmentedPrimal( + cast(fn), retType, constants, TA, + /*returnUsed*/ false, /*shadowReturnUsed*/ false, type_args, + volatile_args, forceAnonymousTape, /*atomicAdd*/ AtomicAdd); + auto &DL = cast(fn)->getParent()->getDataLayout(); + if (!forceAnonymousTape) { + assert(!aug->tapeType); + if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) { + auto tapeIdx = aug->returns.find(AugmentedStruct::Tape)->second; + tapeType = (tapeIdx == -1) + ? aug->fn->getReturnType() + : cast(aug->fn->getReturnType()) + ->getElementType(tapeIdx); + } else { + if (sizeOnly) { + CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0, false)); + CI->eraseFromParent(); + return true; + } + } + if (sizeOnly) { + auto size = DL.getTypeSizeInBits(tapeType) / 8; + CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), size, false)); + CI->eraseFromParent(); + return true; + } + if (tapeType && + DL.getTypeSizeInBits(tapeType) < 8 * (size_t)allocatedTapeSize) { + auto bytes = DL.getTypeSizeInBits(tapeType) / 8; + EmitFailure("Insufficient tape allocation size", CI->getDebugLoc(), + CI, "need ", bytes, " bytes have ", allocatedTapeSize, + " bytes"); + } + } else { + tapeType = PointerType::getInt8PtrTy(fn->getContext()); + } + newFunc = Logic.CreateForwardDiff( + cast(fn), retType, constants, TA, + /*should return*/ false, mode, freeMemory, width, + /*addedType*/ tapeType, type_args, volatile_args, aug); break; + } case DerivativeMode::ReverseModeCombined: assert(freeMemory); newFunc = Logic.CreatePrimalAndGradient( @@ -841,12 +905,12 @@ class Enzyme : public ModulePass { case DerivativeMode::ReverseModePrimal: case DerivativeMode::ReverseModeGradient: { bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1; - bool returnUsed = !cast(fn)->getReturnType()->isVoidTy() && - !cast(fn)->getReturnType()->isEmptyTy(); + bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG || + retType == DIFFE_TYPE::DUP_NONEED); aug = &Logic.CreateAugmentedPrimal( - cast(fn), retType, constants, TA, - /*returnUsed*/ returnUsed, type_args, volatile_args, - forceAnonymousTape, /*atomicAdd*/ AtomicAdd); + cast(fn), retType, constants, TA, returnUsed, + shadowReturnUsed, type_args, volatile_args, forceAnonymousTape, + /*atomicAdd*/ AtomicAdd); auto &DL = cast(fn)->getParent()->getDataLayout(); if (!forceAnonymousTape) { assert(!aug->tapeType); @@ -918,7 +982,9 @@ class Enzyme : public ModulePass { } } - if (mode == DerivativeMode::ReverseModeGradient && tape && tapeType) { + if ((mode == DerivativeMode::ReverseModeGradient || + mode == DerivativeMode::ForwardModeSplit) && + tape && tapeType) { auto &DL = cast(fn)->getParent()->getDataLayout(); if (tapeIsPointer) { tape = Builder.CreateBitCast( @@ -1227,6 +1293,7 @@ class Enzyme : public ModulePass { Fn->getName().contains("__enzyme_call_inactive") || Fn->getName().contains("__enzyme_autodiff") || Fn->getName().contains("__enzyme_fwddiff") || + Fn->getName().contains("__enzyme_fwdsplit") || Fn->getName().contains("__enzyme_augmentfwd") || Fn->getName().contains("__enzyme_augmentsize") || Fn->getName().contains("__enzyme_reverse"))) @@ -1484,6 +1551,9 @@ class Enzyme : public ModulePass { } else if (Fn->getName().contains("__enzyme_fwddiff")) { enableEnzyme = true; mode = DerivativeMode::ForwardMode; + } else if (Fn->getName().contains("__enzyme_fwdsplit")) { + enableEnzyme = true; + mode = DerivativeMode::ForwardModeSplit; } else if (Fn->getName().contains("__enzyme_augmentfwd")) { enableEnzyme = true; mode = DerivativeMode::ReverseModePrimal; @@ -1679,6 +1749,8 @@ class Enzyme : public ModulePass { "__enzyme_register_gradient"; constexpr static const char derivative_handler_name[] = "__enzyme_register_derivative"; + constexpr static const char splitderivative_handler_name[] = + "__enzyme_register_splitderivative"; Logic.clear(); @@ -1686,10 +1758,17 @@ class Enzyme : public ModulePass { std::vector globalsToErase; for (GlobalVariable &g : M.globals()) { if (g.getName().contains(gradient_handler_name)) { - handleCustomDerivative(M, g, globalsToErase); + handleCustomDerivative( + M, g, globalsToErase); } else if (g.getName().contains(derivative_handler_name)) { - handleCustomDerivative(M, g, - globalsToErase); + handleCustomDerivative(M, g, + globalsToErase); + } else if (g.getName().contains(splitderivative_handler_name)) { + handleCustomDerivative( + M, g, globalsToErase); } else if (g.getName().contains("__enzyme_inactivefn")) { handleInactiveFunction(M, g, globalsToErase); } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index f2ec4361f84ef..9c7cb1d4bd820 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -82,14 +82,6 @@ cl::opt looseTypeAnalysis("enzyme-loose-types", cl::init(false), cl::Hidden, cl::desc("Allow looser use of types")); -cl::opt cache_reads_always("enzyme-cache-always", cl::init(false), - cl::Hidden, - cl::desc("Force always caching of all reads")); - -cl::opt cache_reads_never("enzyme-cache-never", cl::init(false), - cl::Hidden, - cl::desc("Disable caching of all reads")); - cl::opt nonmarkedglobals_inactiveloads( "enzyme_nonmarkedglobals_inactiveloads", cl::init(true), cl::Hidden, cl::desc("Consider loads of nonmarked globals to be inactive")); @@ -1480,11 +1472,167 @@ static FnTypeInfo preventTypeAnalysisLoops(const FnTypeInfo &oldTypeInfo_, return oldTypeInfo; } +void restoreCache( + DiffeGradientUtils *gutils, + const std::map, int> &mapping, + const SmallPtrSetImpl &guaranteedUnreachable) { + // One must use this temporary map to first create all the replacements + // prior to actually replacing to ensure that getSubLimits has the same + // behavior and unwrap behavior for all replacements. + std::vector> newIToNextI; + + for (const auto &m : mapping) { + if (m.first.second == CacheType::Self && + gutils->knownRecomputeHeuristic.count(m.first.first)) { + assert(gutils->knownRecomputeHeuristic.count(m.first.first)); + if (!isa(m.first.first)) { + auto newi = gutils->getNewFromOriginal(m.first.first); + if (auto PN = dyn_cast(newi)) + if (gutils->fictiousPHIs.count(PN)) { + assert(gutils->fictiousPHIs[PN] == m.first.first); + gutils->fictiousPHIs.erase(PN); + } + IRBuilder<> BuilderZ(newi->getNextNode()); + if (isa(m.first.first)) { + BuilderZ.SetInsertPoint( + cast(newi)->getParent()->getFirstNonPHI()); + } + Value *nexti = + gutils->cacheForReverse(BuilderZ, newi, m.second, + /*ignoreType*/ false, /*replace*/ false); + newIToNextI.emplace_back(newi, nexti); + } else { + auto newi = gutils->getNewFromOriginal((Value *)m.first.first); + newIToNextI.emplace_back(newi, newi); + } + } + } + + std::map> unwrapToOrig; + for (auto pair : gutils->unwrappedLoads) + unwrapToOrig[pair.second].push_back(const_cast(pair.first)); + gutils->unwrappedLoads.clear(); + + for (auto pair : newIToNextI) { + auto newi = pair.first; + auto nexti = pair.second; + if (newi != nexti) { + gutils->replaceAWithB(newi, nexti); + } + } + + // This most occur after all the replacements have been made + // in the previous loop, lest a loop bound being unwrapped use + // a value being replaced. + for (auto pair : newIToNextI) { + auto newi = pair.first; + auto nexti = pair.second; + for (auto V : unwrapToOrig[newi]) { + ValueToValueMapTy available; + if (auto MD = hasMetadata(V, "enzyme_available")) { + for (auto &pair : MD->operands()) { + auto tup = cast(pair); + auto val = cast(tup->getOperand(1))->getValue(); + assert(val); + available[cast(tup->getOperand(0))->getValue()] = + val; + } + } + IRBuilder<> lb(V); + // This must disallow caching here as otherwise performing the loop in + // the wrong order may result in first replacing the later unwrapped + // value, caching it, then attempting to reuse it for an earlier + // replacement. + Value *nval = gutils->unwrapM(nexti, lb, available, + UnwrapMode::LegalFullUnwrapNoTapeReplace, + /*scope*/ nullptr, /*permitCache*/ false); + assert(nval); + V->replaceAllUsesWith(nval); + V->eraseFromParent(); + } + } + + // Erasure happens after to not erase the key of unwrapToOrig + for (auto pair : newIToNextI) { + auto newi = pair.first; + auto nexti = pair.second; + if (newi != nexti) { + if (auto inst = dyn_cast(newi)) + gutils->erase(inst); + } + } + + // TODO also can consider switch instance as well + // TODO can also insert to topLevel as well [note this requires putting the + // intrinsic at the correct location] + for (auto &BB : *gutils->oldFunc) { + std::vector unreachables; + std::vector reachables; + for (auto Succ : successors(&BB)) { + if (guaranteedUnreachable.find(Succ) != guaranteedUnreachable.end()) { + unreachables.push_back(Succ); + } else { + reachables.push_back(Succ); + } + } + + if (unreachables.size() == 0 || reachables.size() == 0) + continue; + + if (auto bi = dyn_cast(BB.getTerminator())) { + + Value *condition = gutils->getNewFromOriginal(bi->getCondition()); + + Constant *repVal = (bi->getSuccessor(0) == unreachables[0]) + ? ConstantInt::getFalse(condition->getContext()) + : ConstantInt::getTrue(condition->getContext()); + + for (auto UI = condition->use_begin(), E = condition->use_end(); + UI != E;) { + Use &U = *UI; + ++UI; + U.set(repVal); + } + } + if (reachables.size() == 1) + if (auto si = dyn_cast(BB.getTerminator())) { + Value *condition = gutils->getNewFromOriginal(si->getCondition()); + + Constant *repVal = nullptr; + if (si->getDefaultDest() == reachables[0]) { + std::set cases; + for (auto c : si->cases()) { + // TODO this doesnt work with unsigned 64 bit ints or higher + // integer widths + cases.insert(cast(c.getCaseValue())->getSExtValue()); + } + int64_t legalNot = 0; + while (cases.count(legalNot)) + legalNot++; + repVal = ConstantInt::getSigned(condition->getType(), legalNot); + } else { + for (auto c : si->cases()) { + if (c.getCaseSuccessor() == reachables[0]) { + repVal = c.getCaseValue(); + } + } + } + assert(repVal); + for (auto UI = condition->use_begin(), E = condition->use_end(); + UI != E;) { + Use &U = *UI; + ++UI; + U.set(repVal); + } + } + } +} + //! return structtype if recursive function const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( Function *todiff, DIFFE_TYPE retType, const std::vector &constant_args, TypeAnalysis &TA, - bool returnUsed, const FnTypeInfo &oldTypeInfo_, + bool returnUsed, bool shadowReturnUsed, const FnTypeInfo &oldTypeInfo_, const std::map _uncacheable_args, bool forceAnonymousTape, bool AtomicAdd, bool omp) { if (returnUsed) @@ -1497,11 +1645,12 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( FnTypeInfo oldTypeInfo = preventTypeAnalysisLoops(oldTypeInfo_, todiff); assert(constant_args.size() == todiff->getFunctionType()->getNumParams()); - AugmentedCacheKey tup = std::make_tuple( - todiff, retType, constant_args, - std::map(_uncacheable_args.begin(), - _uncacheable_args.end()), - returnUsed, oldTypeInfo, forceAnonymousTape, AtomicAdd, omp); + AugmentedCacheKey tup = + std::make_tuple(todiff, retType, constant_args, + std::map(_uncacheable_args.begin(), + _uncacheable_args.end()), + returnUsed, shadowReturnUsed, oldTypeInfo, + forceAnonymousTape, AtomicAdd, omp); auto found = AugmentedCachedFunctions.find(tup); if (found != AugmentedCachedFunctions.end()) { return found->second; @@ -1510,8 +1659,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( // TODO make default typing (not just constant) - if (hasMetadata(todiff, "enzyme_augment")) { - auto md = todiff->getMetadata("enzyme_augment"); + if (auto md = hasMetadata(todiff, "enzyme_augment")) { if (!isa(md)) { llvm::errs() << *todiff << "\n"; llvm::errs() << *md << "\n"; @@ -1532,8 +1680,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( } if (hasconstant) { - EmitWarning("NoCustom", todiff->getEntryBlock().begin()->getDebugLoc(), - todiff, &todiff->getEntryBlock(), + EmitWarning("NoCustom", todiff, "Massaging provided custom augmented forward pass to handle " "constant argumented"); SmallVector dupargs; @@ -1591,8 +1738,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( act_idx++; } auto &aug = CreateAugmentedPrimal( - todiff, retType, next_constant_args, TA, returnUsed, oldTypeInfo_, - _uncacheable_args, forceAnonymousTape, AtomicAdd, omp); + todiff, retType, next_constant_args, TA, returnUsed, shadowReturnUsed, + oldTypeInfo_, _uncacheable_args, forceAnonymousTape, AtomicAdd, omp); auto cal = bb.CreateCall(aug.fn, fwdargs); cal->setCallingConv(aug.fn->getCallingConv()); @@ -1682,7 +1829,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( GradientUtils *gutils = GradientUtils::CreateFromClone( *this, todiff, TLI, TA, retType, constant_args, - /*returnUsed*/ returnUsed, returnMapping, omp); + /*returnUsed*/ returnUsed, /*shadowReturnUsed*/ shadowReturnUsed, + returnMapping, omp); gutils->AtomicAdd = AtomicAdd; const SmallPtrSet guaranteedUnreachable = getGuaranteedUnreachable(gutils->oldFunc); @@ -1952,7 +2100,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( //! Keep track of inverted pointers we may need to return ValueToValueMapTy invertedRetPs; - if (retType == DIFFE_TYPE::DUP_ARG || retType == DIFFE_TYPE::DUP_NONEED) { + if (shadowReturnUsed) { for (BasicBlock &BB : *gutils->oldFunc) { if (auto ri = dyn_cast(BB.getTerminator())) { if (Value *orig_oldval = ri->getReturnValue()) { @@ -2109,60 +2257,66 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( Value *tapeMemory; if (recursive && !omp) { auto i64 = Type::getInt64Ty(NewF->getContext()); - ConstantInt *size; - tapeMemory = CallInst::CreateMalloc( - NewF->getEntryBlock().getFirstNonPHI(), i64, tapeType, - size = ConstantInt::get( - i64, NewF->getParent()->getDataLayout().getTypeAllocSizeInBits( - tapeType) / - 8), - nullptr, nullptr, "tapemem"); - CallInst *malloccall = dyn_cast(tapeMemory); - if (malloccall == nullptr) { - malloccall = - cast(cast(tapeMemory)->getOperand(0)); - } - for (auto attr : {Attribute::NoAlias, Attribute::NonNull}) { + ConstantInt *size = ConstantInt::get( + i64, + NewF->getParent()->getDataLayout().getTypeAllocSizeInBits(tapeType) / + 8); + Value *memory; + if (!size->isZero()) { + tapeMemory = + CallInst::CreateMalloc(NewF->getEntryBlock().getFirstNonPHI(), i64, + tapeType, size, nullptr, nullptr, "tapemem"); + CallInst *malloccall = dyn_cast(tapeMemory); + if (malloccall == nullptr) { + malloccall = + cast(cast(tapeMemory)->getOperand(0)); + } + for (auto attr : {Attribute::NoAlias, Attribute::NonNull}) { #if LLVM_VERSION_MAJOR >= 14 - malloccall->addRetAttr(attr); + malloccall->addRetAttr(attr); #else - malloccall->addAttribute(AttributeList::ReturnIndex, attr); + malloccall->addAttribute(AttributeList::ReturnIndex, attr); #endif - } - if (EnzymeZeroCache) { - IRBuilder<> B(malloccall->getNextNode()); - Value *args[] = { - malloccall, - ConstantInt::get(Type::getInt8Ty(malloccall->getContext()), 0), - malloccall->getArgOperand(0), - ConstantInt::getFalse(malloccall->getContext())}; - Type *tys[] = {args[0]->getType(), args[2]->getType()}; - - B.CreateCall(Intrinsic::getDeclaration(NewF->getParent(), - Intrinsic::memset, tys), - args); - } + } + if (EnzymeZeroCache) { + IRBuilder<> B(malloccall->getNextNode()); + Value *args[] = { + malloccall, + ConstantInt::get(Type::getInt8Ty(malloccall->getContext()), 0), + malloccall->getArgOperand(0), + ConstantInt::getFalse(malloccall->getContext())}; + Type *tys[] = {args[0]->getType(), args[2]->getType()}; + + B.CreateCall(Intrinsic::getDeclaration(NewF->getParent(), + Intrinsic::memset, tys), + args); + } #if LLVM_VERSION_MAJOR >= 14 - malloccall->addDereferenceableRetAttr(size->getLimitedValue()); + malloccall->addDereferenceableRetAttr(size->getLimitedValue()); #ifndef FLANG - AttrBuilder B(malloccall->getContext()); + AttrBuilder B(malloccall->getContext()); #else - AttrBuilder B; + AttrBuilder B; #endif - B.addDereferenceableOrNullAttr(size->getLimitedValue()); - malloccall->setAttributes(malloccall->getAttributes().addRetAttributes( - malloccall->getContext(), B)); + B.addDereferenceableOrNullAttr(size->getLimitedValue()); + malloccall->setAttributes(malloccall->getAttributes().addRetAttributes( + malloccall->getContext(), B)); #else - malloccall->addDereferenceableAttr(llvm::AttributeList::ReturnIndex, - size->getLimitedValue()); - malloccall->addDereferenceableOrNullAttr(llvm::AttributeList::ReturnIndex, - size->getLimitedValue()); + malloccall->addDereferenceableAttr(llvm::AttributeList::ReturnIndex, + size->getLimitedValue()); + malloccall->addDereferenceableOrNullAttr( + llvm::AttributeList::ReturnIndex, size->getLimitedValue()); #endif + memory = malloccall; + } else { + memory = + ConstantPointerNull::get(Type::getInt8PtrTy(NewF->getContext())); + } Value *Idxs[] = { ib.getInt32(0), ib.getInt32(returnMapping.find(AugmentedStruct::Tape)->second), }; - assert(malloccall); + assert(memory); assert(ret); Value *gep = ret; if (!removeStruct) { @@ -2174,7 +2328,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( #endif cast(gep)->setIsInBounds(true); } - ib.CreateStore(malloccall, gep); + ib.CreateStore(memory, gep); } else if (omp) { j->setName("tape"); tapeMemory = j; @@ -2256,7 +2410,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( ib.CreateStore(actualrv, gep); } - if (retType == DIFFE_TYPE::DUP_ARG || retType == DIFFE_TYPE::DUP_NONEED) { + if (shadowReturnUsed) { assert(invertedRetPs[ri]); if (!isa(invertedRetPs[ri])) { Value *gep = @@ -2888,8 +3042,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient( auto &aug = CreateAugmentedPrimal( key.todiff, key.retType, key.constant_args, TA, key.returnUsed, - key.typeInfo, key.uncacheable_args, /*forceAnonymousTape*/ false, - key.AtomicAdd, omp); + key.shadowReturnUsed, key.typeInfo, key.uncacheable_args, + /*forceAnonymousTape*/ false, key.AtomicAdd, omp); SmallVector fwdargs; for (auto &a : NewF->args()) @@ -3481,159 +3635,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient( key.retType); } - if (key.mode != DerivativeMode::ReverseModeCombined) { - // One must use this temporary map to first create all the replacements - // prior to actually replacing to ensure that getSubLimits has the same - // behavior and unwrap behavior for all replacements. - std::vector> newIToNextI; - - for (const auto &m : mapping) { - if (m.first.second == CacheType::Self && - gutils->knownRecomputeHeuristic.count(m.first.first)) { - assert(gutils->knownRecomputeHeuristic.count(m.first.first)); - if (!isa(m.first.first)) { - auto newi = gutils->getNewFromOriginal(m.first.first); - if (auto PN = dyn_cast(newi)) - if (gutils->fictiousPHIs.count(PN)) { - assert(gutils->fictiousPHIs[PN] == m.first.first); - gutils->fictiousPHIs.erase(PN); - } - IRBuilder<> BuilderZ(newi->getNextNode()); - if (isa(m.first.first)) { - BuilderZ.SetInsertPoint( - cast(newi)->getParent()->getFirstNonPHI()); - } - Value *nexti = - gutils->cacheForReverse(BuilderZ, newi, m.second, - /*ignoreType*/ false, /*replace*/ false); - newIToNextI.emplace_back(newi, nexti); - } else { - auto newi = gutils->getNewFromOriginal((Value *)m.first.first); - newIToNextI.emplace_back(newi, newi); - } - } - } - - std::map> unwrapToOrig; - for (auto pair : gutils->unwrappedLoads) - unwrapToOrig[pair.second].push_back( - const_cast(pair.first)); - gutils->unwrappedLoads.clear(); - - for (auto pair : newIToNextI) { - auto newi = pair.first; - auto nexti = pair.second; - if (newi != nexti) { - gutils->replaceAWithB(newi, nexti); - } - } - - // This most occur after all the replacements have been made - // in the previous loop, lest a loop bound being unwrapped use - // a value being replaced. - for (auto pair : newIToNextI) { - auto newi = pair.first; - auto nexti = pair.second; - for (auto V : unwrapToOrig[newi]) { - ValueToValueMapTy available; - if (auto MD = hasMetadata(V, "enzyme_available")) { - for (auto &pair : MD->operands()) { - auto tup = cast(pair); - auto val = cast(tup->getOperand(1))->getValue(); - assert(val); - available[cast(tup->getOperand(0))->getValue()] = - val; - } - } - IRBuilder<> lb(V); - // This must disallow caching here as otherwise performing the loop in - // the wrong order may result in first replacing the later unwrapped - // value, caching it, then attempting to reuse it for an earlier - // replacement. - Value *nval = gutils->unwrapM(nexti, lb, available, - UnwrapMode::LegalFullUnwrapNoTapeReplace, - /*scope*/ nullptr, /*permitCache*/ false); - assert(nval); - V->replaceAllUsesWith(nval); - V->eraseFromParent(); - } - } - - // Erasure happens after to not erase the key of unwrapToOrig - for (auto pair : newIToNextI) { - auto newi = pair.first; - auto nexti = pair.second; - if (newi != nexti) { - if (auto inst = dyn_cast(newi)) - gutils->erase(inst); - } - } - - // TODO also can consider switch instance as well - // TODO can also insert to topLevel as well [note this requires putting the - // intrinsic at the correct location] - for (auto &BB : *gutils->oldFunc) { - std::vector unreachables; - std::vector reachables; - for (auto Succ : successors(&BB)) { - if (guaranteedUnreachable.find(Succ) != guaranteedUnreachable.end()) { - unreachables.push_back(Succ); - } else { - reachables.push_back(Succ); - } - } - - if (unreachables.size() == 0 || reachables.size() == 0) - continue; - - if (auto bi = dyn_cast(BB.getTerminator())) { - - Value *condition = gutils->getNewFromOriginal(bi->getCondition()); - - Constant *repVal = (bi->getSuccessor(0) == unreachables[0]) - ? ConstantInt::getFalse(condition->getContext()) - : ConstantInt::getTrue(condition->getContext()); - - for (auto UI = condition->use_begin(), E = condition->use_end(); - UI != E;) { - Use &U = *UI; - ++UI; - U.set(repVal); - } - } - if (reachables.size() == 1) - if (auto si = dyn_cast(BB.getTerminator())) { - Value *condition = gutils->getNewFromOriginal(si->getCondition()); - - Constant *repVal = nullptr; - if (si->getDefaultDest() == reachables[0]) { - std::set cases; - for (auto c : si->cases()) { - // TODO this doesnt work with unsigned 64 bit ints or higher - // integer widths - cases.insert(cast(c.getCaseValue())->getSExtValue()); - } - int64_t legalNot = 0; - while (cases.count(legalNot)) - legalNot++; - repVal = ConstantInt::getSigned(condition->getType(), legalNot); - } else { - for (auto c : si->cases()) { - if (c.getCaseSuccessor() == reachables[0]) { - repVal = c.getCaseValue(); - } - } - } - assert(repVal); - for (auto UI = condition->use_begin(), E = condition->use_end(); - UI != E;) { - Use &U = *UI; - ++UI; - U.set(repVal); - } - } - } - } + if (key.mode == DerivativeMode::ReverseModeGradient) + restoreCache(gutils, mapping, guaranteedUnreachable); gutils->eraseFictiousPHIs(); @@ -3759,9 +3762,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient( Function *EnzymeLogic::CreateForwardDiff( Function *todiff, DIFFE_TYPE retType, const std::vector &constant_args, TypeAnalysis &TA, - bool returnUsed, DerivativeMode mode, unsigned width, + bool returnUsed, DerivativeMode mode, bool freeMemory, unsigned width, llvm::Type *additionalArg, const FnTypeInfo &oldTypeInfo_, - const std::map _uncacheable_args, bool omp) { + const std::map _uncacheable_args, + const AugmentedReturn *augmenteddata, bool omp) { assert(retType != DIFFE_TYPE::OUT_DIFF); assert(mode == DerivativeMode::ForwardMode || @@ -3793,8 +3797,9 @@ Function *EnzymeLogic::CreateForwardDiff( } } - if (hasMetadata(todiff, "enzyme_derivative")) { - auto md = todiff->getMetadata("enzyme_derivative"); + if (auto md = hasMetadata(todiff, (mode == DerivativeMode::ForwardMode) + ? "enzyme_derivative" + : "enzyme_splitderivative")) { if (!isa(md)) { llvm::errs() << *todiff << "\n"; llvm::errs() << *md << "\n"; @@ -3862,6 +3867,11 @@ Function *EnzymeLogic::CreateForwardDiff( break; } } + if (augmenteddata && augmenteddata->returns.find(AugmentedStruct::Tape) != + augmenteddata->returns.end()) { + assert(additionalArg); + curTypes.push_back(additionalArg); + } if (legal) { Type *RT = todiff->getReturnType(); if (returnUsed && retType != DIFFE_TYPE::CONSTANT) { @@ -3906,6 +3916,12 @@ Function *EnzymeLogic::CreateForwardDiff( break; } } + if (augmenteddata && augmenteddata->returns.find(AugmentedStruct::Tape) != + augmenteddata->returns.end()) { + foundArg->setName("tapeArg"); + nextArgs.push_back(foundArg); + foundArg++; + } BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF); IRBuilder<> bb(BB); @@ -3955,6 +3971,8 @@ Function *EnzymeLogic::CreateForwardDiff( insert_or_assign2(ForwardCachedFunctions, tup, gutils->newFunc); + gutils->FreeMemory = freeMemory; + const SmallPtrSet guaranteedUnreachable = getGuaranteedUnreachable(gutils->oldFunc); @@ -4035,8 +4053,9 @@ Function *EnzymeLogic::CreateForwardDiff( } AdjointGenerator *maker; - if (mode == DerivativeMode::ForwardModeSplit) { + std::unique_ptr> can_modref_map; + if (mode == DerivativeMode::ForwardModeSplit) { std::map _uncacheable_argsPP; { auto in_arg = todiff->arg_begin(); @@ -4048,7 +4067,7 @@ Function *EnzymeLogic::CreateForwardDiff( } } - // TODO gutils->computeGuaranteedFrees(guaranteedUnreachable, TR); + gutils->computeGuaranteedFrees(guaranteedUnreachable, TR); CacheAnalysis CA( gutils->allocationsWithGuaranteedFree, gutils->rematerializableAllocations, TR, gutils->OrigAA, @@ -4058,15 +4077,13 @@ Function *EnzymeLogic::CreateForwardDiff( _uncacheable_argsPP, mode, omp); const std::map> uncacheable_args_map = CA.compute_uncacheable_args_for_callsites(); - - const std::map can_modref_map = - CA.compute_uncacheable_load_map(); - gutils->can_modref_map = &can_modref_map; - - std::map, int> mapping; + can_modref_map = std::make_unique>( + CA.compute_uncacheable_load_map()); + gutils->can_modref_map = can_modref_map.get(); auto getIndex = [&](Instruction *I, CacheType u) -> unsigned { - return gutils->getIndex(std::make_pair(I, u), mapping); + assert(augmenteddata); + return gutils->getIndex(std::make_pair(I, u), augmenteddata->tapeIndices); }; gutils->computeMinCache(TR, guaranteedUnreachable); @@ -4074,9 +4091,53 @@ Function *EnzymeLogic::CreateForwardDiff( maker = new AdjointGenerator( mode, gutils, constant_args, retType, TR, getIndex, uncacheable_args_map, - /*returnuses*/ nullptr, nullptr, nullptr, unnecessaryValues, + /*returnuses*/ nullptr, augmenteddata, nullptr, unnecessaryValues, unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable, nullptr); + + if (additionalArg) { + auto v = gutils->newFunc->arg_end(); + v--; + Value *additionalValue = v; + assert(augmenteddata); + + // TODO VERIFY THIS + if (augmenteddata->tapeType && + augmenteddata->tapeType != additionalValue->getType()) { + IRBuilder<> BuilderZ(gutils->inversionAllocs); + if (!augmenteddata->tapeType->isEmptyTy()) { + auto tapep = BuilderZ.CreatePointerCast( + additionalValue, PointerType::getUnqual(augmenteddata->tapeType)); +#if LLVM_VERSION_MAJOR > 7 + LoadInst *truetape = + BuilderZ.CreateLoad(augmenteddata->tapeType, tapep, "truetape"); +#else + LoadInst *truetape = BuilderZ.CreateLoad(tapep, "truetape"); +#endif + truetape->setMetadata("enzyme_mustcache", + MDNode::get(truetape->getContext(), {})); + + if (!omp && gutils->FreeMemory) { + CallInst *ci = + cast(CallInst::CreateFree(additionalValue, truetape)); + ci->moveAfter(truetape); + ci->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull); + } + additionalValue = truetape; + } else { + if (gutils->FreeMemory) { + CallInst *ci = cast( + CallInst::CreateFree(additionalValue, gutils->inversionAllocs)); + ci->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull); + BuilderZ.Insert(ci); + } + additionalValue = UndefValue::get(augmenteddata->tapeType); + } + } + + // TODO here finish up making recursive structs simply pass in i8* + gutils->setTape(additionalValue); + } } else { maker = new AdjointGenerator( @@ -4137,6 +4198,9 @@ Function *EnzymeLogic::CreateForwardDiff( createTerminator(TR, gutils, &oBB, retType, retVal); } + if (mode == DerivativeMode::ForwardModeSplit && augmenteddata) + restoreCache(gutils, augmenteddata->tapeIndices, guaranteedUnreachable); + gutils->eraseFictiousPHIs(); BasicBlock *entry = &gutils->newFunc->getEntryBlock(); diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 39871391650be..7097ef572d320 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -237,7 +237,8 @@ class EnzymeLogic { std::tuple /*constant_args*/, std::map /*uncacheable_args*/, - bool /*returnUsed*/, const FnTypeInfo, bool, bool, bool>; + bool /*returnUsed*/, bool /*shadowReturnUsed*/, + const FnTypeInfo, bool, bool, bool>; std::map AugmentedCachedFunctions; std::map AugmentedCachedFinished; @@ -255,7 +256,7 @@ class EnzymeLogic { const AugmentedReturn &CreateAugmentedPrimal( llvm::Function *todiff, DIFFE_TYPE retType, const std::vector &constant_args, TypeAnalysis &TA, - bool returnUsed, const FnTypeInfo &typeInfo, + bool returnUsed, bool shadowReturnUsed, const FnTypeInfo &typeInfo, const std::map _uncacheable_args, bool forceAnonymousTape, bool AtomicAdd, bool omp = false); @@ -291,21 +292,16 @@ class EnzymeLogic { CreateForwardDiff(llvm::Function *todiff, DIFFE_TYPE retType, const std::vector &constant_args, TypeAnalysis &TA, bool returnValue, DerivativeMode mode, - unsigned width, llvm::Type *additionalArg, + bool freeMemory, unsigned width, llvm::Type *additionalArg, const FnTypeInfo &typeInfo, const std::map _uncacheable_args, - bool omp = false); + const AugmentedReturn *augmented, bool omp = false); void clear(); }; extern "C" { extern llvm::cl::opt looseTypeAnalysis; - -extern llvm::cl::opt cache_reads_always; - -extern llvm::cl::opt cache_reads_never; - extern llvm::cl::opt nonmarkedglobals_inactiveloads; }; diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 47aa31031f449..2e700a73b64c0 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -1057,7 +1057,7 @@ Function *PreProcessCache::preprocessForClone(Function *F, if (mode == DerivativeMode::ReverseModeGradient) mode = DerivativeMode::ReverseModePrimal; if (mode == DerivativeMode::ForwardModeSplit) - mode = DerivativeMode::ForwardMode; + mode = DerivativeMode::ReverseModePrimal; // If we've already processed this, return the previous version // and derive aliasing information diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 2b15f03f3fac8..191b3225d86a4 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -1934,8 +1934,9 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, BuilderQ.GetInsertBlock()); - AllocaInst *cache = createCacheForScope( - lctx, innerType, "mdyncache_fromtape", true, false); + AllocaInst *cache = + createCacheForScope(lctx, innerType, "mdyncache_fromtape", + ((DiffeGradientUtils *)this)->FreeMemory, false); assert(malloc); bool isi1 = !ignoreType && malloc->getType()->isIntegerTy() && cast(malloc->getType())->getBitWidth() == 1; @@ -3335,7 +3336,8 @@ GradientUtils *GradientUtils::CreateFromClone( EnzymeLogic &Logic, Function *todiff, TargetLibraryInfo &TLI, TypeAnalysis &TA, DIFFE_TYPE retType, const std::vector &constant_args, bool returnUsed, - std::map &returnMapping, bool omp) { + bool shadowReturnUsed, std::map &returnMapping, + bool omp) { assert(!todiff->empty()); // Since this is forward pass this should always return the tape (at index 0) @@ -3352,10 +3354,10 @@ GradientUtils *GradientUtils::CreateFromClone( // We don't need to differentially return something that we know is not a // pointer (or somehow needed for shadow analysis) - if (retType == DIFFE_TYPE::DUP_ARG || retType == DIFFE_TYPE::DUP_NONEED) { + if (shadowReturnUsed) { + assert(retType == DIFFE_TYPE::DUP_ARG || retType == DIFFE_TYPE::DUP_NONEED); assert(!todiff->getReturnType()->isEmptyTy()); assert(!todiff->getReturnType()->isVoidTy()); - assert(!todiff->getReturnType()->isFPOrFPVectorTy()); returnMapping[AugmentedStruct::DifferentialReturn] = returnCount + 1; ++returnCount; } @@ -3402,7 +3404,8 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( assert(!todiff->empty()); assert(mode == DerivativeMode::ReverseModeGradient || mode == DerivativeMode::ReverseModeCombined || - mode == DerivativeMode::ForwardMode); + mode == DerivativeMode::ForwardMode || + mode == DerivativeMode::ForwardModeSplit); ValueToValueMapTy invertedPointers; SmallPtrSet constants; SmallPtrSet nonconstant; @@ -3636,12 +3639,11 @@ Constant *GradientUtils::GetOrCreateShadowFunction( switch (mode) { case DerivativeMode::ForwardMode: { - Constant *newf = - Logic.CreateForwardDiff(fn, retType, types, TA, false, mode, width, - nullptr, type_args, uncacheable_args); + Constant *newf = Logic.CreateForwardDiff( + fn, retType, types, TA, false, mode, /*freeMemory*/ true, width, + nullptr, type_args, uncacheable_args, /*augmented*/ nullptr); - if (!newf) - newf = UndefValue::get(fn->getType()); + assert(newf); std::string prefix = "_enzyme_forward"; @@ -3660,15 +3662,52 @@ Constant *GradientUtils::GetOrCreateShadowFunction( return ConstantExpr::getPointerCast(GV, fn->getType()); } + case DerivativeMode::ForwardModeSplit: { + auto &augdata = Logic.CreateAugmentedPrimal( + fn, retType, /*constant_args*/ types, TA, + /*returnUsed*/ !fn->getReturnType()->isEmptyTy() && + !fn->getReturnType()->isVoidTy(), + /*shadowReturnUsed*/ false, type_args, uncacheable_args, + /*forceAnonymousTape*/ true, AtomicAdd); + Constant *newf = Logic.CreateForwardDiff( + fn, retType, types, TA, false, mode, /*freeMemory*/ true, width, + nullptr, type_args, uncacheable_args, /*augmented*/ &augdata); + + assert(newf); + + std::string prefix = "_enzyme_forwardsplit"; + + if (width > 1) { + prefix += std::to_string(width); + } + + auto cdata = ConstantStruct::get( + StructType::get(newf->getContext(), + {augdata.fn->getType(), newf->getType()}), + {augdata.fn, newf}); + + std::string globalname = (prefix + "_" + fn->getName() + "'").str(); + auto GV = fn->getParent()->getNamedValue(globalname); + + if (GV == nullptr) { + GV = new GlobalVariable(*fn->getParent(), cdata->getType(), true, + GlobalValue::LinkageTypes::InternalLinkage, cdata, + globalname); + } + + return ConstantExpr::getPointerCast(GV, fn->getType()); + } case DerivativeMode::ReverseModeCombined: case DerivativeMode::ReverseModeGradient: case DerivativeMode::ReverseModePrimal: { // TODO re atomic add consider forcing it to be atomic always as fallback if // used in a parallel context + bool returnUsed = + !fn->getReturnType()->isEmptyTy() && !fn->getReturnType()->isVoidTy(); + bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG || + retType == DIFFE_TYPE::DUP_NONEED); auto &augdata = Logic.CreateAugmentedPrimal( - fn, retType, /*constant_args*/ types, TA, - /*returnUsed*/ !fn->getReturnType()->isEmptyTy() && - !fn->getReturnType()->isVoidTy(), + fn, retType, /*constant_args*/ types, TA, returnUsed, shadowReturnUsed, type_args, uncacheable_args, /*forceAnonymousTape*/ true, AtomicAdd); Constant *newf = Logic.CreatePrimalAndGradient( (ReverseCacheKey){.todiff = fn, @@ -3686,8 +3725,7 @@ Constant *GradientUtils::GetOrCreateShadowFunction( .typeInfo = type_args}, TA, /*map*/ &augdata); - if (!newf) - newf = UndefValue::get(fn->getType()); + assert(newf); auto cdata = ConstantStruct::get( StructType::get(newf->getContext(), {augdata.fn->getType(), newf->getType()}), @@ -3702,9 +3740,6 @@ Constant *GradientUtils::GetOrCreateShadowFunction( } return ConstantExpr::getPointerCast(GV, fn->getType()); } - default: { - report_fatal_error("Invalid derivative mode"); - } } } @@ -6447,46 +6482,58 @@ void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode, if (secretty) { // no change to forward pass if represents floats if (mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ReverseModeCombined) { - IRBuilder<> Builder2(MTI->getParent()); - gutils->getReverseBuilder(Builder2, /*original*/ true); + mode == DerivativeMode::ReverseModeCombined || + mode == DerivativeMode::ForwardModeSplit) { + IRBuilder<> Builder2(MTI); + if (mode == DerivativeMode::ForwardModeSplit) + gutils->getForwardBuilder(Builder2); + else + gutils->getReverseBuilder(Builder2); // If the src is constant simply zero d_dst and don't propagate to d_src // (which thus == src and may be illegal) if (srcConstant) { - Value *args[] = { - shadowsLookedUp ? shadow_dst : gutils->lookupM(shadow_dst, Builder2), - ConstantInt::get(Type::getInt8Ty(MTI->getContext()), 0), - gutils->lookupM(length, Builder2), + // Don't zero in forward mode. + if (mode != DerivativeMode::ForwardModeSplit) { + + Value *args[] = { + shadowsLookedUp ? shadow_dst + : gutils->lookupM(shadow_dst, Builder2), + ConstantInt::get(Type::getInt8Ty(MTI->getContext()), 0), + gutils->lookupM(length, Builder2), #if LLVM_VERSION_MAJOR <= 6 - ConstantInt::get(Type::getInt32Ty(MTI->getContext()), - max(1U, dstalign)), + ConstantInt::get(Type::getInt32Ty(MTI->getContext()), + max(1U, dstalign)), #endif - ConstantInt::getFalse(MTI->getContext()) - }; - - if (args[0]->getType()->isIntegerTy()) - args[0] = Builder2.CreateIntToPtr( - args[0], Type::getInt8PtrTy(MTI->getContext())); + ConstantInt::getFalse(MTI->getContext()) + }; - Type *tys[] = {args[0]->getType(), args[2]->getType()}; - auto memsetIntr = Intrinsic::getDeclaration( - MTI->getParent()->getParent()->getParent(), Intrinsic::memset, tys); - auto cal = Builder2.CreateCall(memsetIntr, args); - cal->setCallingConv(memsetIntr->getCallingConv()); - if (dstalign != 0) { + if (args[0]->getType()->isIntegerTy()) + args[0] = Builder2.CreateIntToPtr( + args[0], Type::getInt8PtrTy(MTI->getContext())); + + Type *tys[] = {args[0]->getType(), args[2]->getType()}; + auto memsetIntr = Intrinsic::getDeclaration( + MTI->getParent()->getParent()->getParent(), Intrinsic::memset, + tys); + auto cal = Builder2.CreateCall(memsetIntr, args); + cal->setCallingConv(memsetIntr->getCallingConv()); + if (dstalign != 0) { #if LLVM_VERSION_MAJOR >= 10 - cal->addParamAttr(0, Attribute::getWithAlignment(MTI->getContext(), - Align(dstalign))); + cal->addParamAttr(0, Attribute::getWithAlignment(MTI->getContext(), + Align(dstalign))); #else - cal->addParamAttr( - 0, Attribute::getWithAlignment(MTI->getContext(), dstalign)); + cal->addParamAttr( + 0, Attribute::getWithAlignment(MTI->getContext(), dstalign)); #endif + } } } else { - auto dsto = shadowsLookedUp ? shadow_dst - : gutils->lookupM(shadow_dst, Builder2); + auto dsto = + (shadowsLookedUp || mode == DerivativeMode::ForwardModeSplit) + ? shadow_dst + : gutils->lookupM(shadow_dst, Builder2); if (dsto->getType()->isIntegerTy()) dsto = Builder2.CreateIntToPtr( dsto, Type::getInt8PtrTy(dsto->getContext())); @@ -6501,8 +6548,12 @@ void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode, dsto = Builder2.CreateConstInBoundsGEP1_64(dsto, offset); #endif } - auto srco = shadowsLookedUp ? shadow_src - : gutils->lookupM(shadow_src, Builder2); + auto srco = + (shadowsLookedUp || mode == DerivativeMode::ForwardModeSplit) + ? shadow_src + : gutils->lookupM(shadow_src, Builder2); + if (mode != DerivativeMode::ForwardModeSplit) + dsto = Builder2.CreatePointerCast(dsto, secretpt); if (srco->getType()->isIntegerTy()) srco = Builder2.CreateIntToPtr( srco, Type::getInt8PtrTy(srco->getContext())); @@ -6517,26 +6568,49 @@ void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode, srco = Builder2.CreateConstInBoundsGEP1_64(srco, offset); #endif } + if (mode != DerivativeMode::ForwardModeSplit) + srco = Builder2.CreatePointerCast(srco, secretpt); - Value *args[]{ - Builder2.CreatePointerCast(dsto, secretpt), - Builder2.CreatePointerCast(srco, secretpt), - Builder2.CreateUDiv( - gutils->lookupM(length, Builder2), - ConstantInt::get(length->getType(), - Builder2.GetInsertBlock() - ->getParent() - ->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(secretty) / - 8))}; - - auto dmemcpy = ((intrinsic == Intrinsic::memcpy) - ? getOrInsertDifferentialFloatMemcpy - : getOrInsertDifferentialFloatMemmove)( - *MTI->getParent()->getParent()->getParent(), secretty, dstalign, - srcalign, dstaddr, srcaddr); - Builder2.CreateCall(dmemcpy, args); + if (mode == DerivativeMode::ForwardModeSplit) { + CallInst *call; +#if LLVM_VERSION_MAJOR >= 11 + MaybeAlign dalign; + if (dstalign) + dalign = MaybeAlign(dstalign); + MaybeAlign salign; + if (srcalign) + salign = MaybeAlign(srcalign); +#else + auto dalign = dstalign; + auto salign = srcalign; +#endif + + if (intrinsic == Intrinsic::memmove) { + call = Builder2.CreateMemMove(dsto, dalign, srco, salign, length); + } else { + call = Builder2.CreateMemCpy(dsto, dalign, srco, salign, length); + } + } else { + Value *args[]{ + Builder2.CreatePointerCast(dsto, secretpt), + Builder2.CreatePointerCast(srco, secretpt), + Builder2.CreateUDiv( + gutils->lookupM(length, Builder2), + ConstantInt::get(length->getType(), + Builder2.GetInsertBlock() + ->getParent() + ->getParent() + ->getDataLayout() + .getTypeAllocSizeInBits(secretty) / + 8))}; + + auto dmemcpy = ((intrinsic == Intrinsic::memcpy) + ? getOrInsertDifferentialFloatMemcpy + : getOrInsertDifferentialFloatMemmove)( + *MTI->getParent()->getParent()->getParent(), secretty, dstalign, + srcalign, dstaddr, srcaddr); + Builder2.CreateCall(dmemcpy, args); + } } } } else { diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 5cc91a77b9d4e..0133949a5340c 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -1008,25 +1008,33 @@ class GradientUtils : public CacheUtility { llvm::errs() << "end invertedPointers\n"; } + int getIndex( + std::pair idx, + const std::map, int> &mapping) { + assert(tape); + auto found = mapping.find(idx); + if (found == mapping.end()) { + llvm::errs() << "oldFunc: " << *oldFunc << "\n"; + llvm::errs() << "newFunc: " << *newFunc << "\n"; + llvm::errs() << " \n"; + for (auto &p : mapping) { + llvm::errs() << " idx: " << *p.first.first << ", " << p.first.second + << " pos=" << p.second << "\n"; + } + llvm::errs() << " \n"; + + llvm::errs() << "idx: " << *idx.first << ", " << idx.second << "\n"; + assert(0 && "could not find index in mapping"); + } + return found->second; + } + int getIndex(std::pair idx, std::map, int> &mapping) { if (tape) { - if (mapping.find(idx) == mapping.end()) { - llvm::errs() << "oldFunc: " << *oldFunc << "\n"; - llvm::errs() << "newFunc: " << *newFunc << "\n"; - llvm::errs() << " \n"; - for (auto &p : mapping) { - llvm::errs() << " idx: " << *p.first.first << ", " << p.first.second - << " pos=" << p.second << "\n"; - } - llvm::errs() << " \n"; - - if (mapping.find(idx) == mapping.end()) { - llvm::errs() << "idx: " << *idx.first << ", " << idx.second << "\n"; - assert(0 && "could not find index in mapping"); - } - } - return mapping[idx]; + return getIndex( + idx, + (const std::map, int> &)mapping); } else { if (mapping.find(idx) != mapping.end()) { return mapping[idx]; @@ -1143,6 +1151,7 @@ class GradientUtils : public CacheUtility { CreateFromClone(EnzymeLogic &Logic, Function *todiff, TargetLibraryInfo &TLI, TypeAnalysis &TA, DIFFE_TYPE retType, const std::vector &constant_args, bool returnUsed, + bool shadowReturnUsed, std::map &returnMapping, bool omp); #if LLVM_VERSION_MAJOR >= 10 diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index ed5f2354a2b8c..7f975e08f473d 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -104,6 +104,21 @@ void EmitWarning(llvm::StringRef RemarkName, (llvm::errs() << ... << args) << "\n"; } +template +void EmitWarning(llvm::StringRef RemarkName, const llvm::Function *F, + const Args &...args) { + + llvm::OptimizationRemarkEmitter ORE(F); + ORE.emit([&]() { + std::string str; + llvm::raw_string_ostream ss(str); + (ss << ... << args); + return llvm::OptimizationRemark("enzyme", RemarkName, F) << ss.str(); + }); + if (EnzymePrintPerf) + (llvm::errs() << ... << args) << "\n"; +} + class EnzymeFailure : public llvm::DiagnosticInfoIROptimization { public: EnzymeFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, @@ -202,9 +217,9 @@ getNextNonDebugInstruction(llvm::Instruction *Z) { } /// Check if a global has metadata -static inline bool hasMetadata(const llvm::GlobalObject *O, - llvm::StringRef kind) { - return O->getMetadata(kind) != nullptr; +static inline llvm::MDNode *hasMetadata(const llvm::GlobalObject *O, + llvm::StringRef kind) { + return O->getMetadata(kind); } /// Check if an instruction has metadata @@ -402,8 +417,10 @@ static inline DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode, } else if (arg->isIntOrIntVectorTy() || arg->isFunctionTy()) { return DIFFE_TYPE::CONSTANT; } else if (arg->isFPOrFPVectorTy()) { - return (mode == DerivativeMode::ForwardMode) ? DIFFE_TYPE::DUP_ARG - : DIFFE_TYPE::OUT_DIFF; + return (mode == DerivativeMode::ForwardMode || + mode == DerivativeMode::ForwardModeSplit) + ? DIFFE_TYPE::DUP_ARG + : DIFFE_TYPE::OUT_DIFF; } else { assert(arg); llvm::errs() << "arg: " << *arg << "\n"; diff --git a/enzyme/test/Enzyme/CMakeLists.txt b/enzyme/test/Enzyme/CMakeLists.txt index f5e2e0658aeb9..2706c5d9aaca3 100644 --- a/enzyme/test/Enzyme/CMakeLists.txt +++ b/enzyme/test/Enzyme/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(ReverseMode) add_subdirectory(ForwardMode) +add_subdirectory(ForwardModeSplit) add_subdirectory(ForwardModeVector) # Run regression and unit tests @@ -9,4 +10,4 @@ add_lit_testsuite(check-enzyme "Running enzyme regression tests" ARGS -v ) -set_target_properties(check-enzyme PROPERTIES FOLDER "Tests") \ No newline at end of file +set_target_properties(check-enzyme PROPERTIES FOLDER "Tests") diff --git a/enzyme/test/Enzyme/ForwardModeSplit/CMakeLists.txt b/enzyme/test/Enzyme/ForwardModeSplit/CMakeLists.txt new file mode 100644 index 0000000000000..e0678cfefda14 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/CMakeLists.txt @@ -0,0 +1,12 @@ +# Run regression and unit tests +add_lit_testsuite(check-enzyme-forward-split "Running enzyme forward mode regression tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} + ARGS -v +) + +set_target_properties(check-enzyme-forward-split PROPERTIES FOLDER "Tests") + +#add_lit_testsuites(ENZYME ${CMAKE_CURRENT_SOURCE_DIR} + # DEPENDS ${ENZYME_TEST_DEPS} +#) diff --git a/enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erf.ll b/enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erf.ll new file mode 100644 index 0000000000000..9776d95fbf107 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erf.ll @@ -0,0 +1,53 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +declare { double, double } @Faddeeva_erf({ double, double }, double) + +define { double, double } @tester({ double, double } %in) { +entry: + %call = call { double, double } @Faddeeva_erf({ double, double } %in, double 0.000000e+00) + ret { double, double } %call +} + +define { double, double } @test_derivative({ double, double } %x) { +entry: + %0 = tail call { double, double } ({ double, double } ({ double, double })*, ...) @__enzyme_fwdsplit({ double, double } ({ double, double })* @tester, { double, double } %x, { double, double } { double 1.0, double 1.0 }, i8* null) + ret { double, double } %0 +} + +; Function Attrs: nounwind +declare { double, double } @__enzyme_fwdsplit({ double, double } ({ double, double })*, ...) + + +; CHECK: define internal { double, double } @fwddiffetester({ double, double } %in, { double, double } %"in'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = extractvalue { double, double } %in, 0 +; CHECK-NEXT: %1 = extractvalue { double, double } %in, 1 +; CHECK-DAG: %[[a2:.+]] = fmul fast double %0, %0 +; CHECK-DAG: %[[a3:.+]] = fmul fast double %1, %1 +; CHECK-NEXT: %4 = fsub fast double %[[a2]], %[[a3]] +; CHECK-NEXT: %5 = fmul fast double %0, %1 +; CHECK-NEXT: %6 = fadd fast double %5, %5 +; CHECK-NEXT: %7 = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %4 +; CHECK-NEXT: %8 = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %6 +; CHECK-NEXT: %9 = call fast double @llvm.exp.f64(double %7) +; CHECK-NEXT: %10 = call fast double @llvm.cos.f64(double %8) +; CHECK-NEXT: %11 = fmul fast double %9, %10 +; CHECK-NEXT: %12 = call fast double @llvm.sin.f64(double %8) +; CHECK-NEXT: %13 = fmul fast double %9, %12 +; CHECK-NEXT: %14 = fmul fast double %11, 0x3FF20DD750429B6D +; CHECK-NEXT: %15 = insertvalue { double, double } undef, double %14, 0 +; CHECK-NEXT: %16 = fmul fast double %13, 0x3FF20DD750429B6D +; CHECK-NEXT: %17 = insertvalue { double, double } %15, double %16, 1 +; CHECK-NEXT: %18 = extractvalue { double, double } %"in'", 0 +; CHECK-NEXT: %19 = extractvalue { double, double } %"in'", 1 +; CHECK-DAG: %[[a20:.+]] = fmul fast double %14, %18 +; CHECK-DAG: %[[a21:.+]] = fmul fast double %16, %19 +; CHECK-NEXT: %22 = fsub fast double %[[a20]], %[[a21]] +; CHECK-NEXT: %23 = insertvalue { double, double } %17, double %22, 0 +; CHECK-DAG: %[[a24:.+]] = fmul fast double %16, %18 +; CHECK-DAG: %[[a25:.+]] = fmul fast double %14, %19 +; CHECK-NEXT: %26 = fadd fast double %[[a24]], %[[a25]] +; CHECK-NEXT: %27 = insertvalue { double, double } %23, double %26, 1 +; CHECK-NEXT: ret { double, double } %27 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erfc.ll b/enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erfc.ll new file mode 100644 index 0000000000000..3642bef019074 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erfc.ll @@ -0,0 +1,53 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +declare { double, double } @Faddeeva_erfc({ double, double }, double) + +define { double, double } @tester({ double, double } %in) { +entry: + %call = call { double, double } @Faddeeva_erfc({ double, double } %in, double 0.000000e+00) + ret { double, double } %call +} + +define { double, double } @test_derivative({ double, double } %x) { +entry: + %0 = tail call { double, double } ({ double, double } ({ double, double })*, ...) @__enzyme_fwdsplit({ double, double } ({ double, double })* @tester, { double, double } %x, { double, double } { double 1.0, double 1.0 }, i8* null) + ret { double, double } %0 +} + +; Function Attrs: nounwind +declare { double, double } @__enzyme_fwdsplit({ double, double } ({ double, double })*, ...) + + +; CHECK: define internal { double, double } @fwddiffetester({ double, double } %in, { double, double } %"in'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = extractvalue { double, double } %in, 0 +; CHECK-NEXT: %1 = extractvalue { double, double } %in, 1 +; CHECK-DAG: %[[a2:.+]] = fmul fast double %0, %0 +; CHECK-DAG: %[[a3:.+]] = fmul fast double %1, %1 +; CHECK-NEXT: %4 = fsub fast double %[[a2]], %[[a3]] +; CHECK-NEXT: %5 = fmul fast double %0, %1 +; CHECK-NEXT: %6 = fadd fast double %5, %5 +; CHECK-NEXT: %7 = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %4 +; CHECK-NEXT: %8 = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %6 +; CHECK-NEXT: %9 = call fast double @llvm.exp.f64(double %7) +; CHECK-NEXT: %10 = call fast double @llvm.cos.f64(double %8) +; CHECK-NEXT: %11 = fmul fast double %9, %10 +; CHECK-NEXT: %12 = call fast double @llvm.sin.f64(double %8) +; CHECK-NEXT: %13 = fmul fast double %9, %12 +; CHECK-NEXT: %14 = fmul fast double %11, 0xBFF20DD750429B6D +; CHECK-NEXT: %15 = insertvalue { double, double } undef, double %14, 0 +; CHECK-NEXT: %16 = fmul fast double %13, 0xBFF20DD750429B6D +; CHECK-NEXT: %17 = insertvalue { double, double } %15, double %16, 1 +; CHECK-NEXT: %18 = extractvalue { double, double } %"in'", 0 +; CHECK-NEXT: %19 = extractvalue { double, double } %"in'", 1 +; CHECK-DAG: %[[a20:.+]] = fmul fast double %14, %18 +; CHECK-DAG: %[[a21:.+]] = fmul fast double %16, %19 +; CHECK-NEXT: %22 = fsub fast double %[[a20]], %[[a21]] +; CHECK-NEXT: %23 = insertvalue { double, double } %17, double %22, 0 +; CHECK-DAG: %[[a24:.+]] = fmul fast double %16, %18 +; CHECK-DAG: %[[a25:.+]] = fmul fast double %14, %19 +; CHECK-NEXT: %26 = fadd fast double %[[a24]], %[[a25]] +; CHECK-NEXT: %27 = insertvalue { double, double } %23, double %26, 1 +; CHECK-NEXT: ret { double, double } %27 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erfi.ll b/enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erfi.ll new file mode 100644 index 0000000000000..f55f9bf9fb491 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/Faddeeva_erfi.ll @@ -0,0 +1,51 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +declare { double, double } @Faddeeva_erfi({ double, double }, double) + +define { double, double } @tester({ double, double } %in) { +entry: + %call = call { double, double } @Faddeeva_erfi({ double, double } %in, double 0.000000e+00) + ret { double, double } %call +} + +define { double, double } @test_derivative({ double, double } %x) { +entry: + %0 = tail call { double, double } ({ double, double } ({ double, double })*, ...) @__enzyme_fwdsplit({ double, double } ({ double, double })* nonnull @tester, { double, double } %x, { double, double } { double 1.0, double 1.0 }, i8* null) + ret { double, double } %0 +} + +; Function Attrs: nounwind +declare { double, double } @__enzyme_fwdsplit({ double, double } ({ double, double })*, ...) + + +; CHECK: define internal { double, double } @fwddiffetester({ double, double } %in, { double, double } %"in'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = extractvalue { double, double } %in, 0 +; CHECK-NEXT: %1 = extractvalue { double, double } %in, 1 +; CHECK-DAG: %[[a2:.+]] = fmul fast double %0, %0 +; CHECK-DAG: %[[a3:.+]] = fmul fast double %1, %1 +; CHECK-NEXT: %4 = fsub fast double %[[a2]], %[[a3]] +; CHECK-DAG: %[[a5:.+]] = fmul fast double %0, {{(%1|%5)}} +; CHECK-DAG: %[[a6:.+]] = fadd fast double {{(%5|%1)}}, {{(%5|%1)}} +; CHECK-NEXT: %7 = call fast double @llvm.exp.f64(double %4) +; CHECK-NEXT: %8 = call fast double @llvm.cos.f64(double %6) +; CHECK-NEXT: %9 = fmul fast double %7, %8 +; CHECK-NEXT: %10 = call fast double @llvm.sin.f64(double %6) +; CHECK-NEXT: %11 = fmul fast double %7, %10 +; CHECK-NEXT: %12 = fmul fast double %9, 0x3FF20DD750429B6D +; CHECK-NEXT: %13 = insertvalue { double, double } undef, double %12, 0 +; CHECK-NEXT: %14 = fmul fast double %11, 0x3FF20DD750429B6D +; CHECK-NEXT: %15 = insertvalue { double, double } %13, double %14, 1 +; CHECK-NEXT: %16 = extractvalue { double, double } %"in'", 0 +; CHECK-NEXT: %17 = extractvalue { double, double } %"in'", 1 +; CHECK-NEXT: %18 = fmul fast double %14, %17 +; CHECK-NEXT: %19 = fmul fast double %12, %16 +; CHECK-NEXT: %20 = fsub fast double %19, %18 +; CHECK-NEXT: %21 = insertvalue { double, double } %15, double %20, 0 +; CHECK-NEXT: %22 = fmul fast double %12, %17 +; CHECK-NEXT: %23 = fmul fast double %14, %16 +; CHECK-NEXT: %24 = fadd fast double %23, %22 +; CHECK-NEXT: %25 = insertvalue { double, double } %21, double %24, 1 +; CHECK-NEXT: ret { double, double } %25 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/add.ll b/enzyme/test/Enzyme/ForwardModeSplit/add.ll new file mode 100644 index 0000000000000..2935a88fa43a2 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/add.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fadd fast double %x, %y + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwdsplit(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 0.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double, double)*, ...) + +; CHECK: define internal i8* @augmented_tester(double %x, double %"x'", double %y, double %"y'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i8* null +; CHECK-NEXT: } + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fadd fast double %"x'", %"y'" +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/addOneMem.ll b/enzyme/test/Enzyme/ForwardModeSplit/addOneMem.ll new file mode 100644 index 0000000000000..209027f9abe26 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/addOneMem.ll @@ -0,0 +1,46 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -instsimplify -gvn -dse -dse -S | FileCheck %s + +; __attribute__((noinline)) +; void addOneMem(double *x) { +; *x += 1; +; } +; +; void test_derivative(double *x, double *xp) { +; __builtin_autodiff(addOneMem, x, xp); +; } + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local void @addOneMem(double* nocapture %x) { +entry: + %0 = load double, double* %x, align 8, !tbaa !2 + %add = fadd fast double %0, 1.000000e+00 + store double %add, double* %x, align 8, !tbaa !2 + ret void +} + +; Function Attrs: nounwind uwtable +define dso_local void @test_derivative(double* %x, double* %xp, i8* %tape) local_unnamed_addr { +entry: + %0 = tail call double (void (double*)*, ...) @__enzyme_fwdsplit(void (double*)* nonnull @addOneMem, double* %x, double* %xp, i8* %tape) + ret void +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(void (double*)*, ...) + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 7.1.0 "} +!2 = !{!3, !3, i64 0} +!3 = !{!"double", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"} + + +; CHECK: define {{(dso_local )?}}void @test_derivative(double* %x, double* %xp, i8* %tape) +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @free(i8* nonnull %tape) +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/badcall.ll b/enzyme/test/Enzyme/ForwardModeSplit/badcall.ll new file mode 100644 index 0000000000000..e6416ab72ea68 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/badcall.ll @@ -0,0 +1,65 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %arrayidx = getelementptr inbounds double, double* %x, i64 1 + store double 3.000000e+00, double* %arrayidx, align 8 + %0 = load double, double* %x, align 8 + %cmp = fcmp fast oeq double %0, 2.000000e+00 + ret i1 %cmp +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %0 = load double, double* %x, align 8 + %mul = fmul fast double %0, 2.000000e+00 + store double %mul, double* %x, align 8 + %call = tail call zeroext i1 @metasubf(double* %x) + ret i1 %call +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local void @f(double* nocapture %x) #0 { +entry: + %call = tail call zeroext i1 @subf(double* %x) + store double 2.000000e+00, double* %x, align 8 + ret void +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 { +entry: + %call = tail call fast double @__enzyme_fwdsplit(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, double*, double*, i8*) local_unnamed_addr + +attributes #0 = { noinline norecurse nounwind uwtable } +attributes #1 = { noinline nounwind uwtable } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffef(double* nocapture %x, double* nocapture %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: call void @fwddiffesubf(double* %x, double* %"x'") +; CHECK-NEXT: store double 0.000000e+00, double* %"x'" +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffesubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = load double, double* %"x'" +; CHECK-NEXT: %1 = fmul fast double %0, 2.000000e+00 +; CHECK-NEXT: store double %1, double* %"x'" +; CHECK-NEXT: call void @fwddiffemetasubf(double* %x, double* %"x'") +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffemetasubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg" +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/badcall2.ll b/enzyme/test/Enzyme/ForwardModeSplit/badcall2.ll new file mode 100644 index 0000000000000..00515bed2a47e --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/badcall2.ll @@ -0,0 +1,82 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %arrayidx = getelementptr inbounds double, double* %x, i64 1 + store double 3.000000e+00, double* %arrayidx, align 8 + %0 = load double, double* %x, align 8 + %cmp = fcmp fast oeq double %0, 2.000000e+00 + ret i1 %cmp +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @othermetasubf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %arrayidx = getelementptr inbounds double, double* %x, i64 1 + store double 4.000000e+00, double* %arrayidx, align 8 + %0 = load double, double* %x, align 8 + %cmp = fcmp fast oeq double %0, 3.000000e+00 + ret i1 %cmp +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %0 = load double, double* %x, align 8 + %mul = fmul fast double %0, 2.000000e+00 + store double %mul, double* %x, align 8 + %call = tail call zeroext i1 @metasubf(double* %x) + %call1 = tail call zeroext i1 @othermetasubf(double* %x) + ret i1 %call1 +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local void @f(double* nocapture %x) #0 { +entry: + %call = tail call zeroext i1 @subf(double* %x) + store double 2.000000e+00, double* %x, align 8 + ret void +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 { +entry: + %call = tail call fast double @__enzyme_fwdsplit(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, double*, double*, i8*) local_unnamed_addr + + +; CHECK: define internal {{(dso_local )?}}void @fwddiffef(double* nocapture %x, double* nocapture %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: call void @fwddiffesubf(double* %x, double* %"x'") +; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffesubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = load double, double* %"x'" +; CHECK-NEXT: %1 = fmul fast double %0, 2.000000e+00 +; CHECK-NEXT: store double %1, double* %"x'", align 8 +; CHECK-NEXT: call void @fwddiffemetasubf(double* %x, double* %"x'") +; CHECK-NEXT: call void @fwddiffeothermetasubf(double* %x, double* %"x'") +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffemetasubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffeothermetasubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/badcall3.ll b/enzyme/test/Enzyme/ForwardModeSplit/badcall3.ll new file mode 100644 index 0000000000000..1464e977d97f6 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/badcall3.ll @@ -0,0 +1,83 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %arrayidx = getelementptr inbounds double, double* %x, i64 1 + store double 3.000000e+00, double* %arrayidx, align 8 + %0 = load double, double* %x, align 8 + %cmp = fcmp fast oeq double %0, 2.000000e+00 + ret i1 %cmp +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @othermetasubf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %arrayidx = getelementptr inbounds double, double* %x, i64 1 + store double 4.000000e+00, double* %arrayidx, align 8 + %0 = load double, double* %x, align 8 + %cmp = fcmp fast oeq double %0, 3.000000e+00 + ret i1 %cmp +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local void @subf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %0 = load double, double* %x, align 8 + %mul = fmul fast double %0, 2.000000e+00 + store double %mul, double* %x, align 8 + %call = tail call zeroext i1 @metasubf(double* %x) + %call1 = tail call zeroext i1 @othermetasubf(double* %x) + ret void +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local void @f(double* nocapture %x) #0 { +entry: + tail call void @subf(double* %x) + store double 2.000000e+00, double* %x, align 8 + ret void +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 { +entry: + %call = tail call fast double @__enzyme_fwdsplit(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, double*, double*, i8*) local_unnamed_addr + +; CHECK: define internal {{(dso_local )?}}void @fwddiffef(double* nocapture %x, double* nocapture %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: call void @fwddiffesubf(double* %x, double* %"x'") +; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } + + +; CHECK: define internal {{(dso_local )?}}void @fwddiffesubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = load double, double* %"x'" +; CHECK-NEXT: %1 = fmul fast double %0, 2.000000e+00 +; CHECK-NEXT: store double %1, double* %"x'", align 8 +; CHECK-NEXT: call void @fwddiffemetasubf(double* %x, double* %"x'") +; CHECK-NEXT: call void @fwddiffeothermetasubf(double* %x, double* %"x'") +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffemetasubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffeothermetasubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardModeSplit/badcall4.ll b/enzyme/test/Enzyme/ForwardModeSplit/badcall4.ll new file mode 100644 index 0000000000000..02d7942f3538e --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/badcall4.ll @@ -0,0 +1,83 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %arrayidx = getelementptr inbounds double, double* %x, i64 1 + store double 3.000000e+00, double* %arrayidx, align 8 + %0 = load double, double* %x, align 8 + %cmp = fcmp fast oeq double %0, 2.000000e+00 + ret i1 %cmp +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @othermetasubf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %arrayidx = getelementptr inbounds double, double* %x, i64 1 + store double 4.000000e+00, double* %arrayidx, align 8 + %0 = load double, double* %x, align 8 + %cmp = fcmp fast oeq double %0, 3.000000e+00 + ret i1 %cmp +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %0 = load double, double* %x, align 8 + %mul = fmul fast double %0, 2.000000e+00 + store double %mul, double* %x, align 8 + %call = tail call zeroext i1 @metasubf(double* %x) + %call1 = tail call zeroext i1 @othermetasubf(double* %x) + %res = and i1 %call, %call1 + ret i1 %res +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local void @f(double* nocapture %x) #0 { +entry: + %call = tail call zeroext i1 @subf(double* %x) + store double 2.000000e+00, double* %x, align 8 + ret void +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 { +entry: + %call = tail call fast double @__enzyme_fwdsplit(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, double*, double*, i8*) local_unnamed_addr + + +; CHECK: define internal {{(dso_local )?}}void @fwddiffef(double* nocapture %x, double* nocapture %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: call void @fwddiffesubf(double* %x, double* %"x'") +; CHECK-NEXT: store double 0.000000e+00, double* %"x'" +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffesubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = load double, double* %"x'" +; CHECK-NEXT: %1 = fmul fast double %0, 2.000000e+00 +; CHECK-NEXT: store double %1, double* %"x'", align 8 +; CHECK-NEXT: call void @fwddiffemetasubf(double* %x, double* %"x'") +; CHECK-NEXT: call void @fwddiffeothermetasubf(double* %x, double* %"x'") +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffemetasubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffeothermetasubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/badcallsq.ll b/enzyme/test/Enzyme/ForwardModeSplit/badcallsq.ll new file mode 100644 index 0000000000000..86b3364973ff5 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/badcallsq.ll @@ -0,0 +1,71 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %arrayidx = getelementptr inbounds double, double* %x, i64 1 + store double 3.000000e+00, double* %arrayidx, align 8 + %0 = load double, double* %x, align 8 + %cmp = fcmp fast oeq double %0, 2.000000e+00 + ret i1 %cmp +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %0 = load double, double* %x, align 8 + %mul = fmul fast double %0, %0 + store double %mul, double* %x, align 8 + %call = tail call zeroext i1 @metasubf(double* %x) + ret i1 %call +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local void @f(double* nocapture %x) #0 { +entry: + %call = tail call zeroext i1 @subf(double* %x) + store double 2.000000e+00, double* %x, align 8 + ret void +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 { +entry: + %call = tail call fast double @__enzyme_fwdsplit(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, double*, double*, i8*) local_unnamed_addr + +attributes #0 = { noinline norecurse nounwind uwtable } +attributes #1 = { noinline nounwind uwtable } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffef(double* nocapture %x, double* nocapture %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to double* +; CHECK-NEXT: %tapeArg1 = load double, double* %0 +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: call void @fwddiffesubf(double* %x, double* %"x'", double %tapeArg1) +; CHECK-NEXT: store double 0.000000e+00, double* %"x'" +; CHECK-NEXT: ret void +; CHECK-NEXT: } + + +; CHECK: define internal {{(dso_local )?}}void @fwddiffesubf(double* nocapture %x, double* nocapture %"x'", double +; CHECK-NEXT: entry: +; CHECK-NEXT: %1 = load double, double* %"x'" +; CHECK-NEXT: %2 = fmul fast double %1, %0 +; CHECK-NEXT: %3 = fmul fast double %1, %0 +; CHECK-NEXT: %4 = fadd fast double %2, %3 +; CHECK-NEXT: store double %4, double* %"x'", align 8 +; CHECK-NEXT: call void @fwddiffemetasubf(double* %x, double* %"x'") +; CHECK-NEXT: ret void +; CHECK-NEXT: } + + +; CHECK: define internal {{(dso_local )?}}void @fwddiffemetasubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/badcallused.ll b/enzyme/test/Enzyme/ForwardModeSplit/badcallused.ll new file mode 100644 index 0000000000000..8a925c50a1fe5 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/badcallused.ll @@ -0,0 +1,67 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %arrayidx = getelementptr inbounds double, double* %x, i64 1 + store double 3.000000e+00, double* %arrayidx, align 8 + %0 = load double, double* %x, align 8 + %cmp = fcmp fast oeq double %0, 2.000000e+00 + ret i1 %cmp +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %0 = load double, double* %x, align 8 + %mul = fmul fast double %0, 2.000000e+00 + store double %mul, double* %x, align 8 + %call = tail call zeroext i1 @metasubf(double* %x) + ret i1 %call +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local void @f(double* nocapture %x) #0 { +entry: + %call = tail call zeroext i1 @subf(double* %x) + %sel = select i1 %call, double 2.000000e+00, double 3.000000e+00 + store double %sel, double* %x, align 8 + ret void +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 { +entry: + %call = tail call fast double @__enzyme_fwdsplit(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, double*, double*, i8*) local_unnamed_addr + +attributes #0 = { noinline norecurse nounwind uwtable } +attributes #1 = { noinline nounwind uwtable } + + +; CHECK: define internal {{(dso_local )?}}void @fwddiffef(double* nocapture %x, double* nocapture %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: call void @fwddiffesubf(double* %x, double* %"x'") +; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffesubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = load double, double* %"x'" +; CHECK-NEXT: %1 = fmul fast double %0, 2.000000e+00 +; CHECK-NEXT: store double %1, double* %"x'", align 8 +; CHECK-NEXT: call void @fwddiffemetasubf(double* %x, double* %"x'") +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffemetasubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: ret +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/badcallused2.ll b/enzyme/test/Enzyme/ForwardModeSplit/badcallused2.ll new file mode 100644 index 0000000000000..09c6e78b8ed6e --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/badcallused2.ll @@ -0,0 +1,84 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %arrayidx = getelementptr inbounds double, double* %x, i64 1 + store double 3.000000e+00, double* %arrayidx, align 8 + %0 = load double, double* %x, align 8 + %cmp = fcmp fast oeq double %0, 2.000000e+00 + ret i1 %cmp +} + +define dso_local zeroext i1 @omegasubf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %arrayidx = getelementptr inbounds double, double* %x, i64 1 + store double 3.000000e+00, double* %arrayidx, align 8 + %0 = load double, double* %x, align 8 + %cmp = fcmp fast oeq double %0, 2.000000e+00 + ret i1 %cmp +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 { +entry: + %0 = load double, double* %x, align 8 + %mul = fmul fast double %0, 2.000000e+00 + store double %mul, double* %x, align 8 + %call = tail call zeroext i1 @omegasubf(double* %x) + %call2 = tail call zeroext i1 @metasubf(double* %x) + ret i1 %call2 +} + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local void @f(double* nocapture %x) #0 { +entry: + %call = tail call zeroext i1 @subf(double* %x) + %sel = select i1 %call, double 2.000000e+00, double 3.000000e+00 + store double %sel, double* %x, align 8 + ret void +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 { +entry: + %call = tail call fast double @__enzyme_fwdsplit(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, double*, double*, i8*) local_unnamed_addr + +attributes #0 = { noinline norecurse nounwind uwtable } +attributes #1 = { noinline nounwind uwtable } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffef(double* nocapture %x, double* nocapture %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: call void @fwddiffesubf(double* %x, double* %"x'") +; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffesubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = load double, double* %"x'" +; CHECK-NEXT: %1 = fmul fast double %0, 2.000000e+00 +; CHECK-NEXT: store double %1, double* %"x'", align 8 +; CHECK-NEXT: call void @fwddiffeomegasubf(double* %x, double* %"x'") +; CHECK-NEXT: call void @fwddiffemetasubf(double* %x, double* %"x'") +; CHECK-NEXT: ret +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffeomegasubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}void @fwddiffemetasubf(double* nocapture %x, double* nocapture %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/bitcast.ll b/enzyme/test/Enzyme/ForwardModeSplit/bitcast.ll new file mode 100644 index 0000000000000..b5bb158527753 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/bitcast.ll @@ -0,0 +1,22 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -adce -instsimplify -S | FileCheck %s + +define double @tester(double %x) { +entry: + %y = bitcast double %x to i64 + %z = bitcast i64 %y to double + ret double %z +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: ret double %"x'" +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/bsearch.ll b/enzyme/test/Enzyme/ForwardModeSplit/bsearch.ll new file mode 100644 index 0000000000000..17d26dcf4128d --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/bsearch.ll @@ -0,0 +1,77 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -gvn -instsimplify -correlated-propagation -adce -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline norecurse nounwind uwtable +define double @f(double* nocapture %x, i64 %n) #0 { +entry: + br label %loop + +loop: + %j = phi i64 [ %nj, %end ], [ 0, %entry ] + %sum = phi double [ %nsum, %end ], [ 0.000000e+00, %entry ] + %nj = add nsw nuw i64 %j, 1 + %g0 = getelementptr inbounds double, double* %x, i64 %j + br label %body + +body: ; preds = %entry, %for.cond.cleanup6 + %i = phi i64 [ %next, %body ], [ 0, %loop ] + %gep = getelementptr inbounds double, double* %g0, i64 %i + %ld = load double, double* %gep, align 8 + %cmp = fcmp oeq double %ld, 3.141592e+00 + %next = add nuw i64 %i, 1 + br i1 %cmp, label %body, label %end + +end: + %gep2 = getelementptr inbounds double, double* %x, i64 %i + %ld2 = load double, double* %gep2, align 8 + %nsum = fadd double %ld2, %sum + %cmp2 = icmp ne i64 %nj, 10 + br i1 %cmp2, label %loop, label %exit + +exit: + ret double %nsum +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp, i64 %n) local_unnamed_addr #1 { +entry: + %call = tail call double (...) @__enzyme_fwdsplit(i8* bitcast (double (double*, i64)* @f to i8*), metadata !"enzyme_nofree", double* %x, double* %xp, i64 %n, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(...) + +attributes #0 = { noinline norecurse nounwind uwtable } +attributes #1 = { noinline nounwind uwtable } + + +; CHECK: define internal double @fwddiffef(double* nocapture %x, double* nocapture %"x'", i64 %n, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to i1*** +; CHECK-NEXT: %truetape = load i1**, i1*** %0 +; CHECK-NEXT: br label %loop + +; CHECK: loop: ; preds = %end, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %end ], [ 0, %entry ] +; CHECK-NEXT: %"sum'" = phi {{(fast )?}}double [ %4, %end ], [ 0.000000e+00, %entry ] +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %1 = getelementptr inbounds i1*, i1** %truetape, i64 %iv +; CHECK-NEXT: %.pre = load i1*, i1** %1, align 8, !invariant.group !1 +; CHECK-NEXT: br label %body + +; CHECK: body: ; preds = %body, %loop +; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %body ], [ 0, %loop ] +; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1 +; CHECK-NEXT: %2 = getelementptr inbounds i1, i1* %.pre, i64 %iv1 +; CHECK-NEXT: %cmp = load i1, i1* %2, align 1, !invariant.group !2 +; CHECK-NEXT: br i1 %cmp, label %body, label %end + +; CHECK: end: ; preds = %body +; CHECK-NEXT: %"gep2'ipg" = getelementptr inbounds double, double* %"x'", i64 %iv1 +; CHECK-NEXT: %3 = load double, double* %"gep2'ipg" +; CHECK-NEXT: %4 = fadd fast double %3, %"sum'" +; CHECK-NEXT: %cmp2 = icmp ne i64 %iv.next, 10 +; CHECK-NEXT: br i1 %cmp2, label %loop, label %exit + +; CHECK: exit: ; preds = %end +; CHECK-NEXT: ret double %4 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/bsearch2.ll b/enzyme/test/Enzyme/ForwardModeSplit/bsearch2.ll new file mode 100644 index 0000000000000..22ed11dee19d8 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/bsearch2.ll @@ -0,0 +1,87 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -gvn -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline norecurse nounwind uwtable +define double @f(double* nocapture %x, i64 %n) #0 { +entry: + br label %loop + +loop: + %j = phi i64 [ %nj, %end ], [ 0, %entry ] + %sum = phi double [ %nsum, %end ], [ 0.000000e+00, %entry ] + %nj = add nsw nuw i64 %j, 1 + %g0 = getelementptr inbounds double, double* %x, i64 %j + br label %body + +body: ; preds = %entry, %for.cond.cleanup6 + %i = phi i64 [ %next, %body ], [ 0, %loop ] + %idx = phi i64 [ %nidx, %body ], [ 0, %loop ] + %gep = getelementptr inbounds double, double* %g0, i64 %i + %ld = load double, double* %gep, align 8 + %cmp = fcmp oeq double %ld, 3.141592e+00 + %next = add nuw i64 %i, 1 + %int = fptoui double %ld to i64 + %nidx = add nuw i64 %idx, %int + br i1 %cmp, label %body, label %end + +end: + %gep2 = getelementptr inbounds double, double* %x, i64 %idx + %ld2 = load double, double* %gep2, align 8 + %nsum = fadd double %ld2, %sum + %cmp2 = icmp ne i64 %nj, 10 + br i1 %cmp2, label %loop, label %exit + +exit: + ret double %nsum +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp, i64 %n) local_unnamed_addr #1 { +entry: + %call = tail call double (...) @__enzyme_fwdsplit(i8* bitcast (double (double*, i64)* @f to i8*), metadata !"enzyme_nofree", double* %x, double* %xp, i64 %n, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(...) + +attributes #0 = { noinline norecurse nounwind uwtable } +attributes #1 = { noinline nounwind uwtable } + + +; CHECK: define internal double @fwddiffef(double* nocapture %x, double* nocapture %"x'", i64 %n, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { i64**, double** }* +; CHECK-NEXT: %truetape = load { i64**, double** }, { i64**, double** }* %0 +; CHECK-NEXT: %1 = extractvalue { i64**, double** } %truetape, 0 +; CHECK-NEXT: %2 = extractvalue { i64**, double** } %truetape, 1 +; CHECK-NEXT: br label %loop + +; CHECK: loop: ; preds = %end, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %end ], [ 0, %entry ] +; CHECK-NEXT: %"sum'" = phi {{(fast )?}}double [ %8, %end ], [ 0.000000e+00, %entry ] +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %3 = getelementptr inbounds i64*, i64** %1, i64 %iv +; CHECK-NEXT: %4 = getelementptr inbounds double*, double** %2, i64 %iv +; CHECK-NEXT: %.pre = load i64*, i64** %3, align 8, !invariant.group !1 +; CHECK-NEXT: %.pre2 = load double*, double** %4, align 8, !invariant.group !2 +; CHECK-NEXT: br label %body + +; CHECK: body: ; preds = %body, %loop +; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %body ], [ 0, %loop ] +; CHECK-NEXT: %5 = getelementptr inbounds i64, i64* %.pre, i64 %iv1 +; CHECK-NEXT: %idx = load i64, i64* %5, align 8, !invariant.group !3 +; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1 +; CHECK-NEXT: %6 = getelementptr inbounds double, double* %.pre2, i64 %iv1 +; CHECK-NEXT: %ld = load double, double* %6, align 8, !invariant.group !4 +; CHECK-NEXT: %cmp = fcmp oeq double %ld, 0x400921FAFC8B007A +; CHECK-NEXT: br i1 %cmp, label %body, label %end + +; CHECK: end: ; preds = %body +; CHECK-NEXT: %"gep2'ipg" = getelementptr inbounds double, double* %"x'", i64 %idx +; CHECK-NEXT: %7 = load double, double* %"gep2'ipg" +; CHECK-NEXT: %8 = fadd fast double %7, %"sum'" +; CHECK-NEXT: %cmp2 = icmp ne i64 %iv.next, 10 +; CHECK-NEXT: br i1 %cmp2, label %loop, label %exit + +; CHECK: exit: ; preds = %end +; CHECK-NEXT: ret double %8 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/call.ll b/enzyme/test/Enzyme/ForwardModeSplit/call.ll new file mode 100644 index 0000000000000..c5841b9fe3750 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/call.ll @@ -0,0 +1,53 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; extern double __enzyme_fwdsplit(double (double), double, double); + +; __attribute__((noinline)) +; double add2(double x) { +; return 2 + x; +; } + +; __attribute__((noinline)) +; double add4(double x) { +; return add2(x) + 2; +; } + +; double dadd4(double x) { +; return __enzyme_fwdsplit(add4, x, 1.0); +; } + + +define dso_local double @add2(double %x) { +entry: + %add = fadd double %x, 2.000000e+00 + ret double %add +} + +define dso_local double @add4(double %x) { +entry: + %call = call double @add2(double %x) + %add = fadd double %call, 2.000000e+00 + ret double %add +} + +define dso_local double @dadd4(double %x) { +entry: + %call = call double @__enzyme_fwdsplit(double (double)* nonnull @add4, double %x, double 1.000000e+00, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(double (double)*, double, double, i8*) + + + +; CHECK: define internal {{(dso_local )?}}double @fwddiffeadd4(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = call fast double @fwddiffeadd2(double %x, double %"x'") +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}double @fwddiffeadd2(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret double %"x'" +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/callmincacheunwrap.ll b/enzyme/test/Enzyme/ForwardModeSplit/callmincacheunwrap.ll new file mode 100644 index 0000000000000..50fc649b79d49 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/callmincacheunwrap.ll @@ -0,0 +1,99 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -adce -S | FileCheck %s + +source_filename = "/mnt/pci4/wmdata/Enzyme2/enzyme/test/Integration/ReverseMode/eigensumsqdyn.cpp" +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +$_ZNK5Eigen9EigenBaseINS_6MatrixIdLin1ELin1ELi0ELin1ELin1EEEE4colsEv = comdat any + +define void @caller(i8* %a, i8* %b, i8* %c) local_unnamed_addr { +entry: + call void (...) @__enzyme_fwdsplit(i8* bitcast (void (double**, i64*)* @_ZL6matvecPKN5Eigen6MatrixIdLin1ELin1ELi0ELin1ELin1EEES3_ to i8*), metadata !"enzyme_nofree", i8* %a, i8* %b, i8* %c, i8* null) + ret void +} + +declare void @__enzyme_fwdsplit(...) + +; Function Attrs: noinline nounwind uwtable +define internal void @_ZL6matvecPKN5Eigen6MatrixIdLin1ELin1ELi0ELin1ELin1EEES3_(double** noalias %m_data.i.i.i, i64* noalias %m_rows) #0 { +entry: + call void @subcall(double** nonnull %m_data.i.i.i, i64* nonnull %m_rows) #3 + store i64 0, i64* %m_rows, align 8 + ret void +} + +; Function Attrs: inlinehint norecurse nounwind uwtable +define linkonce_odr dso_local i64 @_ZNK5Eigen9EigenBaseINS_6MatrixIdLin1ELin1ELi0ELin1ELin1EEEE4colsEv(i64* %m_cols) local_unnamed_addr #1 comdat align 2 { +entry: + %tmp.i.i = load i64, i64* %m_cols, align 8, !tbaa !2 + ret i64 %tmp.i.i +} + +; Function Attrs: nounwind uwtable +define void @subcall(double** %m_data.i.i.i, i64* %tmp7) local_unnamed_addr #2 { +entry: + %mat = load double*, double** %m_data.i.i.i, align 8, !tbaa !8 + %cols = call i64 @_ZNK5Eigen9EigenBaseINS_6MatrixIdLin1ELin1ELi0ELin1ELin1EEEE4colsEv(i64* %tmp7) #3 + br label %for.body + +for.body: ; preds = %for.body, %entry + %i = phi i64 [ %inc, %for.body ], [ 0, %entry ] + %call = getelementptr inbounds double, double* %mat, i64 %i + %ld = load double, double* %call, align 8 + %fmul = fmul double %ld, %ld + store double %fmul, double* %call, align 8, !tbaa !11 + %inc = add nuw nsw i64 %i, 1 + %exitcond = icmp eq i64 %inc, %cols + br i1 %exitcond, label %exit, label %for.body + +exit: ; preds = %for.body + ret void +} + +attributes #0 = { noinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { inlinehint norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #3 = { nounwind } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 7.1.0 "} +!2 = !{!3, !7, i64 16} +!3 = !{!"_ZTSN5Eigen12DenseStorageIdLin1ELin1ELin1ELi0EEE", !4, i64 0, !7, i64 8, !7, i64 16} +!4 = !{!"any pointer", !5, i64 0} +!5 = !{!"omnipotent char", !6, i64 0} +!6 = !{!"Simple C++ TBAA"} +!7 = !{!"long", !5, i64 0} +!8 = !{!9, !4, i64 0} +!9 = !{!"_ZTSN5Eigen8internal9evaluatorINS_15PlainObjectBaseINS_6MatrixIdLin1ELin1ELi0ELin1ELin1EEEEEEE", !4, i64 0, !10, i64 8} +!10 = !{!"_ZTSN5Eigen8internal19variable_if_dynamicIlLin1EEE", !7, i64 0} +!11 = !{!12, !12, i64 0} +!12 = !{!"double", !5, i64 0} + + +; CHECK: define internal void @fwddiffesubcall(double** %m_data.i.i.i, double** %"m_data.i.i.i'", i64* %tmp7, { i64, double*, double* } %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = extractvalue { i64, double*, double* } %tapeArg, 2 +; CHECK-NEXT: %"mat'il_phi" = extractvalue { i64, double*, double* } %tapeArg, 1 +; CHECK-NEXT: %cols = extractvalue { i64, double*, double* } %tapeArg, 0 +; CHECK-NEXT: br label %for.body + +; CHECK: for.body: ; preds = %for.body, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.body ], [ 0, %entry ] +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %"call'ipg" = getelementptr inbounds double, double* %"mat'il_phi", i64 %iv +; CHECK-NEXT: %1 = getelementptr inbounds double, double* %0, i64 %iv +; CHECK-NEXT: %ld = load double, double* %1, align 8, !invariant.group !15 +; CHECK-NEXT: %2 = load double, double* %"call'ipg", align 8 +; CHECK-NEXT: %3 = fmul fast double %2, %ld +; CHECK-NEXT: %4 = fmul fast double %2, %ld +; CHECK-NEXT: %5 = fadd fast double %3, %4 +; CHECK-NEXT: store double %5, double* %"call'ipg", align 8 +; CHECK-NEXT: %exitcond = icmp eq i64 %iv.next, %cols +; CHECK-NEXT: br i1 %exitcond, label %exit, label %for.body + +; CHECK: exit: ; preds = %for.body +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/callmod.ll b/enzyme/test/Enzyme/ForwardModeSplit/callmod.ll new file mode 100644 index 0000000000000..59e438e7dc010 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/callmod.ll @@ -0,0 +1,116 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -S | FileCheck %s + +; extern double readDouble(); +; +; __attribute__((noinline)) +; double sub(double x) { +; return x * readDouble(); +; } +; +; double read2(); +; +; __attribute__((noinline)) +; double foo(double x) { +; double res = sub(x); +; +; return res + read2(); +; } +; +; double dsumsquare(double x) { +; return __builtin_autodiff(foo, x); +; } + +@.str = private unnamed_addr constant [4 x i8] c"%f\0A\00", align 1 +@.str.1 = private unnamed_addr constant [5 x i8] c"%f \0A\00", align 1 + +; Function Attrs: noinline nounwind uwtable +define dso_local double @readDouble() local_unnamed_addr #0 { +entry: + %x = alloca double, align 8 + %0 = bitcast double* %x to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %0) #4 + %call = call i32 (i8*, ...) @scanf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i64 0, i64 0), double* nonnull %x) + %1 = load double, double* %x, align 8, !tbaa !2 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %0) #4 + ret double %1 +} + +; Function Attrs: argmemonly nounwind +declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #1 + +; Function Attrs: nounwind +declare dso_local i32 @scanf(i8* nocapture readonly, ...) local_unnamed_addr #2 + +; Function Attrs: argmemonly nounwind +declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #1 + +; Function Attrs: noinline nounwind uwtable +define dso_local double @sub(double %x) local_unnamed_addr #0 { +entry: + %call = tail call fast double @readDouble() + %mul = fmul fast double %call, %x + ret double %mul +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @read2() local_unnamed_addr #0 { +entry: + %x = alloca double, align 8 + %0 = bitcast double* %x to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %0) #4 + %call = call i32 (i8*, ...) @scanf(i8* getelementptr inbounds ([5 x i8], [5 x i8]* @.str.1, i64 0, i64 0), double* nonnull %x) + %1 = load double, double* %x, align 8, !tbaa !2 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %0) #4 + ret double %1 +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @foo(double %x) #0 { +entry: + %call = tail call fast double @sub(double %x) + %call1 = tail call fast double @read2() + %add = fadd fast double %call1, %call + ret double %add +} + +; Function Attrs: nounwind uwtable +define dso_local double @dsumsquare(double %x) local_unnamed_addr #3 { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @foo, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) #4 + +attributes #0 = { noinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #1 = { argmemonly nounwind } +attributes #2 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #3 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #4 = { nounwind } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 7.1.0 "} +!2 = !{!3, !3, i64 0} +!3 = !{!"double", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"} + +; CHECK: define internal {{(dso_local )?}}double @fwddiffefoo(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to double* +; CHECK-NEXT: %tapeArg1 = load double, double* %0, +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %1 = call fast double @fwddiffesub(double %x, double %"x'", double %tapeArg1) +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } + + +; CHECK: define internal {{(dso_local )?}}double @fwddiffesub(double %x, double %"x'", double %call) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %"x'", %call +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/calloc.ll b/enzyme/test/Enzyme/ForwardModeSplit/calloc.ll new file mode 100644 index 0000000000000..ca566106f10fb --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/calloc.ll @@ -0,0 +1,46 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + + +@enzyme_dupnoneed = dso_local global i32 0, align 4 + +define dso_local double @f(double %x, i64 %arg) { +entry: + %call = call noalias i8* @calloc(i64 8, i64 %arg) + %0 = bitcast i8* %call to double* + store double %x, double* %0, align 8 + %1 = load double, double* %0, align 8 + ret double %1 +} + +declare dso_local noalias i8* @calloc(i64, i64) + +define dso_local double @df(double %x) { +entry: + %x.addr = alloca double, align 8 + store double %x, double* %x.addr, align 8 + %0 = load i32, i32* @enzyme_dupnoneed, align 4 + %1 = load double, double* %x.addr, align 8 + %call = call double (i8*, ...) @__enzyme_fwdsplit(i8* bitcast (double (double,i64)* @f to i8*), metadata !"enzyme_nofree", i32 %0, double %1, double 1.000000e+00, i64 1, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, ...) + +; TODO check correctness + +; CHECK: define internal i8* @augmented_f(double %x, double %"x'", i64 %arg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %call = call noalias i8* @calloc(i64 8, i64 %arg) +; CHECK-NEXT: %0 = bitcast i8* %call to double* +; CHECK-NEXT: store double %x, double* %0, align 8 +; CHECK-NEXT: ret i8* null +; CHECK-NEXT: } + +; CHECK: define internal double @fwddiffef(double %x, double %"x'", i64 %arg, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK: %"call'mi" = call noalias nonnull i8* @calloc(i64 8, i64 %arg) +; CHECK-NEXT: %"'ipc" = bitcast i8* %"call'mi" to double* +; CHECK-NEXT: store double %"x'", double* %"'ipc", align 8 +; CHECK-NEXT: %0 = load double, double* %"'ipc", align 8 +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/constant.ll b/enzyme/test/Enzyme/ForwardModeSplit/constant.ll new file mode 100644 index 0000000000000..8d267ba6ae0b6 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/constant.ll @@ -0,0 +1,22 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x) { +entry: + ret double 1.000000e+00 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: ret double 0.000000e+00 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/constselect.ll b/enzyme/test/Enzyme/ForwardModeSplit/constselect.ll new file mode 100644 index 0000000000000..3359916d7ed51 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/constselect.ll @@ -0,0 +1,49 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -S | FileCheck %s + +; ModuleID = 'inp.c' +source_filename = "inp.c" +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@.str = private unnamed_addr constant [20 x i8] c"dfun/dx = %f, x=%d\0A\00", align 1 + +; Function Attrs: norecurse nounwind readnone uwtable +define double @fun2(double %x) { +entry: + %cmp.inv = fcmp oge double %x, 0.000000e+00 + %.x = select i1 %cmp.inv, double %x, double 0.000000e+00 + ret double %.x +} + +; Function Attrs: nounwind uwtable +define i32 @main() { +entry: + %call3.4 = tail call double (i8*, ...) @__enzyme_fwdsplit(i8* bitcast (double (double)* @fun2 to i8*), double 2.000000e+00, double 1.0, i8* null) + %call4.4 = tail call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([20 x i8], [20 x i8]* @.str, i64 0, i64 0), double %call3.4, i32 2) + ret i32 0 +} + +; Function Attrs: nounwind +declare dso_local i32 @printf(i8* nocapture readonly, ...) + +declare double @__enzyme_fwdsplit(i8*, ...) + +attributes #0 = { norecurse nounwind readnone uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #2 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #3 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #4 = { nounwind } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 7.1.0 "} + +; CHECK: define internal double @fwddiffefun2(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %cmp.inv = fcmp oge double %x, 0.000000e+00 +; CHECK-NEXT: %0 = select{{( fast)?}} i1 %cmp.inv, double %"x'", double 0.000000e+00 +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/cos.ll b/enzyme/test/Enzyme/ForwardModeSplit/cos.ll new file mode 100644 index 0000000000000..fcaa94f6c5e58 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/cos.ll @@ -0,0 +1,30 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -O3 -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @llvm.cos.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.cos.f64(double) + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.sin.f64(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define double @test_derivative(double %x) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = tail call fast double @llvm.sin.f64(double %x) +; CHECK-NEXT: %1 = {{(fsub fast double -0.000000e\+00,|fneg fast double)}} %0 +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/cosh.ll b/enzyme/test/Enzyme/ForwardModeSplit/cosh.ll new file mode 100644 index 0000000000000..8731cef37102c --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/cosh.ll @@ -0,0 +1,28 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @cosh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @cosh(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = call fast double @sinh(double %x) +; CHECK-NEXT: %1 = fmul fast double %"x'", %0 +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/custom0.ll b/enzyme/test/Enzyme/ForwardModeSplit/custom0.ll new file mode 100644 index 0000000000000..44e7b41933d06 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/custom0.ll @@ -0,0 +1,55 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s + +source_filename = "exer2.c" +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__enzyme_register_splitderivative_add = dso_local local_unnamed_addr global [3 x i8*] [i8* bitcast (double (double, double)* @add to i8*), i8* bitcast ({ i8*, double, double } (double, double)* @add_aug to i8*), i8* bitcast ({ double, double } (double, double, double, double, i8*)* @add_err to i8*)], align 16 + +declare double @add(double %x, double %y) #0 + +declare { i8*, double, double } @add_aug(double %v1, double %v2) + +declare { double, double } @add_err(double %v1, double %v1err, double %v2, double %v2err, i8* %tape) + +; Function Attrs: norecurse nounwind readnone uwtable willreturn +define double @f(double %x) { +entry: + %call = call double @add(double %x, double %x) + ret double %call +} + +; Function Attrs: nounwind uwtable +define double @caller(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_fwdsplit(i8* bitcast (double (double)* @f to i8*), double %x, double %dx, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, ...) + +attributes #0 = { norecurse nounwind readnone } + + +; CHECK: define internal i8* @augmented_f(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %malloccall = tail call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8) +; CHECK-NEXT: %tapemem = bitcast i8* %malloccall to i8** +; CHECK-NEXT: %call_augmented = call { i8*, double, double } @add_aug(double %x, double %x) +; CHECK-NEXT: %subcache = extractvalue { i8*, double, double } %call_augmented, 0 +; CHECK-NEXT: store i8* %subcache, i8** %tapemem +; CHECK-NEXT: ret i8* %malloccall +; CHECK-NEXT: } + +; CHECK: define internal double @fwddiffef(double %x, double %"x'", i8* %tapeArg1) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @fixderivative_add(double %x, double %"x'", double %x, double %"x'", i8* %tapeArg1) +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } + +; CHECK: define internal double @fixderivative_add(double %v1, double %v1err, double %v2, double %v2err, i8* %tape) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call { double, double } @add_err(double %v1, double %v1err, double %v2, double %v2err, i8* %tape) +; CHECK-NEXT: %1 = extractvalue { double, double } %0, 1 +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/custom1.ll b/enzyme/test/Enzyme/ForwardModeSplit/custom1.ll new file mode 100644 index 0000000000000..01c7e8e33fcb4 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/custom1.ll @@ -0,0 +1,60 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s + +source_filename = "exer2.c" +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__enzyme_register_splitderivative_add = dso_local local_unnamed_addr global [3 x i8*] [i8* bitcast (double (double, double)* @add to i8*), i8* bitcast ({ i8*, double, double } (double, double)* @add_aug to i8*), i8* bitcast ({ double, double } (double, double, double, double, i8*)* @add_err to i8*)], align 16 + +declare double @add(double %x, double %y) #0 + +declare { i8*, double, double } @add_aug(double %v1, double %v2) + +declare { double, double } @add_err(double %v1, double %v1err, double %v2, double %v2err, i8* %tape) + +; Function Attrs: norecurse nounwind readnone uwtable willreturn +define double @f(double %x) { +entry: + %call = call double @add(double %x, double %x) + %mul = fmul double %call, %call + ret double %mul +} + +; Function Attrs: nounwind uwtable +define double @caller(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_fwdsplit(i8* bitcast (double (double)* @f to i8*), double %x, double %dx, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, ...) + +attributes #0 = { norecurse nounwind readnone } + +; CHECK: define internal i8* @augmented_f(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %malloccall = tail call noalias nonnull dereferenceable(16) dereferenceable_or_null(16) i8* @malloc(i64 16) +; CHECK-NEXT: %tapemem = bitcast i8* %malloccall to { i8*, double }* +; CHECK-NEXT: %call_augmented = call { i8*, double, double } @add_aug(double %x, double %x) +; CHECK-NEXT: %subcache = extractvalue { i8*, double, double } %call_augmented, 0 +; CHECK-NEXT: %0 = getelementptr inbounds { i8*, double }, { i8*, double }* %tapemem, i32 0, i32 0 +; CHECK-NEXT: store i8* %subcache, i8** %0 +; CHECK-NEXT: %call = extractvalue { i8*, double, double } %call_augmented, 1 +; CHECK-NEXT: %1 = getelementptr inbounds { i8*, double }, { i8*, double }* %tapemem, i32 0, i32 1 +; CHECK-NEXT: store double %call, double* %1 +; CHECK-NEXT: ret i8* %malloccall +; CHECK-NEXT: } + +; CHECK: define internal double @fwddiffef(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { i8*, double }* +; CHECK-NEXT: %truetape = load { i8*, double }, { i8*, double }* %0 +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %tapeArg1 = extractvalue { i8*, double } %truetape, 0 +; CHECK-NEXT: %1 = call { double, double } @add_err(double %x, double %"x'", double %x, double %"x'", i8* %tapeArg1) +; CHECK-NEXT: %2 = extractvalue { double, double } %1, 0 +; CHECK-NEXT: %3 = extractvalue { double, double } %1, 1 +; CHECK-NEXT: %4 = fmul fast double %3, %2 +; CHECK-NEXT: %5 = fadd fast double %4, %4 +; CHECK-NEXT: ret double %5 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/custom2.ll b/enzyme/test/Enzyme/ForwardModeSplit/custom2.ll new file mode 100644 index 0000000000000..f64e95b6279da --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/custom2.ll @@ -0,0 +1,44 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s + +source_filename = "exer2.c" +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__enzyme_register_splitderivative_add = dso_local local_unnamed_addr global [3 x i8*] [i8* bitcast (double (double, double)* @add to i8*), i8* bitcast ({ i8*, double, double } (double, double)* @add_aug to i8*), i8* bitcast ({ double, double } (double, double, double, double, i8*)* @add_err to i8*)], align 16 + +declare double @add(double %x, double %y) #0 + +declare { i8*, double, double } @add_aug(double %v1, double %v2) + +declare { double, double } @add_err(double %v1, double %v1err, double %v2, double %v2err, i8* %tape) + +; Function Attrs: norecurse nounwind readnone uwtable willreturn +define double @f(double %x) { +entry: + %call = call double @add(double %x, double 2.000000e+00) + ret double %call +} + +; Function Attrs: nounwind uwtable +define double @caller(double %x, double %dx, i8* %t) { +entry: + %call = call double (i8*, ...) @__enzyme_fwdsplit(i8* bitcast (double (double)* @f to i8*), double %x, double %dx, i8* %t) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, ...) + +attributes #0 = { norecurse nounwind readnone } + +; CHECK: define internal double @fwddiffef(double %x, double %"x'", i8* %tapeArg1) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @fixderivative_add(double %x, double %"x'", double 2.000000e+00, i8* %tapeArg1) +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } + +; CHECK: define internal double @fixderivative_add(double %x, double %"x'", double %y, i8* %tapeArg) { +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call { double, double } @add_err(double %x, double %"x'", double %y, double 0.000000e+00, i8* %tapeArg) +; CHECK-NEXT: %1 = extractvalue { double, double } %0, 1 +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/div.ll b/enzyme/test/Enzyme/ForwardModeSplit/div.ll new file mode 100644 index 0000000000000..7f262658f389f --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/div.ll @@ -0,0 +1,28 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fdiv fast double %x, %y + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwdsplit(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 0.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double, double)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", double %y, double %"y'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fmul fast double %"x'", %y +; CHECK-NEXT: %1 = fmul fast double %x, %"y'" +; CHECK-NEXT: %2 = fsub fast double %0, %1 +; CHECK-NEXT: %3 = fmul fast double %y, %y +; CHECK-NEXT: %4 = fdiv fast double %2, %3 +; CHECK-NEXT: ret double %4 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/divreduce.ll b/enzyme/test/Enzyme/ForwardModeSplit/divreduce.ll new file mode 100644 index 0000000000000..ee5b0804905cf --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/divreduce.ll @@ -0,0 +1,122 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse-memssa -instsimplify -correlated-propagation -adce -S | FileCheck %s + +; Function Attrs: norecurse nounwind readonly uwtable +define double @alldiv(double* nocapture readonly %A, i64 %N, double %start) { +entry: + br label %loop + +loop: ; preds = %9, %5 + %i = phi i64 [ 0, %entry ], [ %next, %loop ] + %reduce = phi double [ %start, %entry ], [ %div, %loop ] + %gep = getelementptr inbounds double, double* %A, i64 %i + %ld = load double, double* %gep, align 8, !tbaa !2 + %div = fdiv double %reduce, %ld + %next = add nuw nsw i64 %i, 1 + %cmp = icmp eq i64 %next, %N + br i1 %cmp, label %end, label %loop + +end: ; preds = %9, %3 + ret double %div +} + +define double @alldiv2(double* nocapture readonly %A, i64 %N) { +entry: + br label %loop + +loop: ; preds = %9, %5 + %i = phi i64 [ 0, %entry ], [ %next, %loop ] + %reduce = phi double [ 2.000000e+00, %entry ], [ %div, %loop ] + %gep = getelementptr inbounds double, double* %A, i64 %i + %ld = load double, double* %gep, align 8, !tbaa !2 + %div = fdiv double %reduce, %ld + %next = add nuw nsw i64 %i, 1 + %cmp = icmp eq i64 %next, %N + br i1 %cmp, label %end, label %loop + +end: ; preds = %9, %3 + ret double %div +} + +; Function Attrs: nounwind uwtable +define double @main(double* %A, double* %dA, i64 %N, double %start) { + %r = call double (...) @__enzyme_fwdsplit(i8* bitcast (double (double*, i64, double)* @alldiv to i8*), metadata !"enzyme_nofree", double* %A, double* %dA, i64 %N, double %start, double 1.0, i8* null) + %r2 = call double (...) @__enzyme_fwdsplit2(i8* bitcast (double (double*, i64)* @alldiv2 to i8*), metadata !"enzyme_nofree", double* %A, double* %dA, i64 %N, i8* null) + ret double %r +} + +declare double @__enzyme_fwdsplit(...) +declare double @__enzyme_fwdsplit2(...) + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"Ubuntu clang version 10.0.1-++20200809072545+ef32c611aa2-1~exp1~20200809173142.193"} +!2 = !{!3, !3, i64 0} +!3 = !{!"double", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"} +!6 = !{!7, !7, i64 0} +!7 = !{!"any pointer", !4, i64 0} + + +; CHECK: define internal double @fwddiffealldiv(double* nocapture readonly %A, double* nocapture %"A'", i64 %N, double %start, double %"start'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { double*, double* }* +; CHECK-NEXT: %truetape = load { double*, double* }, { double*, double* }* %0 +; CHECK-DAG: %[[i1:.+]] = extractvalue { double*, double* } %truetape, 0 +; CHECK-DAG: %[[i2:.+]] = extractvalue { double*, double* } %truetape, 1 +; CHECK-NEXT: br label %loop + +; CHECK: loop: ; preds = %loop, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %loop ], [ 0, %entry ] +; CHECK-NEXT: %"reduce'" = phi {{(fast )?}}double [ %"start'", %entry ], [ %10, %loop ] +; CHECK-NEXT: %3 = getelementptr inbounds double, double* %[[i1]], i64 %iv +; CHECK-NEXT: %reduce = load double, double* %3, align 8, !invariant.group !9 +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %"gep'ipg" = getelementptr inbounds double, double* %"A'", i64 %iv +; CHECK-NEXT: %4 = getelementptr inbounds double, double* %[[i2]], i64 %iv +; CHECK-NEXT: %ld = load double, double* %4, align 8, !invariant.group !10 +; CHECK-NEXT: %5 = load double, double* %"gep'ipg", align 8 +; CHECK-NEXT: %6 = fmul fast double %"reduce'", %ld +; CHECK-NEXT: %7 = fmul fast double %reduce, %5 +; CHECK-NEXT: %8 = fsub fast double %6, %7 +; CHECK-NEXT: %9 = fmul fast double %ld, %ld +; CHECK-NEXT: %10 = fdiv fast double %8, %9 +; CHECK-NEXT: %cmp = icmp eq i64 %iv.next, %N +; CHECK-NEXT: br i1 %cmp, label %end, label %loop + +; CHECK: end: ; preds = %loop +; CHECK-NEXT: ret double %10 +; CHECK-NEXT: } + + +; CHECK: define internal double @fwddiffealldiv2(double* nocapture readonly %A, double* nocapture %"A'", i64 %N, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { double*, double* }* +; CHECK-NEXT: %truetape = load { double*, double* }, { double*, double* }* %0 +; CHECK-DAG: %[[i1:.+]] = extractvalue { double*, double* } %truetape, 0 +; CHECK-DAG: %[[i2:.+]] = extractvalue { double*, double* } %truetape, 1 +; CHECK-NEXT: br label %loop + +; CHECK: loop: ; preds = %loop, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %loop ], [ 0, %entry ] +; CHECK-NEXT: %"reduce'" = phi {{(fast )?}}double [ 0.000000e+00, %entry ], [ %10, %loop ] +; CHECK-NEXT: %3 = getelementptr inbounds double, double* %[[i1]], i64 %iv +; CHECK-NEXT: %reduce = load double, double* %3, align 8, !invariant.group !13 +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %"gep'ipg" = getelementptr inbounds double, double* %"A'", i64 %iv +; CHECK-NEXT: %4 = getelementptr inbounds double, double* %[[i2]], i64 %iv +; CHECK-NEXT: %ld = load double, double* %4, align 8, !invariant.group !14 +; CHECK-NEXT: %5 = load double, double* %"gep'ipg", align 8 +; CHECK-NEXT: %6 = fmul fast double %"reduce'", %ld +; CHECK-NEXT: %7 = fmul fast double %reduce, %5 +; CHECK-NEXT: %8 = fsub fast double %6, %7 +; CHECK-NEXT: %9 = fmul fast double %ld, %ld +; CHECK-NEXT: %10 = fdiv fast double %8, %9 +; CHECK-NEXT: %cmp = icmp eq i64 %iv.next, %N +; CHECK-NEXT: br i1 %cmp, label %end, label %loop + +; CHECK: end: ; preds = %loop +; CHECK-NEXT: ret double %10 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/divreduce2.ll b/enzyme/test/Enzyme/ForwardModeSplit/divreduce2.ll new file mode 100644 index 0000000000000..1e14160043fc4 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/divreduce2.ll @@ -0,0 +1,80 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse-memssa -instsimplify -correlated-propagation -adce -S | FileCheck %s + +; TODO optimize this style reduction + +; Function Attrs: norecurse nounwind readonly uwtable +define double @alldiv(double* nocapture readonly %A, i64 %N, double %start) { +entry: + br label %loop + +loop: ; preds = %9, %5 + %i = phi i64 [ 0, %entry ], [ %next, %body ] + %reduce = phi double [ %start, %entry ], [ %div, %body ] + %cmp = icmp ult i64 %i, %N + br i1 %cmp, label %body, label %end + +body: + %gep = getelementptr inbounds double, double* %A, i64 %i + %ld = load double, double* %gep, align 8, !tbaa !2 + %div = fdiv double %reduce, %ld + %next = add nuw nsw i64 %i, 1 + br label %loop + +end: ; preds = %9, %3 + ret double %reduce +} + +; Function Attrs: nounwind uwtable +define double @main(double* %A, double* %dA, i64 %N, double %start) { + %r = call double (...) @__enzyme_fwdsplit(i8* bitcast (double (double*, i64, double)* @alldiv to i8*), metadata !"enzyme_nofree", double* %A, double* %dA, i64 %N, double %start, double 1.0, i8* null) + ret double %r +} + +declare double @__enzyme_fwdsplit(...) + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"Ubuntu clang version 10.0.1-++20200809072545+ef32c611aa2-1~exp1~20200809173142.193"} +!2 = !{!3, !3, i64 0} +!3 = !{!"double", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"} +!6 = !{!7, !7, i64 0} +!7 = !{!"any pointer", !4, i64 0} + + +; CHECK: define internal double @fwddiffealldiv(double* nocapture readonly %A, double* nocapture %"A'", i64 %N, double %start, double %"start'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { double*, double* }* +; CHECK-NEXT: %truetape = load { double*, double* }, { double*, double* }* %0 +; CHECK-DAG: %[[i1:.+]] = extractvalue { double*, double* } %truetape, 0 +; CHECK-DAG: %[[i2:.+]] = extractvalue { double*, double* } %truetape, 1 +; CHECK-NEXT: br label %loop + +; CHECK: loop: ; preds = %body, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %body ], [ 0, %entry ] +; CHECK-NEXT: %"reduce'" = phi {{(fast )?}}double [ %"start'", %entry ], [ %10, %body ] +; CHECK-NEXT: %3 = getelementptr inbounds double, double* %[[i1]], i64 %iv +; CHECK-NEXT: %reduce = load double, double* %3, align 8, +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %cmp = icmp ne i64 %iv, %N +; CHECK-NEXT: br i1 %cmp, label %body, label %end + +; CHECK: body: ; preds = %loop +; CHECK-NEXT: %"gep'ipg" = getelementptr inbounds double, double* %"A'", i64 %iv +; CHECK-NEXT: %4 = getelementptr inbounds double, double* %[[i2]], i64 %iv +; TODO this should keep tbaa +; CHECK-NEXT: %ld = load double, double* %4, align 8 +; CHECK-NEXT: %5 = load double, double* %"gep'ipg", align 8 +; CHECK-NEXT: %6 = fmul fast double %"reduce'", %ld +; CHECK-NEXT: %7 = fmul fast double %reduce, %5 +; CHECK-NEXT: %8 = fsub fast double %6, %7 +; CHECK-NEXT: %9 = fmul fast double %ld, %ld +; CHECK-NEXT: %10 = fdiv fast double %8, %9 +; CHECK-NEXT: br label %loop + +; CHECK: end: ; preds = %loop +; CHECK-NEXT: ret double %"reduce'" +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/enzyme_inactive.ll b/enzyme/test/Enzyme/ForwardModeSplit/enzyme_inactive.ll new file mode 100644 index 0000000000000..0ae3e772b5f50 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/enzyme_inactive.ll @@ -0,0 +1,27 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x) { +entry: + tail call void @myprint(double %x) + ret double %x +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +declare void @myprint(double %x) #0 + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +attributes #0 = { "enzyme_inactive" } + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: ret double %"x'" +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/enzyme_inactive2.ll b/enzyme/test/Enzyme/ForwardModeSplit/enzyme_inactive2.ll new file mode 100644 index 0000000000000..9490f0ac8c332 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/enzyme_inactive2.ll @@ -0,0 +1,27 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x) { +entry: + tail call void @myprint(double %x) #0 + ret double %x +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +declare void @myprint(double %x) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +attributes #0 = { "enzyme_inactive" } + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: ret double %"x'" +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/erf.ll b/enzyme/test/Enzyme/ForwardModeSplit/erf.ll new file mode 100644 index 0000000000000..5c84301efe637 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/erf.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +declare double @erf(double) + +define double @tester(double %x) { +entry: + %call = call double @erf(double %x) + ret double %call +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = {{(fsub fast double \-?0.000000e\+00,|fneg fast double)}} %0 +; CHECK-NEXT: %2 = call fast double @llvm.exp.f64(double %1) +; CHECK-NEXT: %3 = fmul fast double %2, 0x3FF20DD750429B6D +; CHECK-NEXT: %4 = fmul fast double %3, %"x'" +; CHECK-NEXT: ret double %4 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/erfc.ll b/enzyme/test/Enzyme/ForwardModeSplit/erfc.ll new file mode 100644 index 0000000000000..e8999e1b4dd39 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/erfc.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +declare double @erfc(double) + +define double @tester(double %x) { +entry: + %call = call double @erfc(double %x) + ret double %call +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = {{(fsub fast double \-?0.000000e\+00,|fneg fast double)}} %0 +; CHECK-NEXT: %2 = call fast double @llvm.exp.f64(double %1) +; CHECK-NEXT: %3 = fmul fast double %2, 0xBFF20DD750429B6D +; CHECK-NEXT: %4 = fmul fast double %3, %"x'" +; CHECK-NEXT: ret double %4 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/erfi.ll b/enzyme/test/Enzyme/ForwardModeSplit/erfi.ll new file mode 100644 index 0000000000000..e8a60e4da2edc --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/erfi.ll @@ -0,0 +1,28 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +declare double @erfi(double) + +define double @tester(double %x) { +entry: + %call = call double @erfi(double %x) + ret double %call +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = call fast double @llvm.exp.f64(double %0) +; CHECK-NEXT: %2 = fmul fast double %1, 0x3FF20DD750429B6D +; CHECK-NEXT: %3 = fmul fast double %2, %"x'" +; CHECK-NEXT: ret double %3 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/exp.ll b/enzyme/test/Enzyme/ForwardModeSplit/exp.ll new file mode 100644 index 0000000000000..52821af09db57 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/exp.ll @@ -0,0 +1,26 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -O3 -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @llvm.exp.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.exp.f64(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define double @test_derivative(double %x) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = tail call fast double @llvm.exp.f64(double %x) +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/exp2.ll b/enzyme/test/Enzyme/ForwardModeSplit/exp2.ll new file mode 100644 index 0000000000000..16a5315d17c1a --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/exp2.ll @@ -0,0 +1,27 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -O3 -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @llvm.exp2.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.exp2.f64(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define double @test_derivative(double %x) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = tail call fast double @llvm.exp2.f64(double %x) +; CHECK-NEXT: %1 = fmul fast double %0, 0x3FE62E42FEFA39EF +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/experimental_vector_reduce_v2_fadd.ll b/enzyme/test/Enzyme/ForwardModeSplit/experimental_vector_reduce_v2_fadd.ll new file mode 100644 index 0000000000000..4eaafdff8dc52 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/experimental_vector_reduce_v2_fadd.ll @@ -0,0 +1,26 @@ +; RUN: if [ %llvmver -ge 9 ] && [ %llvmver -le 11 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi + +define float @tester(float %start_value, <4 x float> %input) { +entry: + %ord = call float @llvm.experimental.vector.reduce.v2.fadd.f32.v4f32(float %start_value, <4 x float> %input) + ret float %ord +} + +define float @test_derivative(float %start_value, <4 x float> %input) { +entry: + %0 = tail call float (float (float, <4 x float>)*, ...) @__enzyme_fwdsplit(float (float, <4 x float>)* nonnull @tester, float %start_value, float 1.0, <4 x float> %input, <4 x float> , i8* null) + ret float %0 +} + +declare float @llvm.experimental.vector.reduce.v2.fadd.f32.v4f32(float, <4 x float>) + +; Function Attrs: nounwind +declare float @__enzyme_fwdsplit(float (float, <4 x float>)*, ...) + + +; CHECK: define internal {{(dso_local )?}}float @fwddiffetester(float %start_value, float %"start_value'", <4 x float> %input, <4 x float> %"input'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = call fast float @llvm.experimental.vector.reduce.v2.fadd.f32.v4f32(float %"start_value'", <4 x float> %"input'") +; CHECK-NEXT: ret float %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/fabs.ll b/enzyme/test/Enzyme/ForwardModeSplit/fabs.ll new file mode 100644 index 0000000000000..66913e34a2108 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/fabs.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -simplifycfg -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @llvm.fabs.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.fabs.f64(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %[[differet:.+]], i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fcmp fast olt double %x, 0.000000e+00 +; CHECK-NEXT: %1 = select{{( fast)?}} i1 %0, double -1.000000e+00, double 1.000000e+00 +; CHECK-NEXT: %2 = fmul fast double %1, %[[differet]] +; CHECK-NEXT: ret double %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/fneg.ll b/enzyme/test/Enzyme/ForwardModeSplit/fneg.ll new file mode 100644 index 0000000000000..67d3b980f45c2 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/fneg.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -ge 10 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi + +; extern double __enzyme_fwdsplit(void*, double, double); +; +; double fneg(double x) { +; return -x; +; } +; +; double dfneg(double x) { +; return __enzyme_fwdsplit((void*)fneg, x, 1.0); +; } + + +define double @fneg(double %x) { + %fneg = fneg double %x + ret double %fneg +} + +define double @dfneg(double %x) { + %1 = call double @__enzyme_fwdsplit(double (double)* @fneg, double %x, double 1.0, i8* null) + ret double %1 +} + +declare double @__enzyme_fwdsplit(double (double)*, double, double, i8*) + + +; CHECK: define internal double @fwddiffefneg(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %1 = fneg fast double %"x'" +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/fpext.ll b/enzyme/test/Enzyme/ForwardModeSplit/fpext.ll new file mode 100644 index 0000000000000..0923b1c63fbbe --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/fpext.ll @@ -0,0 +1,22 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -adce -instsimplify -S | FileCheck %s + +define double @tester(float %x) { +entry: + %y = fpext float %x to double + ret double %y +} + +define double @test_derivative(float %x) { +entry: + %0 = tail call double (double (float)*, ...) @__enzyme_fwdsplit(double (float)* nonnull @tester, float %x, float 1.0, i8* null) + ret double %0 +} + +declare double @__enzyme_fwdsplit(double (float)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(float %x, float %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fpext float %"x'" to double +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/ge.ll b/enzyme/test/Enzyme/ForwardModeSplit/ge.ll new file mode 100644 index 0000000000000..6bd2a308df3b7 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/ge.ll @@ -0,0 +1,96 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -early-cse -S | FileCheck %s + +; void __enzyme_autodiff(void*, ...); + +; double cache(double* x, unsigned N) { +; double sum = 0.0; +; for(unsigned i=0; i<=N; i++) { +; sum += x[i] * x[i]; +; } +; x[0] = 0.0; +; return sum; +; } + +; void ad(double* in, double* din, unsigned N) { +; __enzyme_autodiff(cache, in, din, N); +; } + +; ModuleID = 'foo.c' +source_filename = "foo.c" +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +; Function Attrs: norecurse nounwind uwtable +define dso_local double @cache(double* nocapture %x, i32 %N) #0 { +entry: + br label %for.body + +for.cond.cleanup: ; preds = %for.body + store double 0.000000e+00, double* %x, align 8, !tbaa !2 + ret double %add + +for.body: ; preds = %entry, %for.body + %i.013 = phi i32 [ 0, %entry ], [ %inc, %for.body ] + %sum.012 = phi double [ 0.000000e+00, %entry ], [ %add, %for.body ] + %idxprom = zext i32 %i.013 to i64 + %arrayidx = getelementptr inbounds double, double* %x, i64 %idxprom + %0 = load double, double* %arrayidx, align 8, !tbaa !2 + %mul = fmul double %0, %0 + %add = fadd double %sum.012, %mul + %inc = add i32 %i.013, 1 + %cmp = icmp ugt i32 %inc, %N + br i1 %cmp, label %for.cond.cleanup, label %for.body +} + +; Function Attrs: nounwind uwtable +define dso_local void @ad(double* %in, double* %din, i32 %N) local_unnamed_addr #1 { +entry: + tail call double (i8*, ...) @__enzyme_fwdsplit(i8* bitcast (double (double*, i32)* @cache to i8*), metadata !"enzyme_nofree", double* %in, double* %din, i32 %N, i8* null) #3 + ret void +} + +declare dso_local double @__enzyme_fwdsplit(i8*, ...) local_unnamed_addr #2 + +attributes #0 = { norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #2 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #3 = { nounwind } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 7.0.0 (trunk 336729)"} +!2 = !{!3, !3, i64 0} +!3 = !{!"double", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"} + + +; CHECK: define internal double @fwddiffecache(double* nocapture %x, double* nocapture %"x'", i32 %N, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to double** +; CHECK-NEXT: %truetape = load double*, double** %0 +; CHECK-NEXT: br label %for.body + +; CHECK: for.cond.cleanup: ; preds = %for.body +; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 +; CHECK-NEXT: ret double %7 + +; CHECK: for.body: ; preds = %for.body, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.body ], [ 0, %entry ] +; CHECK-NEXT: %"sum.012'" = phi {{(fast )?}}double [ 0.000000e+00, %entry ], [ %7, %for.body ] +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %1 = trunc i64 %iv to i32 +; CHECK-NEXT: %idxprom = zext i32 %1 to i64 +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 %idxprom +; CHECK-NEXT: %2 = getelementptr inbounds double, double* %truetape, i64 %iv +; CHECK-NEXT: %3 = load double, double* %2, align 8, !invariant.group !8 +; CHECK-NEXT: %4 = load double, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: %5 = fmul fast double %4, %3 +; CHECK-NEXT: %6 = fadd fast double %5, %5 +; CHECK-NEXT: %7 = fadd fast double %"sum.012'", %6 +; CHECK-NEXT: %inc = add i32 %1, 1 +; CHECK-NEXT: %cmp = icmp ugt i32 %inc, %N +; CHECK-NEXT: br i1 %cmp, label %for.cond.cleanup, label %for.body +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/global.ll b/enzyme/test/Enzyme/ForwardModeSplit/global.ll new file mode 100644 index 0000000000000..823b701be4b38 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/global.ll @@ -0,0 +1,95 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; #include +; +; extern double global; +; +; __attribute__((noinline)) +; double mulglobal(double x) { +; return x * global; +; } +; +; __attribute__((noinline)) +; double derivative(double x) { +; return __builtin_fwddiff(mulglobal, x, 1.0); +; } +; +; void main(int argc, char** argv) { +; double x = atof(argv[1]); +; printf("x=%f\n", x); +; double xp = derivative(x); +; printf("xp=%f\n", xp); +; } + +@global = external dso_local local_unnamed_addr global double, align 8, !enzyme_shadow !{double* @dglobal} +@dglobal = external dso_local local_unnamed_addr global double, align 8 + +@.str = private unnamed_addr constant [6 x i8] c"x=%f\0A\00", align 1 +@.str.1 = private unnamed_addr constant [7 x i8] c"xp=%f\0A\00", align 1 + +; Function Attrs: noinline norecurse nounwind readonly uwtable +define dso_local double @mulglobal(double %x) #0 { +entry: + %0 = load double, double* @global, align 8, !tbaa !2 + %mul = fmul fast double %0, %x + ret double %mul +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @derivative(double %x) local_unnamed_addr #1 { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @mulglobal, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) #2 + +; Function Attrs: nounwind uwtable +define dso_local void @main(i32 %argc, i8** nocapture readonly %argv) local_unnamed_addr #3 { +entry: + %arrayidx = getelementptr inbounds i8*, i8** %argv, i64 1 + %0 = load i8*, i8** %arrayidx, align 8, !tbaa !6 + %call.i = tail call fast double @strtod(i8* nocapture nonnull %0, i8** null) #2 + %call1 = tail call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([6 x i8], [6 x i8]* @.str, i64 0, i64 0), double %call.i) + %call2 = tail call fast double @derivative(double %call.i) + %call3 = tail call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([7 x i8], [7 x i8]* @.str.1, i64 0, i64 0), double %call2) + ret void +} + +; Function Attrs: nounwind +declare dso_local i32 @printf(i8* nocapture readonly, ...) local_unnamed_addr #4 + +; Function Attrs: nounwind +declare dso_local double @strtod(i8* readonly, i8** nocapture) local_unnamed_addr #4 + +attributes #0 = { noinline norecurse nounwind readonly uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #1 = { noinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #2 = { nounwind } +attributes #3 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #4 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 7.1.0 "} +!2 = !{!3, !3, i64 0} +!3 = !{!"double", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"} +!6 = !{!7, !7, i64 0} +!7 = !{!"any pointer", !4, i64 0} +!8 = !{double* @dglobal} + +; CHECK: define internal {{(dso_local )?}}double @fwddiffemulglobal(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to double* +; CHECK-NEXT: %1 = load double, double* %0 +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %2 = load double, double* @dglobal +; CHECK-NEXT: %3 = fmul fast double %2, %x +; CHECK-NEXT: %4 = fmul fast double %"x'", %1 +; CHECK-NEXT: %5 = fadd fast double %3, %4 +; CHECK-NEXT: ret double %5 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/globalfn.ll b/enzyme/test/Enzyme/ForwardModeSplit/globalfn.ll new file mode 100644 index 0000000000000..8f2e8e2babb8e --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/globalfn.ll @@ -0,0 +1,96 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +@global = private unnamed_addr constant [1 x void (double*)*] [void (double*)* @ipmul] + +@.str = private unnamed_addr constant [6 x i8] c"x=%f\0A\00", align 1 +@.str.1 = private unnamed_addr constant [7 x i8] c"xp=%f\0A\00", align 1 + +define void @ipmul(double* %x) { +entry: + %0 = load double, double* %x, align 8, !tbaa !2 + %mul = fmul fast double %0, %0 + store double %mul, double* %x + ret void +} + +; Function Attrs: noinline norecurse nounwind readonly uwtable +define dso_local double @mulglobal(double %x, i64 %idx) #0 { +entry: + %alloc = alloca double + store double %x, double* %alloc + %arrayidx = getelementptr inbounds [1 x void (double*)*], [1 x void (double*)*]* @global, i64 0, i64 %idx + %fp = load void (double*)*, void (double*)** %arrayidx, align 8 + call void %fp(double* %alloc) + %ret = load double, double* %alloc, !tbaa !2 + ret double %ret +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @derivative(double %x) local_unnamed_addr #1 { +entry: + %0 = tail call double (double (double, i64)*, ...) @__enzyme_fwdsplit(double (double, i64)* nonnull @mulglobal, metadata !"enzyme_nofree", double %x, double 1.0, i64 0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double, i64)*, ...) #2 + +; Function Attrs: nounwind uwtable +define dso_local void @main(i32 %argc, i8** nocapture readonly %argv) local_unnamed_addr #3 { +entry: + %arrayidx = getelementptr inbounds i8*, i8** %argv, i64 1 + %0 = load i8*, i8** %arrayidx, align 8, !tbaa !6 + %call.i = tail call fast double @strtod(i8* nocapture nonnull %0, i8** null) #2 + %call1 = tail call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([6 x i8], [6 x i8]* @.str, i64 0, i64 0), double %call.i) + %call2 = tail call fast double @derivative(double %call.i) + %call3 = tail call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([7 x i8], [7 x i8]* @.str.1, i64 0, i64 0), double %call2) + ret void +} + +; Function Attrs: nounwind +declare dso_local i32 @printf(i8* nocapture readonly, ...) local_unnamed_addr #4 + +; Function Attrs: nounwind +declare dso_local double @strtod(i8* readonly, i8** nocapture) local_unnamed_addr #4 + +attributes #0 = { noinline norecurse nounwind readonly uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #1 = { noinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #2 = { nounwind } +attributes #3 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #4 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 7.1.0 "} +!2 = !{!3, !3, i64 0} +!3 = !{!"double", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"} +!6 = !{!7, !7, i64 0} +!7 = !{!"any pointer", !4, i64 0} + +; XFAIL: * +; TODO + +; CHECK: @global_shadow = private unnamed_addr constant [1 x void (double*)*] [void (double*)* bitcast (void (double*, double*)** @"_enzyme_forward_ipmul'" to void (double*)*)] +; CHECK:@"_enzyme_forward_ipmul'" = internal constant void (double*, double*)* @fwddiffeipmul + +; CHECK: define internal double @fwddiffemulglobal(double %x, double %"x'", i64 %idx, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { i8*, i8*, i8* }* +; CHECK-NEXT: %truetape = load { i8*, i8*, i8* }, { i8*, i8*, i8* }* %0, !enzyme_mustcache !10 +; CHECK-NEXT: %malloccall = extractvalue { i8*, i8*, i8* } %truetape, 2 +; CHECK-NEXT: %"malloccall'mi" = extractvalue { i8*, i8*, i8* } %truetape, 1 +; CHECK-NEXT: %"alloc'ipc" = bitcast i8* %"malloccall'mi" to double* +; CHECK-NEXT: %alloc = bitcast i8* %malloccall to double* +; CHECK-NEXT: store double %"x'", double* %"alloc'ipc" +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds [1 x void (double*)*], [1 x void (double*)*]* @global_shadow, i64 0, i64 %idx +; CHECK-NEXT: %"fp'ipl" = load void (double*)*, void (double*)** %"arrayidx'ipg", align 8 +; CHECK-NEXT: %1 = bitcast void (double*)* %"fp'ipl" to void (double*, double*)** +; CHECK-NEXT: %2 = load void (double*, double*)*, void (double*, double*)** %1 +; CHECK-NEXT: call void %2(double* %alloc, double* %"alloc'ipc") +; CHECK-NEXT: %3 = load double, double* %"alloc'ipc" +; CHECK-NEXT: ret double %3 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/globallower.ll b/enzyme/test/Enzyme/ForwardModeSplit/globallower.ll new file mode 100644 index 0000000000000..8da6a674b18a4 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/globallower.ll @@ -0,0 +1,42 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-lower-globals -mem2reg -sroa -simplifycfg -instsimplify -S | FileCheck %s + +; TODO +; XFAIL: * + +@global = external dso_local local_unnamed_addr global double, align 8 + +; Function Attrs: noinline norecurse nounwind readonly uwtable +define double @mulglobal(double %x) { +entry: + %l1 = load double, double* @global, align 8 + %mul = fmul fast double %l1, %x + store double %mul, double* @global, align 8 + %l2 = load double, double* @global, align 8 + %mul2 = fmul fast double %l2, %l2 + store double %mul2, double* @global, align 8 + %l3 = load double, double* @global, align 8 + ret double %l3 +} + +; Function Attrs: noinline nounwind uwtable +define double @derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @mulglobal, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define internal double @fwddiffemulglobal(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %global_local.0.copyload = load double, double* @global, align 8 +; CHECK-NEXT: %mul = fmul fast double %global_local.0.copyload, %x +; CHECK-NEXT: %0 = fmul fast double %"x'", %global_local.0.copyload +; CHECK-NEXT: %mul2 = fmul fast double %mul, %mul +; CHECK-NEXT: %1 = fmul fast double %0, %mul +; CHECK-NEXT: %2 = fmul fast double %0, %mul +; CHECK-NEXT: %3 = fadd fast double %1, %2 +; CHECK-NEXT: store double %mul2, double* @global, align 8 +; CHECK-NEXT: ret double %3 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/insertvalue.ll b/enzyme/test/Enzyme/ForwardModeSplit/insertvalue.ll new file mode 100644 index 0000000000000..730f17b16c7ac --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/insertvalue.ll @@ -0,0 +1,31 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x) { +entry: + %agg1 = insertvalue [3 x double] undef, double %x, 0 + %mul = fmul double %x, %x + %agg2 = insertvalue [3 x double] %agg1, double %mul, 1 + %add = fadd double %mul, 2.0 + %agg3 = insertvalue [3 x double] %agg2, double %add, 2 + %res = extractvalue [3 x double] %agg2, 1 + ret double %res +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fmul fast double %"x'", %x +; CHECK-NEXT: %1 = fmul fast double %"x'", %x +; CHECK-NEXT: %2 = fadd fast double %0, %1 +; CHECK-NEXT: ret double %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/intsum.ll b/enzyme/test/Enzyme/ForwardModeSplit/intsum.ll new file mode 100644 index 0000000000000..4d07a52c4550e --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/intsum.ll @@ -0,0 +1,55 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -adce -S | FileCheck %s + +define dso_local void @sum(float* %array, float* %ret) #4 { +entry: + br label %do.body + +do.body: ; preds = %do.body, %entry + %i = phi i64 [ %inc, %do.body ], [ 0, %entry ] + %intsum = phi i32 [ 0, %entry ], [ %intadd, %do.body ] + %arrayidx = getelementptr inbounds float, float* %array, i64 %i + %loaded = load float, float* %arrayidx + %fltload = bitcast i32 %intsum to float + %add = fadd float %fltload, %loaded + %intadd = bitcast float %add to i32 + %inc = add nuw nsw i64 %i, 1 + %cmp = icmp eq i64 %inc, 5 + br i1 %cmp, label %do.end, label %do.body + +do.end: ; preds = %do.body + %lcssa = phi float [ %add, %do.body ] + store float %lcssa, float* %ret, align 4 + ret void +} + +; Function Attrs: nounwind uwtable +define dso_local void @dsum(float* %x, float* %xp, float* %n, float* %np) local_unnamed_addr #1 { +entry: + %0 = tail call double (void (float*, float*)*, ...) @__enzyme_fwdsplit(void (float*, float*)* nonnull @sum, float* %x, float* %xp, float* %n, float* %np, i8* null) + ret void +} + +declare double @__enzyme_fwdsplit(void (float*, float*)*, ...) #2 + + +; CHECK: define internal void @fwddiffesum(float* %array, float* %"array'", float* %ret, float* %"ret'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: br label %do.body + +; CHECK: do.body: ; preds = %do.body, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %do.body ], [ 0, %entry ] +; CHECK-NEXT: %"intsum'" = phi i32 [ 0, %entry ], [ %3, %do.body ] +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds float, float* %"array'", i64 %iv +; CHECK-NEXT: %0 = load float, float* %"arrayidx'ipg" +; CHECK-NEXT: %1 = bitcast i32 %"intsum'" to float +; CHECK-NEXT: %2 = fadd fast float %1, %0 +; CHECK-NEXT: %3 = bitcast float %2 to i32 +; CHECK-NEXT: %cmp = icmp eq i64 %iv.next, 5 +; CHECK-NEXT: br i1 %cmp, label %do.end, label %do.body + +; CHECK: do.end: ; preds = %do.body +; CHECK-NEXT: store float %2, float* %"ret'", align 4 +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/invertselect.ll b/enzyme/test/Enzyme/ForwardModeSplit/invertselect.ll new file mode 100644 index 0000000000000..99219199cc988 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/invertselect.ll @@ -0,0 +1,33 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind uwtable +define dso_local float @man_max(float* %a, float* %b) #0 { +entry: + %0 = load float, float* %a, align 4 + %1 = load float, float* %b, align 4 + %cmp = fcmp ogt float %0, %1 + %a.b = select i1 %cmp, float* %a, float* %b + %retval.0 = load float, float* %a.b, align 4 + ret float %retval.0 +} + +define void @dman_max(float* %a, float* %da, float* %b, float* %db) { +entry: + call float (...) @__enzyme_fwdsplit.f64(float (float*, float*)* @man_max, float* %a, float* %da, float* %b, float* %db, i8* null) + ret void +} + +declare float @__enzyme_fwdsplit.f64(...) + +attributes #0 = { noinline } + + +; CHECK: define internal float @fwddiffeman_max(float* %a, float* %"a'", float* %b, float* %"b'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to i1* +; CHECK-NEXT: %cmp = load i1, i1* %0 +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %"a.b'ipse" = select i1 %cmp, float* %"a'", float* %"b'" +; CHECK-NEXT: %1 = load float, float* %"a.b'ipse" +; CHECK-NEXT: ret float %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/log.ll b/enzyme/test/Enzyme/ForwardModeSplit/log.ll new file mode 100644 index 0000000000000..0304c1e90ff25 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/log.ll @@ -0,0 +1,26 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -O3 -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @llvm.log.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.log.f64(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define double @test_derivative(double %x) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fdiv fast double 1.000000e+00, %x +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/log10.ll b/enzyme/test/Enzyme/ForwardModeSplit/log10.ll new file mode 100644 index 0000000000000..ac32a77096c18 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/log10.ll @@ -0,0 +1,27 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -O3 -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @llvm.log10.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.log10.f64(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define double @test_derivative(double %x) +; CHECK-NEXT: entry: +; equivalent to 1/log(10) / x +; CHECK-NEXT: %0 = fdiv fast double 0x3FDBCB7B1526E50D, %x +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/log2.ll b/enzyme/test/Enzyme/ForwardModeSplit/log2.ll new file mode 100644 index 0000000000000..1e4c77d049af9 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/log2.ll @@ -0,0 +1,27 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -O3 -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @llvm.log2.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.log2.f64(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define double @test_derivative(double %x) +; CHECK-NEXT: entry: +; equivalent to 1/log(2) / x +; CHECK-NEXT: %0 = fdiv fast double 0x3FF71547652B82FE, %x +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/maskedload.ll b/enzyme/test/Enzyme/ForwardModeSplit/maskedload.ll new file mode 100644 index 0000000000000..00f08511a9616 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/maskedload.ll @@ -0,0 +1,30 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -dce -instcombine -S | FileCheck %s + +declare <2 x double> @llvm.masked.load.v2f64.p0v2f64 (<2 x double>*, i32, <2 x i1>, <2 x double>) + +; Function Attrs: nounwind uwtable +define dso_local <2 x double> @loader(<2 x double>* %ptr, <2 x i1> %mask, <2 x double> %other) { +entry: + %res = call <2 x double> @llvm.masked.load.v2f64.p0v2f64(<2 x double>* %ptr, i32 16, <2 x i1> %mask, <2 x double> %other) + ret <2 x double> %res +} + + +; Function Attrs: argmemonly nounwind +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) #1 + +; Function Attrs: nounwind uwtable +define <2 x double> @dloader(i8* %ptr, i8* %dptr, <2 x i1> %mask, <2 x double> %other, <2 x double> %dother) { +entry: + %res = tail call <2 x double> (...) @__enzyme_fwdsplit.f64(<2 x double> (<2 x double>*, <2 x i1>, <2 x double>)* @loader, i8* %ptr, i8* %dptr, <2 x i1> %mask, <2 x double> %other, <2 x double> %dother, i8* null) + ret <2 x double> %res +} + +declare <2 x double> @__enzyme_fwdsplit.f64(...) + +; CHECK: define internal <2 x double> @fwddiffeloader(<2 x double>* %ptr, <2 x double>* %"ptr'", <2 x i1> %mask, <2 x double> %other, <2 x double> %"other'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = call fast <2 x double> @llvm.masked.load.v2f64.p0v2f64(<2 x double>* %"ptr'", i32 16, <2 x i1> %mask, <2 x double> %"other'") +; CHECK-NEXT: ret <2 x double> %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/maskedstore.ll b/enzyme/test/Enzyme/ForwardModeSplit/maskedstore.ll new file mode 100644 index 0000000000000..2755f48224642 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/maskedstore.ll @@ -0,0 +1,30 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -dce -instcombine -S | FileCheck %s + +declare void @llvm.masked.store.v2f64.p0v2f64 (<2 x double>, <2 x double>*, i32, <2 x i1>) + +; Function Attrs: nounwind uwtable +define dso_local void @loader(<2 x double>* %ptr, <2 x i1> %mask, <2 x double> %val) { +entry: + call void @llvm.masked.store.v2f64.p0v2f64(<2 x double> %val, <2 x double>* %ptr, i32 16, <2 x i1> %mask) + ret void +} + + +; Function Attrs: argmemonly nounwind +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) #1 + +; Function Attrs: nounwind uwtable +define <2 x double> @dloader(i8* %ptr, i8* %dptr, <2 x i1> %mask, <2 x double> %other, <2 x double> %dother) { +entry: + %res = tail call <2 x double> (...) @__enzyme_fwdsplit.f64(void (<2 x double>*, <2 x i1>, <2 x double>)* @loader, i8* %ptr, i8* %dptr, <2 x i1> %mask, <2 x double> %other, <2 x double> %dother, i8* null) + ret <2 x double> %res +} + +declare <2 x double> @__enzyme_fwdsplit.f64(...) + +; CHECK: define internal void @fwddiffeloader(<2 x double>* %ptr, <2 x double>* %"ptr'", <2 x i1> %mask, <2 x double> %val, <2 x double> %"val'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: call void @llvm.masked.store.v2f64.p0v2f64(<2 x double> %"val'", <2 x double>* %"ptr'", i32 16, <2 x i1> %mask) +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/max.ll b/enzyme/test/Enzyme/ForwardModeSplit/max.ll new file mode 100644 index 0000000000000..078fd8228a772 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/max.ll @@ -0,0 +1,28 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: norecurse nounwind readnone uwtable +define dso_local double @max(double %x, double %y) #0 { +entry: + %cmp = fcmp fast ogt double %x, %y + %cond = select i1 %cmp, double %x, double %y + ret double %cond +} + +; Function Attrs: nounwind uwtable +define dso_local double @test_derivative(double %x, double %y) local_unnamed_addr #1 { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwdsplit(double (double, double)* nonnull @max, double %x, double 1.0, double %y, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double, double)*, ...) + + +; CHECK: define internal double @fwddiffemax(double %x, double %"x'", double %y, double %"y'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %cmp = fcmp fast ogt double %x, %y +; CHECK-NEXT: %0 = select {{(fast )?}}i1 %cmp, double %"x'", double %"y'" +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/maxnum-inactive.ll b/enzyme/test/Enzyme/ForwardModeSplit/maxnum-inactive.ll new file mode 100644 index 0000000000000..cb92d3a7d4056 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/maxnum-inactive.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = tail call double @llvm.maxnum.f64(double %x, double %y) + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwdsplit(double (double, double)* nonnull @tester, double %x, double 1.0, metadata !"enzyme_const", double %y, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.maxnum.f64(double, double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double, double)*, ...) + + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fcmp fast olt double %x, %y +; CHECK-NEXT: %1 = select {{(fast )?}}i1 %0, double %"x'", double 0.000000e+00 +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/maxnum.ll b/enzyme/test/Enzyme/ForwardModeSplit/maxnum.ll new file mode 100644 index 0000000000000..948feede35a93 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/maxnum.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = tail call double @llvm.maxnum.f64(double %x, double %y) + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwdsplit(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.maxnum.f64(double, double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double, double)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", double %y, double %"y'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fcmp fast olt double %x, %y +; CHECK-NEXT: %1 = select {{(fast )?}}i1 %0, double %"x'", double %"y'" +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardModeSplit/memcpy-flt.ll b/enzyme/test/Enzyme/ForwardModeSplit/memcpy-flt.ll new file mode 100644 index 0000000000000..2e22169854663 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/memcpy-flt.ll @@ -0,0 +1,37 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -dce -instcombine -S | FileCheck %s + +; Function Attrs: nounwind uwtable +define dso_local void @memcpy_float(double* nocapture %dst, double* nocapture readonly %src, i64 %num) #0 { +entry: + %0 = bitcast double* %dst to i8* + %1 = bitcast double* %src to i8* + tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 %0, i8* align 1 %1, i64 %num, i1 false) + ret void +} + +; Function Attrs: argmemonly nounwind +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) #1 + +; Function Attrs: nounwind uwtable +define dso_local void @dmemcpy_float(double* %dst, double* %dstp, double* %src, double* %srcp, i64 %n) local_unnamed_addr #0 { +entry: + tail call void (...) @__enzyme_fwdsplit.f64(void (double*, double*, i64)* nonnull @memcpy_float, double* %dst, double* %dstp, double* %src, double* %srcp, i64 %n, i8* null) #3 + ret void +} + +declare void @__enzyme_fwdsplit.f64(...) local_unnamed_addr + + +attributes #0 = { nounwind uwtable } +attributes #1 = { argmemonly nounwind } +attributes #2 = { noinline nounwind uwtable } + + +; CHECK: define internal void @fwddiffememcpy_float(double* nocapture %dst, double* nocapture %"dst'", double* nocapture readonly %src, double* nocapture %"src'", i64 %num, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %"'ipc" = bitcast double* %"dst'" to i8* +; CHECK-NEXT: %"'ipc2" = bitcast double* %"src'" to i8* +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 %"'ipc", i8* align 1 %"'ipc2", i64 %num, i1 false) +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/memcpy-intstruct.ll b/enzyme/test/Enzyme/ForwardModeSplit/memcpy-intstruct.ll new file mode 100644 index 0000000000000..317d8d7baacda --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/memcpy-intstruct.ll @@ -0,0 +1,38 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -dce -S | FileCheck %s + +; Function Attrs: nounwind uwtable +define dso_local void @memcpy_ptr(i8* nocapture %dst, i8* nocapture readonly %src, i64 %num) { +entry: + tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %dst, i8* align 8 %src, i64 %num, i1 false), !tbaa !17, !tbaa.struct !19 + ret void +} + +; Function Attrs: argmemonly nounwind +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) #0 + +; Function Attrs: nounwind uwtable +define dso_local void @dmemcpy_ptr(i8* %dst, i8* %dstp, i8* %src, i8* %srcp, i64 %n) { +entry: + %0 = tail call double (...) @__enzyme_fwdsplit.f64(void (i8*, i8*, i64)* nonnull @memcpy_ptr, metadata !"enzyme_dup", i8* %dst, i8* %dstp, metadata !"enzyme_dup", i8* %src, i8* %srcp, i64 %n, i8* null) + ret void +} + +declare double @__enzyme_fwdsplit.f64(...) local_unnamed_addr + +attributes #0 = { argmemonly nounwind } + +!17 = !{!18, !18, i64 0, i64 32} +!18 = !{!4, i64 32, !"_ZTSSt5arrayIlLm4EE", !9, i64 0, i64 32} + +!19 = !{i64 0, i64 32, !20} +!20 = !{!9, !9, i64 0, i64 32} +!9 = !{!4, i64 8, !"long"} +!4 = !{!5, i64 1, !"omnipotent char"} +!5 = !{!"Simple C++ TBAA"} + + +; CHECK: define internal void @fwddiffememcpy_ptr(i8* nocapture %dst, i8* nocapture %"dst'", i8* nocapture readonly %src, i8* nocapture %"src'", i64 %num, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/memcpy-ptr.ll b/enzyme/test/Enzyme/ForwardModeSplit/memcpy-ptr.ll new file mode 100644 index 0000000000000..e1e4ed18c9e89 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/memcpy-ptr.ll @@ -0,0 +1,35 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -dce -S | FileCheck %s + +; Function Attrs: nounwind uwtable +define dso_local void @memcpy_ptr(double** nocapture %dst, double** nocapture readonly %src, i64 %num) #0 { +entry: + %0 = bitcast double** %dst to i8* + %1 = bitcast double** %src to i8* + tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 %0, i8* align 1 %1, i64 %num, i1 false) + ret void +} + +; Function Attrs: argmemonly nounwind +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) #1 + +; Function Attrs: nounwind uwtable +define dso_local void @dmemcpy_ptr(double** %dst, double** %dstp, double** %src, double** %srcp, i64 %n) local_unnamed_addr #0 { +entry: + %0 = tail call double (...) @__enzyme_fwdsplit.f64(void (double**, double**, i64)* nonnull @memcpy_ptr, double** %dst, double** %dstp, double** %src, double** %srcp, i64 %n, i8* null) #3 + ret void +} + +declare double @__enzyme_fwdsplit.f64(...) local_unnamed_addr + + +attributes #0 = { nounwind uwtable } +attributes #1 = { argmemonly nounwind } +attributes #2 = { noinline nounwind uwtable } +attributes #3 = { nounwind } + + +; CHECK: define internal void @fwddiffememcpy_ptr(double** nocapture %dst, double** nocapture %"dst'", double** nocapture readonly %src, double** nocapture %"src'", i64 %num, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/minnum.ll b/enzyme/test/Enzyme/ForwardModeSplit/minnum.ll new file mode 100644 index 0000000000000..1f798b9268e6b --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/minnum.ll @@ -0,0 +1,25 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -simplifycfg -S | FileCheck %s + +define double @tester(double %x, double %y) { +entry: + %0 = tail call double @llvm.minnum.f64(double %x, double %y) + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwdsplit(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 1.0, i8* null) + ret double %0 +} + +declare double @llvm.minnum.f64(double, double) + +declare double @__enzyme_fwdsplit(double (double, double)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", double %y, double %"y'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fcmp fast olt double %x, %y +; CHECK-NEXT: %1 = select{{( fast)?}} i1 %0, double %"x'", double %"y'" +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/mul.ll b/enzyme/test/Enzyme/ForwardModeSplit/mul.ll new file mode 100644 index 0000000000000..e5ae711880caa --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/mul.ll @@ -0,0 +1,26 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fmul fast double %x, %y + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwdsplit(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 0.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double, double)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", double %y, double %"y'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fmul fast double %"x'", %y +; CHECK-NEXT: %1 = fmul fast double %"y'", %x +; CHECK-NEXT: %2 = fadd fast double %0, %1 +; CHECK-NEXT: ret double %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/negbithack.ll b/enzyme/test/Enzyme/ForwardModeSplit/negbithack.ll new file mode 100644 index 0000000000000..46c3bf582b54e --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/negbithack.ll @@ -0,0 +1,26 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x) { +entry: + %cstx = bitcast double %x to i64 + %negx = xor i64 %cstx, -9223372036854775808 + %csty = bitcast i64 %negx to double + ret double %csty +} + +define double @test_derivative(double %x, double %dx) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double %dx, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = {{(fsub fast double \-?0.000000e\+00,|fneg fast double)}} %"x'" +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/negbithack2.ll b/enzyme/test/Enzyme/ForwardModeSplit/negbithack2.ll new file mode 100644 index 0000000000000..7920591d48f1f --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/negbithack2.ll @@ -0,0 +1,26 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define <2 x double> @tester(<2 x double> %x) { +entry: + %cstx = bitcast <2 x double> %x to <2 x i64> + %negx = xor <2 x i64> %cstx, + %csty = bitcast <2 x i64> %negx to <2 x double> + ret <2 x double> %csty +} + +define <2 x double> @test_derivative(<2 x double> %x, <2 x double> %dx) { +entry: + %0 = tail call <2 x double> (<2 x double> (<2 x double>)*, ...) @__enzyme_fwdsplit(<2 x double> (<2 x double>)* nonnull @tester, <2 x double> %x, <2 x double> %dx, i8* null) + ret <2 x double> %0 +} + +; Function Attrs: nounwind +declare <2 x double> @__enzyme_fwdsplit(<2 x double> (<2 x double>)*, ...) + +; CHECK: define internal <2 x double> @fwddiffetester(<2 x double> %x, <2 x double> %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = {{(fsub fast <2 x double> ,|fneg fast <2 x double>)}} %"x'" +; CHECK-NEXT: ret <2 x double> %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/negbithack3.ll b/enzyme/test/Enzyme/ForwardModeSplit/negbithack3.ll new file mode 100644 index 0000000000000..b3ad9526b5e12 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/negbithack3.ll @@ -0,0 +1,34 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define <2 x double> @tester(<2 x double> %x) { +entry: + %cstx = bitcast <2 x double> %x to <2 x i64> + %negx = xor <2 x i64> %cstx, + %csty = bitcast <2 x i64> %negx to <2 x double> + ret <2 x double> %csty +} + +define <2 x double> @test_derivative(<2 x double> %x, <2 x double> %dx) { +entry: + %0 = tail call <2 x double> (<2 x double> (<2 x double>)*, ...) @__enzyme_fwdsplit(<2 x double> (<2 x double>)* nonnull @tester, <2 x double> %x, <2 x double> %dx, i8* null) + ret <2 x double> %0 +} + +; Function Attrs: nounwind +declare <2 x double> @__enzyme_fwdsplit(<2 x double> (<2 x double>)*, ...) + +; CHECK: define internal <2 x double> @fwddiffetester(<2 x double> %x, <2 x double> %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = bitcast <2 x double> %"x'" to <2 x i64> +; CHECK-NEXT: %1 = extractelement <2 x i64> %0, i64 0 +; CHECK-NEXT: %2 = bitcast i64 %1 to double +; CHECK-NEXT: %3 = {{(fsub fast double -?0.000000e\+00,|fneg fast double)}} %2 +; CHECK-NEXT: %4 = bitcast double %3 to i64 +; CHECK-NEXT: %5 = insertelement <2 x i64> undef, i64 %4, i64 0 +; CHECK-NEXT: %6 = extractelement <2 x i64> %0, i64 1 +; CHECK-NEXT: %7 = insertelement <2 x i64> %5, i64 %6, i64 1 +; CHECK-NEXT: %8 = bitcast <2 x i64> %7 to <2 x double> +; CHECK-NEXT: ret <2 x double> %8 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/pow.ll b/enzyme/test/Enzyme/ForwardModeSplit/pow.ll new file mode 100644 index 0000000000000..2130f6b1fd26c --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/pow.ll @@ -0,0 +1,35 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = tail call fast double @llvm.pow.f64(double %x, double %y) + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwdsplit(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.pow.f64(double, double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double, double)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", double %y, double %"y'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %[[i0:.+]] = fsub fast double %y, 1.000000e+00 +; CHECK-NEXT: %[[i1:.+]] = call fast double @llvm.pow.f64(double %x, double %[[i0]]) +; CHECK-NEXT: %[[i2:.+]] = fmul fast double %y, %[[i1]] +; CHECK-NEXT: %[[dx:.+]] = fmul fast double %[[i2]], %"x'" +; CHECK-NEXT: %[[i3:.+]] = call fast double @llvm.pow.f64(double %x, double %y) +; CHECK-NEXT: %[[i4:.+]] = call fast double @llvm.log.f64(double %x) +; CHECK-DAG: %[[i5:.+]] = fmul fast double %[[i3]], %[[i4]] +; CHECK-NEXT: %[[dy:.+]] = fmul fast double %[[i5]], %"y'" +; CHECK-DAG: %[[i6:.+]] = fadd fast double %[[dx]], %[[dy]] +; CHECK-NEXT: ret double %[[i6]] +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/powi13.ll b/enzyme/test/Enzyme/ForwardModeSplit/powi13.ll new file mode 100644 index 0000000000000..c532078ec829a --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/powi13.ll @@ -0,0 +1,34 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, i32 %y) { +entry: + %0 = tail call fast double @llvm.powi.f64.i32(double %x, i32 %y) + ret double %0 +} + +define double @test_derivative(double %x, i32 %y) { +entry: + %0 = tail call double (double (double, i32)*, ...) @__enzyme_fwdsplit(double (double, i32)* nonnull @tester, double %x, double 1.0, i32 %y, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.powi.f64.i32(double, i32) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double, i32)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", i32 %y, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %[[ym1:.+]] = sub i32 %y, 1 +; CHECK-NEXT: %[[newpow:.+]] = call fast double @llvm.powi.f64{{(\.i32)?}}(double %x, i32 %[[ym1]]) +; CHECK-DAG: %[[sitofp:.+]] = sitofp i32 %y to double +; CHECK-DAG: %[[cmp:.+]] = icmp eq i32 0, %y +; CHECK-DAG: %[[newpowdret:.+]] = fmul fast double %"x'", %[[newpow]] +; CHECK-NEXT: %[[dx:.+]] = fmul fast double %[[newpowdret]], %[[sitofp]] +; CHECK-NEXT: %[[res:.+]] = select {{(fast )?}}i1 %[[cmp]], double 0.000000e+00, double %[[dx]] +; CHECK-NEXT: ret double %[[res]] +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardModeSplit/ptr-ret.ll b/enzyme/test/Enzyme/ForwardModeSplit/ptr-ret.ll new file mode 100644 index 0000000000000..0c9d2eb5ba85f --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/ptr-ret.ll @@ -0,0 +1,63 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +define dso_local noalias nonnull double* @_Z6toHeapd(double %x) { +entry: + %call = call noalias nonnull dereferenceable(8) i8* @_Znwm(i64 8) + %0 = bitcast i8* %call to double* + store double %x, double* %0, align 8 + ret double* %0 +} + +declare dso_local nonnull i8* @_Znwm(i64) + +define dso_local double @_Z6squared(double %x) { +entry: + %call = call double* @_Z6toHeapd(double %x) + %0 = load double, double* %call, align 8 + %mul = fmul double %0, %x + ret double %mul +} + +define dso_local double @_Z7dsquared(double %x) { +entry: + %call = call double (...) @_Z16__enzyme_fwdsplitz(i8* bitcast (double (double)* @_Z6squared to i8*), metadata !"enzyme_nofree", double %x, double 1.000000e+00, i8* null) + ret double %call +} + +declare dso_local double @_Z16__enzyme_fwdsplitz(...) + + + +; CHECK: define dso_local double @_Z7dsquared(double %x) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @fwddiffe_Z6squared(double %x, double 1.000000e+00, i8* null) +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } + + +; CHECK: define internal double @fwddiffe_Z6squared(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { { i8*, i8* }, double*, double }* +; CHECK-NEXT: %truetape = load { { i8*, i8* }, double*, double }, { { i8*, i8* }, double*, double }* %0 +; CHECK-NEXT: %tapeArg1 = extractvalue { { i8*, i8* }, double*, double } %truetape, 0 +; CHECK-NEXT: %1 = call { double*, double* } @fwddiffe_Z6toHeapd(double %x, double %"x'", { i8*, i8* } %tapeArg1) +; CHECK-NEXT: %2 = extractvalue { double*, double* } %1, 1 +; CHECK-NEXT: %3 = extractvalue { { i8*, i8* }, double*, double } %truetape, 2 +; CHECK-NEXT: %4 = load double, double* %2, align 8 +; CHECK-NEXT: %5 = fmul fast double %4, %x +; CHECK-NEXT: %6 = fmul fast double %"x'", %3 +; CHECK-NEXT: %7 = fadd fast double %5, %6 +; CHECK-NEXT: ret double %7 +; CHECK-NEXT: } + +; CHECK: define internal { double*, double* } @fwddiffe_Z6toHeapd(double %x, double %"x'", { i8*, i8* } %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %call = extractvalue { i8*, i8* } %tapeArg, 1 +; CHECK-NEXT: %"call'mi" = extractvalue { i8*, i8* } %tapeArg, 0 +; CHECK-NEXT: %"'ipc" = bitcast i8* %"call'mi" to double* +; CHECK-NEXT: %0 = bitcast i8* %call to double* +; CHECK-NEXT: store double %"x'", double* %"'ipc", align 8 +; CHECK-NEXT: %1 = insertvalue { double*, double* } undef, double* %0, 0 +; CHECK-NEXT: %2 = insertvalue { double*, double* } %1, double* %"'ipc", 1 +; CHECK-NEXT: ret { double*, double* } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/relu.ll b/enzyme/test/Enzyme/ForwardModeSplit/relu.ll new file mode 100644 index 0000000000000..dcbd42d5de06b --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/relu.ll @@ -0,0 +1,62 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -inline -early-cse -instcombine -simplifycfg -S | FileCheck %s + +; __attribute__((noinline)) +; double f(double x) { +; return x; +; } +; +; double relu(double x) { +; return (x > 0) ? f(x) : 0; +; } +; +; double drelu(double x) { +; return __builtin_autodiff(relu, x); +; } + +define dso_local double @f(double %x) #1 { +entry: + ret double %x +} + +define dso_local double @relu(double %x) { +entry: + %cmp = fcmp fast ogt double %x, 0.000000e+00 + br i1 %cmp, label %cond.true, label %cond.end + +cond.true: ; preds = %entry + %call = tail call fast double @f(double %x) + br label %cond.end + +cond.end: ; preds = %entry, %cond.true + %cond = phi double [ %call, %cond.true ], [ 0.000000e+00, %entry ] + ret double %cond +} + +define dso_local double @drelu(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @relu, double %x, double 1.0, i8* null) + ret double %0 +} + +declare double @__enzyme_fwdsplit(double (double)*, ...) #0 + +attributes #0 = { nounwind } +attributes #1 = { nounwind readnone noinline } + +; CHECK: define dso_local double @drelu(double %x) +; CHECK-NEXT: entry: +; CHECK-NEXT: %cmp.i = fcmp fast ogt double %x, 0.000000e+00 +; CHECK-NEXT: br i1 %cmp.i, label %cond.true.i, label %fwddifferelu.exit +; CHECK: cond.true.i: ; preds = %entry +; CHECK-NEXT: %0 = call fast double @fwddiffef(double %x, double 1.000000e+00) +; CHECK-NEXT: br label %fwddifferelu.exit +; CHECK: fwddifferelu.exit: ; preds = %entry, %cond.true.i +; CHECK-NEXT: %"cond'.i" = phi{{( fast)?}} double [ %0, %cond.true.i ], [ 0.000000e+00, %entry ] +; CHECK-NEXT: ret double %"cond'.i" +; CHECK-NEXT: } + + +; CHECK: define internal {{(dso_local )?}}double @fwddiffef(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret double %"x'" +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/rwloop.ll b/enzyme/test/Enzyme/ForwardModeSplit/rwloop.ll new file mode 100644 index 0000000000000..fae3aae06ea5c --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/rwloop.ll @@ -0,0 +1,166 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -simplifycfg -instsimplify -correlated-propagation -adce -S | FileCheck %s + +; ModuleID = '../test/Integration/rwrloop.c' +source_filename = "../test/Integration/rwrloop.c" +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +%struct._IO_FILE = type { i32, i8*, i8*, i8*, i8*, i8*, i8*, i8*, i8*, i8*, i8*, i8*, %struct._IO_marker*, %struct._IO_FILE*, i32, i32, i64, i16, i8, [1 x i8], i8*, i64, i8*, i8*, i8*, i8*, i64, i32, [20 x i8] } +%struct._IO_marker = type { %struct._IO_marker*, %struct._IO_FILE*, i32 } + +@.str = private unnamed_addr constant [16 x i8] c"d_a[%d][%d]=%f\0A\00", align 1 +@stderr = external dso_local local_unnamed_addr global %struct._IO_FILE*, align 8 +@.str.1 = private unnamed_addr constant [68 x i8] c"Assertion Failed: fabs( [%s = %g] - [%s = %g] ) > %g at %s:%d (%s)\0A\00", align 1 +@.str.2 = private unnamed_addr constant [10 x i8] c"d_a[i][j]\00", align 1 +@.str.3 = private unnamed_addr constant [15 x i8] c"2. * (i*100+j)\00", align 1 +@.str.4 = private unnamed_addr constant [30 x i8] c"../test/Integration/rwrloop.c\00", align 1 +@__PRETTY_FUNCTION__.main = private unnamed_addr constant [23 x i8] c"int main(int, char **)\00", align 1 + +; Function Attrs: norecurse nounwind uwtable +define dso_local double @alldiv(double* noalias nocapture %a, i32* noalias nocapture %N) #0 { +entry: + %0 = load i32, i32* %N, align 4, !tbaa !2 + %cmp233 = icmp sgt i32 %0, 0 + br label %for.cond1.preheader + +for.cond1.preheader: ; preds = %for.cond.cleanup3, %entry + %indvar = phi i64 [ 0, %entry ], [ %indvar.next, %for.cond.cleanup3 ] + %sum.036 = phi double [ 0.000000e+00, %entry ], [ %sum.1.lcssa, %for.cond.cleanup3 ] + br i1 %cmp233, label %for.body4.lr.ph, label %for.cond.cleanup3 + +for.body4.lr.ph: ; preds = %for.cond1.preheader + %1 = mul nuw nsw i64 %indvar, 10 + %2 = load i32, i32* %N, align 4, !tbaa !2 + %3 = sext i32 %2 to i64 + br label %for.body4 + +for.body4: ; preds = %for.body4.lr.ph, %for.body4 + %indvars.iv = phi i64 [ 0, %for.body4.lr.ph ], [ %indvars.iv.next, %for.body4 ] + %sum.134 = phi double [ %sum.036, %for.body4.lr.ph ], [ %add10, %for.body4 ] + %4 = add nuw nsw i64 %indvars.iv, %1 + %arrayidx = getelementptr inbounds double, double* %a, i64 %4 + %5 = load double, double* %arrayidx, align 8, !tbaa !6 + %mul9 = fmul double %5, %5 + %add10 = fadd double %sum.134, %mul9 + store double 0.000000e+00, double* %arrayidx, align 8, !tbaa !6 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %cmp2 = icmp slt i64 %indvars.iv.next, %3 + br i1 %cmp2, label %for.body4, label %for.cond.cleanup3 + +for.cond.cleanup3: ; preds = %for.body4, %for.cond1.preheader + %sum.1.lcssa = phi double [ %sum.036, %for.cond1.preheader ], [ %add10, %for.body4 ] + %indvar.next = add nuw nsw i64 %indvar, 1 + %exitcond = icmp eq i64 %indvar.next, 10 + br i1 %exitcond, label %for.cond.cleanup, label %for.cond1.preheader + +for.cond.cleanup: ; preds = %for.cond.cleanup3 + store i32 7, i32* %N, align 4, !tbaa !2 + ret double %sum.1.lcssa +} + +define void @main(double* %a, double* %da, i32* %N) { +entry: + %call = call double (i8*, ...) @__enzyme_fwdsplit(i8* bitcast (double (double*, i32*)* @alldiv to i8*), metadata !"enzyme_nofree", double* nonnull %a, double* nonnull %da, i32* nonnull %N, i8* null) + ret void +} + +; Function Attrs: nounwind +declare i8* @llvm.stacksave() #3 + +; Function Attrs: argmemonly nounwind +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) #1 + +declare dso_local double @__enzyme_fwdsplit(i8*, ...) local_unnamed_addr #4 + +; Function Attrs: nounwind +declare dso_local i32 @printf(i8* nocapture readonly, ...) local_unnamed_addr #5 + +; Function Attrs: nounwind +declare dso_local i32 @fflush(%struct._IO_FILE* nocapture) local_unnamed_addr #5 + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.fabs.f64(double) #6 + +; Function Attrs: nounwind +declare dso_local i32 @fprintf(%struct._IO_FILE* nocapture, i8* nocapture readonly, ...) local_unnamed_addr #5 + +; Function Attrs: noreturn nounwind +declare dso_local void @abort() local_unnamed_addr #7 + +; Function Attrs: nounwind +declare void @llvm.stackrestore(i8*) #3 + +attributes #0 = { norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { argmemonly nounwind } +attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #3 = { nounwind } +attributes #4 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #5 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #6 = { nounwind readnone speculatable } +attributes #7 = { noreturn nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #8 = { cold } +attributes #9 = { noreturn nounwind } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 7.0.0 (trunk 336729)"} +!2 = !{!3, !3, i64 0} +!3 = !{!"int", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"} +!6 = !{!7, !7, i64 0} +!7 = !{!"double", !4, i64 0} +!8 = !{!9, !9, i64 0} +!9 = !{!"any pointer", !4, i64 0} + + +; CHECK: define internal double @fwddiffealldiv(double* noalias nocapture %a, double* nocapture %"a'", i32* noalias nocapture %N, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { i1, i32*, double** }* +; CHECK-NEXT: %truetape = load { i1, i32*, double** }, { i1, i32*, double** }* %0 +; CHECK-DAG: %[[i1:.+]] = extractvalue { i1, i32*, double** } %truetape, 1 +; CHECK-DAG: %[[i2:.+]] = extractvalue { i1, i32*, double** } %truetape, 2 +; CHECK-DAG: %cmp233 = extractvalue { i1, i32*, double** } %truetape, 0 +; CHECK-NEXT: br label %for.cond1.preheader + +; CHECK: for.cond1.preheader: ; preds = %for.cond.cleanup3, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.cond.cleanup3 ], [ 0, %entry ] +; CHECK-NEXT: %"sum.036'" = phi {{(fast )?}}double [ 0.000000e+00, %entry ], [ %"sum.1.lcssa'", %for.cond.cleanup3 ] +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: br i1 %cmp233, label %for.body4.lr.ph, label %for.cond.cleanup3 + +; CHECK: for.body4.lr.ph: ; preds = %for.cond1.preheader +; CHECK-NEXT: %3 = mul nuw nsw i64 %iv, 10 +; CHECK-NEXT: %4 = getelementptr inbounds i32, i32* %[[i1]], i64 %iv +; CHECK-NEXT: %5 = load i32, i32* %4, align 4, !invariant.group !13 +; CHECK-NEXT: %6 = sext i32 %5 to i64 +; CHECK-NEXT: br label %for.body4 + +; CHECK: for.body4: ; preds = %for.body4, %for.body4.lr.ph +; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %for.body4 ], [ 0, %for.body4.lr.ph ] +; CHECK-NEXT: %"sum.134'" = phi {{(fast )?}}double [ %"sum.036'", %for.body4.lr.ph ], [ %15, %for.body4 ] +; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1 +; CHECK-NEXT: %7 = add nuw nsw i64 %iv1, %3 +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"a'", i64 %7 +; CHECK-NEXT: %8 = getelementptr inbounds double*, double** %[[i2]], i64 %iv +; CHECK-NEXT: %9 = load double*, double** %8, align 8, !dereferenceable !10, !invariant.group !14 +; CHECK-NEXT: %10 = getelementptr inbounds double, double* %9, i64 %iv1 +; CHECK-NEXT: %11 = load double, double* %10, align 8, !invariant.group !15 +; CHECK-NEXT: %12 = load double, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: %13 = fmul fast double %12, %11 +; CHECK-NEXT: %14 = fadd fast double %13, %13 +; CHECK-NEXT: %15 = fadd fast double %"sum.134'", %14 +; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: %cmp2 = icmp slt i64 %iv.next2, %6 +; CHECK-NEXT: br i1 %cmp2, label %for.body4, label %for.cond.cleanup3 + +; CHECK: for.cond.cleanup3: ; preds = %for.body4, %for.cond1.preheader +; CHECK-NEXT: %"sum.1.lcssa'" = phi {{(fast )?}}double [ %"sum.036'", %for.cond1.preheader ], [ %15, %for.body4 ] +; CHECK-NEXT: %exitcond = icmp eq i64 %iv.next, 10 +; CHECK-NEXT: br i1 %exitcond, label %for.cond.cleanup, label %for.cond1.preheader + +; CHECK: for.cond.cleanup: ; preds = %for.cond.cleanup3 +; CHECK-NEXT: ret double %"sum.1.lcssa'" +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/sin.ll b/enzyme/test/Enzyme/ForwardModeSplit/sin.ll new file mode 100644 index 0000000000000..f910bb404da73 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/sin.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -O3 -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @llvm.sin.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.cos.f64(double) + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.sin.f64(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define double @test_derivative(double %x) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = tail call fast double @llvm.cos.f64(double %x) +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/sqrelu.ll b/enzyme/test/Enzyme/ForwardModeSplit/sqrelu.ll new file mode 100644 index 0000000000000..3cb4252867262 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/sqrelu.ll @@ -0,0 +1,72 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -instcombine -early-cse -adce -S | FileCheck %s + +; #include +; +; double sqrelu(double x) { +; return (x > 0) ? sqrt(x * sin(x)) : 0; +; } +; +; double dsqrelu(double x) { +; return __builtin_autodiff(sqrelu, x); +; } + +; Function Attrs: nounwind readnone uwtable +define dso_local double @sqrelu(double %x) #0 { +entry: + %cmp = fcmp fast ogt double %x, 0.000000e+00 + br i1 %cmp, label %cond.true, label %cond.end + +cond.true: ; preds = %entry + %0 = tail call fast double @llvm.sin.f64(double %x) + %mul = fmul fast double %0, %x + %1 = tail call fast double @llvm.sqrt.f64(double %mul) + br label %cond.end + +cond.end: ; preds = %entry, %cond.true + %cond = phi double [ %1, %cond.true ], [ 0.000000e+00, %entry ] + ret double %cond +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.sin.f64(double) #1 + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.sqrt.f64(double) #1 + +; Function Attrs: nounwind uwtable +define dso_local double @dsqrelu(double %x) local_unnamed_addr #2 { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @sqrelu, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) #3 + +attributes #0 = { nounwind readnone uwtable } +attributes #1 = { nounwind readnone speculatable } +attributes #2 = { nounwind uwtable } +attributes #3 = { nounwind } + +; CHECK: define dso_local double @dsqrelu(double %x) local_unnamed_addr +; CHECK-NEXT: entry: +; CHECK-NEXT: %cmp.i = fcmp fast ogt double %x, 0.000000e+00 +; CHECK-NEXT: br i1 %cmp.i, label %cond.true.i, label %fwddiffesqrelu.exit + +; CHECK: cond.true.i: +; CHECK-NEXT: %0 = call fast double @llvm.sin.f64(double %x) +; CHECK-NEXT: %1 = call fast double @llvm.cos.f64(double %x) +; CHECK-NEXT: %mul.i = fmul fast double %0, %x +; CHECK-NEXT: %2 = fmul fast double %1, %x +; CHECK-NEXT: %3 = fadd fast double %2, %0 +; CHECK-NEXT: %4 = call fast double @llvm.sqrt.f64(double %mul.i) +; CHECK-NEXT: %5 = fmul fast double %3, 5.000000e-01 +; CHECK-NEXT: %6 = fdiv fast double %5, %4 +; CHECK-NEXT: %7 = fcmp fast oeq double %mul.i, 0.000000e+00 +; CHECK-NEXT: %8 = select {{(fast )?}}i1 %7, double 0.000000e+00, double %6 +; CHECK-NEXT: br label %fwddiffesqrelu.exit + +; CHECK: fwddiffesqrelu.exit: +; CHECK-NEXT: %"cond'.i" = phi {{(fast )?}}double [ %8, %cond.true.i ], [ 0.000000e+00, %entry ] +; CHECK-NEXT: ret double %"cond'.i" +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/sqrt.ll b/enzyme/test/Enzyme/ForwardModeSplit/sqrt.ll new file mode 100644 index 0000000000000..5f7f8068a9748 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/sqrt.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -O3 -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @llvm.sqrt.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @tester, double %x, double 1.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @llvm.sqrt.f64(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define double @test_derivative(double %x) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = tail call fast double @llvm.sqrt.f64(double %x) +; CHECK-NEXT: %1 = fdiv fast double 5.000000e-01, %0 +; CHECK-NEXT: %2 = fcmp fast oeq double %x, 0.000000e+00 +; CHECK-NEXT: %3 = select{{( fast)?}} i1 %2, double 0.000000e+00, double %1 +; CHECK-NEXT: ret double %3 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/square.ll b/enzyme/test/Enzyme/ForwardModeSplit/square.ll new file mode 100644 index 0000000000000..8460f9503a349 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/square.ll @@ -0,0 +1,32 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s + +; source code +; double square(double x) { +; return x * x; +; } +; +; double dsquare(double x) { +; return __builtin_autodiff(square, x); +; } + +define double @square(double %x) { +entry: + %mul = fmul fast double %x, %x + ret double %mul +} + +define double @dsquare(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwdsplit(double (double)* nonnull @square, double %x, double 1.0, i8* null) + ret double %0 +} + +declare double @__enzyme_fwdsplit(double (double)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffesquare(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fmul fast double %"x'", %x +; CHECK-NEXT: %1 = fadd fast double %0, %0 +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/square2.ll b/enzyme/test/Enzyme/ForwardModeSplit/square2.ll new file mode 100644 index 0000000000000..8dcd3e2b0d59c --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/square2.ll @@ -0,0 +1,92 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +; #include + +; double __enzyme_fwdsplit(void*, ...); + +; __attribute__((noinline)) +; void square_(const double* src, double* dest) { +; *dest = *src * *src; +; } + +; double square(double x) { +; double y; +; square_(&x, &y); +; return y; +; } + +; double dsquare(double x) { +; return __enzyme_fwdsplit((void*)square, x, 1.0); +; } + + +define dso_local void @square_(double* nocapture readonly %src, double* nocapture %dest) local_unnamed_addr #0 { +entry: + %0 = load double, double* %src, align 8 + %mul = fmul double %0, %0 + store double %mul, double* %dest, align 8 + ret void +} + +define dso_local double @square(double %x) #1 { +entry: + %x.addr = alloca double, align 8 + %y = alloca double, align 8 + store double %x, double* %x.addr, align 8 + %0 = bitcast double* %y to i8* + call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %0) #4 + call void @square_(double* nonnull %x.addr, double* nonnull %y) + %1 = load double, double* %y, align 8 + call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %0) #4 + ret double %1 +} + +declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #2 + +declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #2 + +define dso_local double @dsquare(double %x) local_unnamed_addr #1 { +entry: + %call = tail call double (i8*, ...) @__enzyme_fwdsplit(i8* bitcast (double (double)* @square to i8*), metadata !"enzyme_nofree", double %x, double 1.000000e+00, i8* null) #4 + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, ...) local_unnamed_addr #3 + +attributes #0 = { norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #2 = { argmemonly nounwind } +attributes #3 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #4 = { nounwind } + + +; CHECK: define internal double @fwddiffesquare(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { double, i8*, i8* }* +; CHECK-NEXT: %truetape = load { double, i8*, i8* }, { double, i8*, i8* }* %0 +; CHECK-NEXT: %malloccall = extractvalue { double, i8*, i8* } %truetape, 2 +; CHECK-NEXT: %"malloccall'mi" = alloca i8, i64 8, align 8 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(8) dereferenceable_or_null(8) %"malloccall'mi", i8 0, i64 8, i1 false) +; CHECK-NEXT: %"x.addr'ipc" = bitcast i8* %"malloccall'mi" to double* +; CHECK-NEXT: %x.addr = bitcast i8* %malloccall to double* +; CHECK-NEXT: %malloccall1 = extractvalue { double, i8*, i8* } %truetape, 1 +; CHECK-NEXT: %"malloccall1'mi" = alloca i8, i64 8, align 8 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(8) dereferenceable_or_null(8) %"malloccall1'mi", i8 0, i64 8, i1 false) +; CHECK-NEXT: %"y'ipc" = bitcast i8* %"malloccall1'mi" to double* +; CHECK-NEXT: %y = bitcast i8* %malloccall1 to double* +; CHECK-NEXT: store double %"x'", double* %"x.addr'ipc", align 8 +; CHECK-NEXT: %tapeArg1 = extractvalue { double, i8*, i8* } %truetape, 0 +; CHECK-NEXT: call void @fwddiffesquare_(double* %x.addr, double* %"x.addr'ipc", double* %y, double* %"y'ipc", double %tapeArg1) +; CHECK-NEXT: %1 = load double, double* %"y'ipc", align 8 +; CHECK-NEXT: ret double %1 +; CHECK-NEXT: } + +; CHECK: define internal void @fwddiffesquare_(double* nocapture readonly %src, double* nocapture %"src'", double* nocapture %dest, double* nocapture %"dest'", double +; CHECK-NEXT: entry: +; CHECK-NEXT: %1 = load double, double* %"src'", align 8 +; CHECK-NEXT: %2 = fmul fast double %1, %0 +; CHECK-NEXT: %3 = fmul fast double %1, %0 +; CHECK-NEXT: %4 = fadd fast double %2, %3 +; CHECK-NEXT: store double %4, double* %"dest'", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/square_array.ll b/enzyme/test/Enzyme/ForwardModeSplit/square_array.ll new file mode 100644 index 0000000000000..370e427015397 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/square_array.ll @@ -0,0 +1,34 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s + +define { double, double } @squared(double %x) { +entry: + %mul = fmul double %x, %x + %mul2 = fmul double %mul, %x + %.fca.0.insert = insertvalue { double, double } undef, double %mul, 0 + %.fca.1.insert = insertvalue { double, double } %.fca.0.insert, double %mul2, 1 + ret { double, double } %.fca.1.insert +} + +define { double, double } @dsquared(double %x) { +entry: + %call = call { double, double } (i8*, ...) @__enzyme_fwdsplit(i8* bitcast ({ double, double } (double)* @squared to i8*), double %x, double 1.0, i8* null) + ret { double, double } %call +} + +declare { double, double } @__enzyme_fwdsplit(i8*, ...) + + + +; CHECK: define internal {{(dso_local )?}}{ double, double } @fwddiffesquared(double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %mul = fmul double %x, %x +; CHECK-NEXT: %0 = fmul fast double %"x'", %x +; CHECK-NEXT: %1 = fadd fast double %0, %0 +; CHECK-NEXT: %2 = fmul fast double %1, %x +; CHECK-NEXT: %3 = fmul fast double %"x'", %mul +; CHECK-NEXT: %4 = fadd fast double %2, %3 +; CHECK-NEXT: %5 = insertvalue { double, double } zeroinitializer, double %1, 0 +; CHECK-NEXT: %6 = insertvalue { double, double } %5, double %4, 1 +; CHECK-NEXT: ret { double, double } %6 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/sret.ll b/enzyme/test/Enzyme/ForwardModeSplit/sret.ll new file mode 100644 index 0000000000000..8dd7227bb9055 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/sret.ll @@ -0,0 +1,100 @@ +; RUN: if [ %llvmver -lt 12 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s ; fi + + +; #include +; #include + +; using namespace std; + +; extern array __enzyme_fwdsplit(void*, ...); + +; array square(double x) { +; return {x * x, x * x * x, x}; +; } +; array dsquare(double x) { +; // This returns the derivative of square or 2 * x +; return __enzyme_fwdsplit((void*)square, x, 1.0); +; } +; int main() { +; printf("%f \n", dsquare(3)[0]); +; } + + +%"struct.std::array" = type { [3 x double] } + +@.str = private unnamed_addr constant [5 x i8] c"%f \0A\00", align 1 + +define dso_local void @_Z6squared(%"struct.std::array"* noalias nocapture sret %agg.result, double %x) #0 { +entry: + %arrayinit.begin = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 0 + %mul = fmul double %x, %x + store double %mul, double* %arrayinit.begin, align 8 + %arrayinit.element = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 1 + %mul2 = fmul double %mul, %x + store double %mul2, double* %arrayinit.element, align 8 + %arrayinit.element3 = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 2 + store double %x, double* %arrayinit.element3, align 8 + ret void +} + +define dso_local void @_Z7dsquared(%"struct.std::array"* noalias sret %agg.result, double %x) local_unnamed_addr #1 { +entry: + tail call void (%"struct.std::array"*, i8*, ...) @_Z16__enzyme_fwdsplitPvz(%"struct.std::array"* sret %agg.result, i8* bitcast (void (%"struct.std::array"*, double)* @_Z6squared to i8*), double %x, double 1.000000e+00, i8* null) + ret void +} + +declare dso_local void @_Z16__enzyme_fwdsplitPvz(%"struct.std::array"* sret, i8*, ...) local_unnamed_addr #2 + +define dso_local i32 @main() local_unnamed_addr #3 { +entry: + %ref.tmp = alloca %"struct.std::array", align 8 + %0 = bitcast %"struct.std::array"* %ref.tmp to i8* + call void @llvm.lifetime.start.p0i8(i64 24, i8* nonnull %0) #6 + call void (%"struct.std::array"*, i8*, ...) @_Z16__enzyme_fwdsplitPvz(%"struct.std::array"* nonnull sret %ref.tmp, i8* bitcast (void (%"struct.std::array"*, double)* @_Z6squared to i8*), double 3.000000e+00, double 1.000000e+00, i8* null) + %arrayidx.i.i = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %ref.tmp, i64 0, i32 0, i64 0 + %1 = load double, double* %arrayidx.i.i, align 8 + %call1 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([5 x i8], [5 x i8]* @.str, i64 0, i64 0), double %1) + call void @llvm.lifetime.end.p0i8(i64 24, i8* nonnull %0) #6 + ret i32 0 +} + +declare dso_local i32 @printf(i8* nocapture readonly, ...) local_unnamed_addr #4 + +declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #5 + +declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #5 + +attributes #0 = { norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #2 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #3 = { norecurse uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #4 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #5 = { argmemonly nounwind } +attributes #6 = { nounwind } + + +; CHECK: define dso_local void @_Z7dsquared(%"struct.std::array"* noalias sret %agg.result, double %x) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = alloca %"struct.std::array" +; CHECK-NEXT: call void @fwddiffe_Z6squared(%"struct.std::array"* %0, %"struct.std::array"* %agg.result, double %x, double 1.000000e+00, i8* null) +; CHECK-NEXT: ret void +; CHECK-NEXT: } + + +; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture %agg.result, %"struct.std::array"* nocapture %"agg.result'", double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 0 +; CHECK-NEXT: %mul = fmul double %x, %x +; CHECK-NEXT: %0 = fmul fast double %"x'", %x +; CHECK-NEXT: %1 = fadd fast double %0, %0 +; CHECK-NEXT: store double %1, double* %"arrayinit.begin'ipg", align 8 +; CHECK-NEXT: %"arrayinit.element'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 1 +; CHECK-NEXT: %2 = fmul fast double %1, %x +; CHECK-NEXT: %3 = fmul fast double %"x'", %mul +; CHECK-NEXT: %4 = fadd fast double %2, %3 +; CHECK-NEXT: store double %4, double* %"arrayinit.element'ipg", align 8 +; CHECK-NEXT: %"arrayinit.element3'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 2 +; CHECK-NEXT: store double %"x'", double* %"arrayinit.element3'ipg", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/sret12.ll b/enzyme/test/Enzyme/ForwardModeSplit/sret12.ll new file mode 100644 index 0000000000000..efbf7645984c4 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/sret12.ll @@ -0,0 +1,88 @@ +; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -S | FileCheck %s ; fi + + +; #include +; #include + +; using namespace std; + +; extern array __enzyme_fwdsplit(void*, ...); + +; array square(double x) { +; return {x * x, x * x * x, x}; +; } +; array dsquare(double x) { +; // This returns the derivative of square or 2 * x +; return __enzyme_fwdsplit((void*)square, x, 1.0); +; } +; int main() { +; printf("%f \n", dsquare(3)[0]); +; } + + +%"struct.std::array" = type { [3 x double] } + +@.str = private unnamed_addr constant [5 x i8] c"%f \0A\00", align 1 + +define dso_local void @_Z6squared(%"struct.std::array"* noalias nocapture sret(%"struct.std::array") align 8 %agg.result, double %x) #0 { +entry: + %arrayinit.begin = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 0 + %mul = fmul double %x, %x + store double %mul, double* %arrayinit.begin, align 8 + %arrayinit.element = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 1 + %mul2 = fmul double %mul, %x + store double %mul2, double* %arrayinit.element, align 8 + %arrayinit.element3 = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 2 + store double %x, double* %arrayinit.element3, align 8 + ret void +} + +define dso_local void @_Z7dsquared(%"struct.std::array"* noalias sret(%"struct.std::array") align 8 %agg.result, double %x) local_unnamed_addr #1 { +entry: + tail call void (%"struct.std::array"*, i8*, ...) @_Z16__enzyme_fwdsplitPvz(%"struct.std::array"* sret(%"struct.std::array") align 8 %agg.result, i8* bitcast (void (%"struct.std::array"*, double)* @_Z6squared to i8*), double %x, double 1.000000e+00, i8* null) + ret void +} + +declare dso_local void @_Z16__enzyme_fwdsplitPvz(%"struct.std::array"* sret(%"struct.std::array") align 8, i8*, ...) local_unnamed_addr #2 + + +declare dso_local noundef i32 @printf(i8* nocapture noundef readonly, ...) local_unnamed_addr #4 + +declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #5 + +declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #5 + +attributes #0 = { nofree norecurse nounwind uwtable willreturn writeonly mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { uwtable mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #2 = { "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #3 = { norecurse uwtable mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #4 = { nofree nounwind "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #5 = { argmemonly nofree nosync nounwind willreturn } +attributes #6 = { nounwind } + + +; CHECK: define dso_local void @_Z7dsquared(%"struct.std::array"* noalias sret(%"struct.std::array") align 8 %agg.result, double %x) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = alloca %"struct.std::array" +; CHECK-NEXT: call void @fwddiffe_Z6squared(%"struct.std::array"* %0, %"struct.std::array"* %agg.result, double %x, double 1.000000e+00, i8* null) +; CHECK-NEXT: ret void +; CHECK-NEXT: } + + +; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture align 8 %agg.result, %"struct.std::array"* nocapture %"agg.result'", double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 0 +; CHECK-NEXT: %mul = fmul double %x, %x +; CHECK-NEXT: %0 = fmul fast double %"x'", %x +; CHECK-NEXT: %1 = fadd fast double %0, %0 +; CHECK-NEXT: store double %1, double* %"arrayinit.begin'ipg", align 8 +; CHECK-NEXT: %"arrayinit.element'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 1 +; CHECK-NEXT: %2 = fmul fast double %1, %x +; CHECK-NEXT: %3 = fmul fast double %"x'", %mul +; CHECK-NEXT: %4 = fadd fast double %2, %3 +; CHECK-NEXT: store double %4, double* %"arrayinit.element'ipg", align 8 +; CHECK-NEXT: %"arrayinit.element3'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 2 +; CHECK-NEXT: store double %"x'", double* %"arrayinit.element3'ipg", align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/store3.ll b/enzyme/test/Enzyme/ForwardModeSplit/store3.ll new file mode 100644 index 0000000000000..4d6a714e1b204 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/store3.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -gvn -dse -S | FileCheck %s + +; Function Attrs: noinline norecurse nounwind uwtable +define dso_local double @f(double* noalias nocapture %out, double %x) #0 { +entry: + store double %x, double* %out, align 8 + store double 0.000000e+00, double* %out, align 8 + %res = load double, double* %out + ret double %res +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp, double %inp, double %in2) local_unnamed_addr #1 { +entry: + %call = tail call fast double @__enzyme_fwdsplit(i8* bitcast (double (double*, double)* @f to i8*), double* %x, double* %xp, double %inp, double 1.0, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(i8*, double*, double*, double, double, i8*) local_unnamed_addr + +attributes #0 = { noinline norecurse nounwind uwtable } +attributes #1 = { noinline nounwind uwtable } + +; CHECK: define internal double @fwddiffef(double* noalias nocapture %out, double* nocapture %"out'", double %x, double %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: store double 0.000000e+00, double* %"out'", align 8 +; CHECK-NEXT: ret double 0.000000e+00 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/storeconstexpr.ll b/enzyme/test/Enzyme/ForwardModeSplit/storeconstexpr.ll new file mode 100644 index 0000000000000..0a7f1dd578e0a --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/storeconstexpr.ll @@ -0,0 +1,24 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -simplifycfg -instcombine -adce -S | FileCheck %s + +@.str = private unnamed_addr constant [18 x i8] c"W(o=%d, i=%d)=%f\0A\00", align 1 + +define void @derivative(i64* %from, i64* %fromp, i64* %to, i64* %top) { +entry: + %call = call double (i8*, ...) @__enzyme_fwdsplit(i8* bitcast (void (i64*, i64*)* @callee to i8*), metadata !"enzyme_dup", i64* %from, i64* %fromp, metadata !"enzyme_dup", i64* %to, i64* %top, i8* null) + ret void +} + +define void @callee(i64* %from, i64* %to) { +entry: + store i64 ptrtoint ([18 x i8]* @.str to i64), i64* %to + ret void +} + +; Function Attrs: alwaysinline +declare double @__enzyme_fwdsplit(i8*, ...) + +; CHECK: define internal void @fwddiffecallee(i64* %from, i64* %"from'", i64* %to, i64* %"to'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/sub.ll b/enzyme/test/Enzyme/ForwardModeSplit/sub.ll new file mode 100644 index 0000000000000..2e166d49b3708 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/sub.ll @@ -0,0 +1,24 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -instcombine -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fsub fast double %x, %y + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwdsplit(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 0.0, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double, double)*, ...) + +; CHECK: define internal {{(dso_local )?}}double @fwddiffetester(double %x, double %"x'", double %y, double %"y'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = fsub fast double %"x'", %"y'" +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/sumnlist.ll b/enzyme/test/Enzyme/ForwardModeSplit/sumnlist.ll new file mode 100644 index 0000000000000..8704cbfd045ef --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/sumnlist.ll @@ -0,0 +1,130 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -gvn -early-cse-memssa -instcombine -instsimplify -simplifycfg -adce -licm -correlated-propagation -instcombine -correlated-propagation -adce -instsimplify -correlated-propagation -jump-threading -instsimplify -early-cse -simplifycfg -S | FileCheck %s + +; #include +; #include +; +; struct n { +; double *values; +; struct n *next; +; }; +; +; __attribute__((noinline)) +; double sum_list(const struct n *__restrict node, unsigned long times) { +; double sum = 0; +; for(const struct n *val = node; val != 0; val = val->next) { +; for(int i=0; i<=times; i++) { +; sum += val->values[i]; +; } +; } +; return sum; +; } + +%struct.n = type { double*, %struct.n* } + +; Function Attrs: noinline norecurse nounwind readonly uwtable +define dso_local double @sum_list(%struct.n* noalias readonly %node, i64 %times) local_unnamed_addr #0 { +entry: + %cmp18 = icmp eq %struct.n* %node, null + br i1 %cmp18, label %for.cond.cleanup, label %for.cond1.preheader + +for.cond1.preheader: ; preds = %for.cond.cleanup4, %entry + %val.020 = phi %struct.n* [ %1, %for.cond.cleanup4 ], [ %node, %entry ] + %sum.019 = phi double [ %add, %for.cond.cleanup4 ], [ 0.000000e+00, %entry ] + %values = getelementptr inbounds %struct.n, %struct.n* %val.020, i64 0, i32 0 + %0 = load double*, double** %values, align 8, !tbaa !2 + br label %for.body5 + +for.cond.cleanup: ; preds = %for.cond.cleanup4, %entry + %sum.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %add, %for.cond.cleanup4 ] + ret double %sum.0.lcssa + +for.cond.cleanup4: ; preds = %for.body5 + %next = getelementptr inbounds %struct.n, %struct.n* %val.020, i64 0, i32 1 + %1 = load %struct.n*, %struct.n** %next, align 8, !tbaa !7 + %cmp = icmp eq %struct.n* %1, null + br i1 %cmp, label %for.cond.cleanup, label %for.cond1.preheader + +for.body5: ; preds = %for.body5, %for.cond1.preheader + %indvars.iv = phi i64 [ 0, %for.cond1.preheader ], [ %indvars.iv.next, %for.body5 ] + %sum.116 = phi double [ %sum.019, %for.cond1.preheader ], [ %add, %for.body5 ] + %arrayidx = getelementptr inbounds double, double* %0, i64 %indvars.iv + %2 = load double, double* %arrayidx, align 8, !tbaa !8 + %add = fadd fast double %2, %sum.116 + %indvars.iv.next = add nuw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv, %times + br i1 %exitcond, label %for.cond.cleanup4, label %for.body5 +} + +; Function Attrs: nounwind +declare dso_local noalias i8* @malloc(i64) local_unnamed_addr #2 + +; Function Attrs: noinline nounwind uwtable +define dso_local double @derivative(%struct.n* %x, %struct.n* %xp, i64 %n) { +entry: + %0 = tail call double (double (%struct.n*, i64)*, ...) @__enzyme_fwdsplit(double (%struct.n*, i64)* nonnull @sum_list, metadata !"enzyme_nofree", %struct.n* %x, %struct.n* %xp, i64 %n, i8* null) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (%struct.n*, i64)*, ...) #4 + + +attributes #0 = { noinline norecurse nounwind readonly uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #1 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #2 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #3 = { noinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #4 = { nounwind } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 7.1.0 "} +!2 = !{!3, !4, i64 0} +!3 = !{!"n", !4, i64 0, !4, i64 8} +!4 = !{!"any pointer", !5, i64 0} +!5 = !{!"omnipotent char", !6, i64 0} +!6 = !{!"Simple C/C++ TBAA"} +!7 = !{!3, !4, i64 8} +!8 = !{!9, !9, i64 0} +!9 = !{!"double", !5, i64 0} +!10 = !{!4, !4, i64 0} + + +; CHECK: define internal double @fwddiffesum_list(%struct.n* noalias readonly %node, %struct.n* %"node'", i64 %times, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %truetape.elt = bitcast i8* %tapeArg to double*** +; CHECK-NEXT: %truetape.unpack = load double**, double*** %truetape.elt, align 8 +; CHECK-NEXT: %truetape.elt7 = getelementptr inbounds i8, i8* %tapeArg, i64 16 +; CHECK-NEXT: %0 = bitcast i8* %truetape.elt7 to %struct.n*** +; CHECK-NEXT: %truetape.unpack8 = load %struct.n**, %struct.n*** %0, align 8 +; CHECK-NEXT: %cmp18 = icmp eq %struct.n* %node, null +; CHECK-NEXT: br i1 %cmp18, label %for.cond.cleanup, label %for.cond1.preheader + +; CHECK: for.cond1.preheader: ; preds = %entry, %for.cond.cleanup4 +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.cond.cleanup4 ], [ 0, %entry ] +; CHECK-NEXT: %"sum.019'" = phi {{(fast )?}}double [ %5, %for.cond.cleanup4 ], [ 0.000000e+00, %entry ] +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %1 = getelementptr inbounds double*, double** %truetape.unpack, i64 %iv +; CHECK-NEXT: %"'il_phi" = load double*, double** %1, align 8, !invariant.group !16 +; CHECK-NEXT: br label %for.body5 + +; CHECK: for.cond.cleanup: ; preds = %for.cond.cleanup4, %entry +; CHECK-NEXT: %"sum.0.lcssa'" = phi {{(fast )?}}double [ 0.000000e+00, %entry ], [ %5, %for.cond.cleanup4 ] +; CHECK-NEXT: ret double %"sum.0.lcssa'" + +; CHECK: for.cond.cleanup4: ; preds = %for.body5 +; CHECK-NEXT: %2 = getelementptr inbounds %struct.n*, %struct.n** %truetape.unpack8, i64 %iv +; CHECK-NEXT: %3 = load %struct.n*, %struct.n** %2, align 8, !invariant.group !17 +; CHECK-NEXT: %cmp = icmp eq %struct.n* %3, null +; CHECK-NEXT: br i1 %cmp, label %for.cond.cleanup, label %for.cond1.preheader + +; CHECK: for.body5: ; preds = %for.body5, %for.cond1.preheader +; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %for.body5 ], [ 0, %for.cond1.preheader ] +; CHECK-NEXT: %"sum.116'" = phi {{(fast )?}}double [ %5, %for.body5 ], [ %"sum.019'", %for.cond1.preheader ] +; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1 +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"'il_phi", i64 %iv1 +; CHECK-NEXT: %4 = load double, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: %5 = fadd fast double %4, %"sum.116'" +; CHECK-NEXT: %exitcond = icmp eq i64 %iv1, %times +; CHECK-NEXT: br i1 %exitcond, label %for.cond.cleanup4, label %for.body5 diff --git a/enzyme/test/Enzyme/ForwardModeSplit/sumsimple.ll b/enzyme/test/Enzyme/ForwardModeSplit/sumsimple.ll new file mode 100644 index 0000000000000..aa40ef7eb32cb --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/sumsimple.ll @@ -0,0 +1,64 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S -early-cse -simplifycfg | FileCheck %s + +; Function Attrs: noinline nounwind uwtable +define dso_local void @f(double* %x, double** %y, i64 %n) #0 { +entry: + br label %for.cond + +for.cond: ; preds = %for.body, %entry + %i.0 = phi i64 [ 0, %entry ], [ %inc, %for.body ] + %cmp = icmp ule i64 %i.0, %n + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %arrayidx = getelementptr inbounds double, double* %x, i64 0 + %0 = load double, double* %arrayidx + %1 = load double*, double** %y + %2 = load double, double* %1 + %add = fadd fast double %2, %0 + store double %add, double* %1 + %inc = add i64 %i.0, 1 + br label %for.cond + +for.end: ; preds = %for.cond + ret void +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp, double** %y, double** %yp, i64 %n) #0 { +entry: + %call = call double (...) @__enzyme_fwdsplit(i8* bitcast (void (double*, double**, i64)* @f to i8*), metadata !"enzyme_nofree", double* %x, double* %xp, double** %y, double** %yp, i64 %n, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(...) + + +attributes #0 = { noinline nounwind uwtable } + + +; CHECK: define internal void @fwddiffef(double* %x, double* %"x'", double** %y, double** %"y'", i64 %n, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to double*** +; CHECK-NEXT: %truetape = load double**, double*** %0 +; CHECK-NEXT: %1 = add nuw i64 %n, 1 +; CHECK-NEXT: br label %for.cond + +; CHECK: for.cond: ; preds = %for.body, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.body ], [ 0, %entry ] +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %cmp = icmp ne i64 %iv, %1 +; CHECK-NEXT: br i1 %cmp, label %for.body, label %for.end + +; CHECK: for.body: ; preds = %for.cond +; CHECK-NEXT: %2 = load double, double* %"x'" +; CHECK-NEXT: %3 = getelementptr inbounds double*, double** %truetape, i64 %iv +; CHECK-NEXT: %"'il_phi" = load double*, double** %3, align 8, !invariant.group !1 +; CHECK-NEXT: %4 = load double, double* %"'il_phi" +; CHECK-NEXT: %5 = fadd fast double %4, %2 +; CHECK-NEXT: store double %5, double* %"'il_phi" +; CHECK-NEXT: br label %for.cond + +; CHECK: for.end: ; preds = %for.cond +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/sumsimpleoptnone.ll b/enzyme/test/Enzyme/ForwardModeSplit/sumsimpleoptnone.ll new file mode 100644 index 0000000000000..d7d39da6b630e --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/sumsimpleoptnone.ll @@ -0,0 +1,64 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S -early-cse | FileCheck %s + +; Function Attrs: noinline nounwind uwtable +define dso_local void @f(double* %x, double** %y, i64 %n) #0 { +entry: + br label %for.cond + +for.cond: ; preds = %for.body, %entry + %i.0 = phi i64 [ 0, %entry ], [ %inc, %for.body ] + %cmp = icmp ule i64 %i.0, %n + br i1 %cmp, label %for.body, label %for.end + +for.body: ; preds = %for.cond + %arrayidx = getelementptr inbounds double, double* %x, i64 0 + %0 = load double, double* %arrayidx + %1 = load double*, double** %y + %2 = load double, double* %1 + %add = fadd fast double %2, %0 + store double %add, double* %1 + %inc = add i64 %i.0, 1 + br label %for.cond + +for.end: ; preds = %for.cond + ret void +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp, double** %y, double** %yp, i64 %n) #0 { +entry: + %call = call double (...) @__enzyme_fwdsplit(i8* bitcast (void (double*, double**, i64)* @f to i8*), metadata !"enzyme_nofree", double* %x, double* %xp, double** %y, double** %yp, i64 %n, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(...) + + +attributes #0 = { noinline nounwind uwtable optnone } + + +; CHECK: define internal void @fwddiffef(double* %x, double* %"x'", double** %y, double** %"y'", i64 %n, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to double*** +; CHECK-NEXT: %truetape = load double**, double*** %0 +; CHECK-NEXT: %1 = add nuw i64 %n, 1 +; CHECK-NEXT: br label %for.cond + +; CHECK: for.cond: ; preds = %for.body, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.body ], [ 0, %entry ] +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %cmp = icmp ne i64 %iv, %1 +; CHECK-NEXT: br i1 %cmp, label %for.body, label %for.end + +; CHECK: for.body: ; preds = %for.cond +; CHECK-NEXT: %2 = load double, double* %"x'" +; CHECK-NEXT: %3 = getelementptr inbounds double*, double** %truetape, i64 %iv +; CHECK-NEXT: %"'il_phi" = load double*, double** %3, align 8, !invariant.group !1 +; CHECK-NEXT: %4 = load double, double* %"'il_phi" +; CHECK-NEXT: %5 = fadd fast double %4, %2 +; CHECK-NEXT: store double %5, double* %"'il_phi" +; CHECK-NEXT: br label %for.cond + +; CHECK: for.end: ; preds = %for.cond +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/sumsquare.ll b/enzyme/test/Enzyme/ForwardModeSplit/sumsquare.ll new file mode 100644 index 0000000000000..0bd211311e391 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/sumsquare.ll @@ -0,0 +1,60 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -early-cse -S | FileCheck %s + +; Function Attrs: norecurse nounwind readonly uwtable +define dso_local double @sumsquare(double* nocapture readonly %x, i64 %n) #0 { +entry: + br label %for.body + +for.cond.cleanup: ; preds = %for.body + ret double %add + +for.body: ; preds = %entry, %for.body + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %total.011 = phi double [ 0.000000e+00, %entry ], [ %add, %for.body ] + %arrayidx = getelementptr inbounds double, double* %x, i64 %indvars.iv + %0 = load double, double* %arrayidx, align 8 + %mul = fmul fast double %0, %0 + %add = fadd fast double %mul, %total.011 + %indvars.iv.next = add nuw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv, %n + br i1 %exitcond, label %for.cond.cleanup, label %for.body +} + +; Function Attrs: nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp, i64 %n, i8* %tapeArg) local_unnamed_addr #1 { +entry: + %0 = tail call double (double (double*, i64)*, ...) @__enzyme_fwdsplit(double (double*, i64)* nonnull @sumsquare, metadata !"enzyme_nofree", double* %x, double* %xp, i64 %n, i8* %tapeArg) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwdsplit(double (double*, i64)*, ...) #2 + +attributes #0 = { norecurse nounwind readonly uwtable } +attributes #1 = { nounwind uwtable } +attributes #2 = { nounwind } + + +; CHECK: define {{(dso_local )?}}double @dsumsquare(double* %x, double* %xp, i64 %n, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to double** +; CHECK-NEXT: %truetape.i = load double*, double** %0 +; CHECK-NEXT: br label %for.body.i + +; CHECK: for.body.i: ; preds = %for.body.i, %entry +; CHECK-NEXT: %iv.i = phi i64 [ %iv.next.i, %for.body.i ], [ 0, %entry ] +; CHECK-NEXT: %"total.011'.i" = phi {{(fast )?}}double [ 0.000000e+00, %entry ], [ %6, %for.body.i ] +; CHECK-NEXT: %iv.next.i = add nuw nsw i64 %iv.i, 1 +; CHECK-NEXT: %"arrayidx'ipg.i" = getelementptr inbounds double, double* %xp, i64 %iv.i +; CHECK-NEXT: %1 = getelementptr inbounds double, double* %truetape.i, i64 %iv.i +; CHECK-NEXT: %2 = load double, double* %1, align 8, !invariant.group !1 +; CHECK-NEXT: %3 = load double, double* %"arrayidx'ipg.i", align 8 +; CHECK-NEXT: %4 = fmul fast double %3, %2 +; CHECK-NEXT: %5 = fadd fast double %4, %4 +; CHECK-NEXT: %6 = fadd fast double %5, %"total.011'.i" +; CHECK-NEXT: %exitcond.i = icmp eq i64 %iv.i, %n +; CHECK-NEXT: br i1 %exitcond.i, label %fwddiffesumsquare.exit, label %for.body.i + +; CHECK: fwddiffesumsquare.exit: ; preds = %for.body.i +; CHECK-NEXT: ret double %6 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/sumwithbreak.ll b/enzyme/test/Enzyme/ForwardModeSplit/sumwithbreak.ll new file mode 100644 index 0000000000000..5150499d926e3 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/sumwithbreak.ll @@ -0,0 +1,76 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instcombine -correlated-propagation -adce -instcombine -simplifycfg -early-cse -simplifycfg -loop-unroll -instcombine -simplifycfg -gvn -jump-threading -instcombine -simplifycfg -S | FileCheck %s + +; Function Attrs: noinline nounwind uwtable +define dso_local double @f(double* nocapture readonly %x, i64 %n) #0 { +entry: + br label %for.body + +for.body: ; preds = %if.end, %entry + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %if.end ] + %data.016 = phi double [ 0.000000e+00, %entry ], [ %add5, %if.end ] + %cmp2 = fcmp fast ogt double %data.016, 1.000000e+01 + br i1 %cmp2, label %if.then, label %if.end + +if.then: ; preds = %for.body + %arrayidx = getelementptr inbounds double, double* %x, i64 %n + %0 = load double, double* %arrayidx, align 8 + %add = fadd fast double %0, %data.016 + br label %cleanup + +if.end: ; preds = %for.body + %arrayidx4 = getelementptr inbounds double, double* %x, i64 %indvars.iv + %1 = load double, double* %arrayidx4, align 8 + %add5 = fadd fast double %1, %data.016 + %indvars.iv.next = add nuw i64 %indvars.iv, 1 + %cmp = icmp ult i64 %indvars.iv, %n + br i1 %cmp, label %for.body, label %cleanup + +cleanup: ; preds = %if.end, %if.then + %data.1 = phi double [ %add, %if.then ], [ %add5, %if.end ] + ret double %data.1 +} + +; Function Attrs: noinline nounwind uwtable +define dso_local double @dsumsquare(double* %x, double* %xp, i64 %n) #0 { +entry: + %call = call double (...) @__enzyme_fwdsplit(i8* bitcast (double (double*, i64)* @f to i8*), metadata !"enzyme_nofree", double* %x, double* %xp, i64 %n, i8* null) + ret double %call +} + +declare dso_local double @__enzyme_fwdsplit(...) + + +attributes #0 = { noinline nounwind uwtable } + + +; CHECK: define internal double @fwddiffef(double* nocapture readonly %x, double* nocapture %"x'", i64 %n, i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast i8* %tapeArg to i1** +; CHECK-NEXT: %truetape = load i1*, i1** %0, align 8, !enzyme_mustcache !4 +; CHECK-NEXT: br label %for.body + +; CHECK: for.body: ; preds = %if.end, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %if.end ], [ 0, %entry ] +; CHECK-NEXT: %"data.016'" = phi {{(fast )?}}double [ %5, %if.end ], [ 0.000000e+00, %entry ] +; CHECK-NEXT: %1 = getelementptr inbounds i1, i1* %truetape, i64 %iv +; CHECK-NEXT: %cmp2 = load i1, i1* %1, align 1, !invariant.group !5 +; CHECK-NEXT: br i1 %cmp2, label %if.then, label %if.end + +; CHECK: if.then: ; preds = %for.body +; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 %n +; CHECK-NEXT: %2 = load double, double* %"arrayidx'ipg", align 8 +; CHECK-NEXT: %3 = fadd fast double %2, %"data.016'" +; CHECK-NEXT: br label %cleanup + +; CHECK: if.end: ; preds = %for.body +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %"arrayidx4'ipg" = getelementptr inbounds double, double* %"x'", i64 %iv +; CHECK-NEXT: %4 = load double, double* %"arrayidx4'ipg", align 8 +; CHECK-NEXT: %5 = fadd fast double %4, %"data.016'" +; CHECK-NEXT: %cmp = icmp ult i64 %iv, %n +; CHECK-NEXT: br i1 %cmp, label %for.body, label %cleanup + +; CHECK: cleanup: ; preds = %if.end, %if.then +; CHECK-NEXT: %"data.1'" = phi {{(fast )?}}double [ %3, %if.then ], [ %5, %if.end ] +; CHECK-NEXT: ret double %"data.1'" +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/vecsquare.ll b/enzyme/test/Enzyme/ForwardModeSplit/vecsquare.ll new file mode 100644 index 0000000000000..36bf3b53a4af7 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/vecsquare.ll @@ -0,0 +1,43 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse -instcombine -S | FileCheck %s + +declare {float, float, float} @__enzyme_fwdsplit({float, float, float} (<4 x float>)*, <4 x float>, <4 x float>, i8*) + +define {float, float, float} @square(<4 x float> %x) { +entry: + %vec = insertelement <4 x float> %x, float 1.0, i32 3 + %sq = fmul <4 x float> %x, %x + %cb = fmul <4 x float> %sq, %x + %id = shufflevector <4 x float> %sq, <4 x float> %cb, <4 x i32> + %res1 = extractelement <4 x float> %id, i32 1 + %res2 = extractelement <4 x float> %id, i32 2 + %res3 = extractelement <4 x float> %id, i32 3 + %agg1 = insertvalue {float, float, float} undef, float %res1, 0 + %agg2 = insertvalue {float, float, float} %agg1, float %res2, 1 + %agg3 = insertvalue {float, float, float} %agg2, float %res3, 2 + ret {float, float, float} %agg3 +} + +define {float, float, float} @dsquare(<4 x float> %x) { +entry: + %call = tail call {float, float, float} @__enzyme_fwdsplit({float, float, float} (<4 x float>)* @square, <4 x float> %x, <4 x float> , i8* null) + ret {float, float, float} %call +} + + +; CHECK: define internal { float, float, float } @fwddiffesquare(<4 x float> %x, <4 x float> %"x'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %sq = fmul <4 x float> %x, %x +; CHECK-NEXT: %0 = fmul fast <4 x float> %"x'", %x +; CHECK-NEXT: %1 = fadd fast <4 x float> %0, %0 +; CHECK-NEXT: %2 = fmul fast <4 x float> %1, %x +; CHECK-NEXT: %3 = fmul fast <4 x float> %sq, %"x'" +; CHECK-NEXT: %4 = fadd fast <4 x float> %2, %3 +; CHECK-NEXT: %5 = extractelement <4 x float> %1, i32 1 +; CHECK-NEXT: %6 = extractelement <4 x float> %4, i32 0 +; CHECK-NEXT: %7 = extractelement <4 x float> %4, i32 1 +; CHECK-NEXT: %8 = insertvalue { float, float, float } zeroinitializer, float %5, 0 +; CHECK-NEXT: %9 = insertvalue { float, float, float } %8, float %6, 1 +; CHECK-NEXT: %10 = insertvalue { float, float, float } %9, float %7, 2 +; CHECK-NEXT: ret { float, float, float } %10 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardModeSplit/vector_reduce_fadd.ll b/enzyme/test/Enzyme/ForwardModeSplit/vector_reduce_fadd.ll new file mode 100644 index 0000000000000..6b8915d907267 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeSplit/vector_reduce_fadd.ll @@ -0,0 +1,26 @@ +; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi + +define float @tester(float %start_value, <4 x float> %input) { +entry: + %ord = call float @llvm.vector.reduce.fadd.v4f32(float %start_value, <4 x float> %input) + ret float %ord +} + +define float @test_derivative(float %start_value, <4 x float> %input) { +entry: + %0 = tail call float (float (float, <4 x float>)*, ...) @__enzyme_fwdsplit(float (float, <4 x float>)* nonnull @tester, float %start_value, float 1.0, <4 x float> %input, <4 x float> , i8* null) + ret float %0 +} + +declare float @llvm.vector.reduce.fadd.v4f32(float, <4 x float>) + +; Function Attrs: nounwind +declare float @__enzyme_fwdsplit(float (float, <4 x float>)*, ...) + + +; CHECK: define internal {{(dso_local )?}}float @fwddiffetester(float %start_value, float %"start_value'", <4 x float> %input, <4 x float> %"input'", i8* %tapeArg) +; CHECK-NEXT: entry: +; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg) +; CHECK-NEXT: %0 = call fast float @llvm.vector.reduce.fadd.v4f32(float %"start_value'", <4 x float> %"input'") +; CHECK-NEXT: ret float %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/atomicadd.ll b/enzyme/test/Enzyme/ReverseMode/atomicadd.ll new file mode 100644 index 0000000000000..d99da26a6faa0 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/atomicadd.ll @@ -0,0 +1,33 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s + +; Function Attrs: norecurse nounwind readonly uwtable +define dso_local double @sum(i64* nocapture %n, double %x) #0 { +entry: + %res = atomicrmw add i64* %n, i64 1 monotonic + %fp = uitofp i64 %res to double + %mul = fmul double %fp, %x + ret double %mul +} + +; Function Attrs: nounwind uwtable +define dso_local void @dsum(i64* %x, i64* %xp, double %n) local_unnamed_addr #1 { +entry: + %0 = tail call double (double (i64*, double)*, ...) @__enzyme_autodiff(double (i64*, double)* nonnull @sum, i64* %x, double %n) + ret void +} + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (i64*, double)*, ...) #2 + +attributes #0 = { norecurse nounwind readonly uwtable } +attributes #1 = { nounwind uwtable } +attributes #2 = { nounwind } + +; CHECK: define internal { double } @diffesum(i64* nocapture %n, double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %res = atomicrmw add i64* %n, i64 1 monotonic +; CHECK-NEXT: %fp = uitofp i64 %res to double +; CHECK-NEXT: %m1diffex = fmul fast double %differeturn, %fp +; CHECK-NEXT: %0 = insertvalue { double } undef, double %m1diffex, 0 +; CHECK-NEXT: ret { double } %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/bitcastfn.ll b/enzyme/test/Enzyme/ReverseMode/bitcastfn.ll index 8b421981479f9..0f47b4c9890a6 100644 --- a/enzyme/test/Enzyme/ReverseMode/bitcastfn.ll +++ b/enzyme/test/Enzyme/ReverseMode/bitcastfn.ll @@ -324,8 +324,7 @@ attributes #8 = { noreturn nounwind "correctly-rounded-divide-sqrt-fp-math"="fal ; TODO no need for malloc/free ; CHECK: define internal i8* @augmented_indir(%"class.boost::array.1"* dereferenceable(8) %x, %"class.boost::array.1"* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 0) -; CHECK-NEXT: ret i8* %malloccall +; CHECK-NEXT: ret i8* null ; CHECK-NEXT: } ; CHECK: define internal void @diffeindir(%"class.boost::array.1"* dereferenceable(8) %x, %"class.boost::array.1"* %"x'", i8* %tapeArg) diff --git a/enzyme/test/Enzyme/ReverseMode/callvalue.ll b/enzyme/test/Enzyme/ReverseMode/callvalue.ll index 0523ee5101797..736cba898122f 100644 --- a/enzyme/test/Enzyme/ReverseMode/callvalue.ll +++ b/enzyme/test/Enzyme/ReverseMode/callvalue.ll @@ -51,10 +51,8 @@ attributes #2 = { nounwind } ; CHECK: define internal { i8*, double } @augmented_square(double %x) ; CHECK-NEXT: entry: -; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 0) ; CHECK-NEXT: %mul = fmul fast double %x, %x -; CHECK-NEXT: %[[iv1:.+]] = insertvalue { i8*, double } undef, i8* %malloccall, 0 -; CHECK-NEXT: %[[iv2:.+]] = insertvalue { i8*, double } %[[iv1]], double %mul, 1 +; CHECK-NEXT: %[[iv2:.+]] = insertvalue { i8*, double } { i8* null, double undef }, double %mul, 1 ; CHECK-NEXT: ret { i8*, double } %[[iv2]] ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/phiswitch.ll b/enzyme/test/Enzyme/ReverseMode/phiswitch.ll index b48110cb05fec..f008e89b83ef3 100644 --- a/enzyme/test/Enzyme/ReverseMode/phiswitch.ll +++ b/enzyme/test/Enzyme/ReverseMode/phiswitch.ll @@ -46,12 +46,12 @@ bb13: ; preds = %bb12, %bb9, %bb8, % ; CHECK: define internal { double } @diffejulia_euroad_1769(double %arg, i64 %i5, double %differeturn) ; CHECK-NEXT: bb: -; CHECK-NEXT: %0 = icmp eq i64 7, %i5 -; CHECK-NEXT: %1 = icmp eq i64 12, %i5 -; CHECK-NEXT: %2 = or i1 %1, %0 -; CHECK-NEXT: %3 = select {{(fast )?}}i1 %1, double %differeturn, double 0.000000e+00 +; CHECK-DAG: %[[i0:.+]] = icmp eq i64 7, %i5 +; CHECK-DAG: %[[i1:.+]] = icmp eq i64 12, %i5 +; CHECK-NEXT: %2 = or i1 %[[i1]], %[[i0]] +; CHECK-NEXT: %3 = select {{(fast )?}}i1 %[[i1]], double %differeturn, double 0.000000e+00 ; CHECK-NEXT: %4 = select {{(fast )?}}i1 %2, double 0.000000e+00, double %differeturn -; CHECK-NEXT: %5 = select {{(fast )?}}i1 %0, double %differeturn, double 0.000000e+00 +; CHECK-NEXT: %5 = select {{(fast )?}}i1 %[[i0]], double %differeturn, double 0.000000e+00 ; CHECK-NEXT: switch i64 %i5, label %invertbb9 [ ; CHECK-NEXT: i64 12, label %invertbb ; CHECK-NEXT: i64 7, label %invertbb7