Skip to content

Commit

Permalink
Minor changes to enable AD of Fortran. (rust-lang#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludgerpaehler authored Aug 12, 2021
1 parent dce5f8b commit d40597a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
22 changes: 18 additions & 4 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1014,10 +1014,20 @@ class AdjointGenerator
diffe(&SI, Builder2), "diffe" + op2->getName());

setDiffe(&SI, Constant::getNullValue(SI.getType()), Builder2);
if (dif1)
addToDiffe(orig_op1, dif1, Builder2, TR.addingType(size, orig_op1));
if (dif2)
addToDiffe(orig_op2, dif2, Builder2, TR.addingType(size, orig_op2));
if (dif1) {
Type *addingType = TR.addingType(size, orig_op1);
if (addingType || !looseTypeAnalysis)
addToDiffe(orig_op1, dif1, Builder2, addingType);
else
llvm::errs() << " warning: assuming integral for " << SI << "\n";
}
if (dif2) {
Type *addingType = TR.addingType(size, orig_op2);
if (addingType || !looseTypeAnalysis)
addToDiffe(orig_op2, dif2, Builder2, addingType);
else
llvm::errs() << " warning: assuming integral for " << SI << "\n";
}
}

void createSelectInstDual(llvm::SelectInst &SI) {
Expand Down Expand Up @@ -1710,8 +1720,12 @@ class AdjointGenerator
}
goto def;
}
case Instruction::Mul:
case Instruction::Sub:
case Instruction::Add: {
if (looseTypeAnalysis) {
llvm::errs() << "warning: binary operator is integer and constant: "
<< BO << "\n";
// if loose type analysis, assume this integer add is constant
return;
}
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2509,12 +2509,12 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
Vals.push_back(cast<Constant>(
invertPointerM(CD->getElementAsConstant(i), BuilderM)));
}
return ConstantDataArray::get(CD->getContext(), Vals);
return ConstantArray::get(CD->getType(), Vals);
} else if (auto CD = dyn_cast<ConstantArray>(oval)) {
SmallVector<Constant *, 1> Vals;
for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) {
Vals.push_back(
cast<Constant>(invertPointerM(CD->getOperand(i), BuilderM)));
Value *val = invertPointerM(CD->getOperand(i), BuilderM);
Vals.push_back(cast<Constant>(val));
}
return ConstantArray::get(CD->getType(), Vals);
} else if (auto CD = dyn_cast<ConstantStruct>(oval)) {
Expand Down

0 comments on commit d40597a

Please sign in to comment.