-
Notifications
You must be signed in to change notification settings - Fork 111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(WIP) Batched autodiff #2181
base: main
Are you sure you want to change the base?
(WIP) Batched autodiff #2181
Conversation
@@ -27,7 +27,11 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode, | |||
for (auto &&[Ty, returnPrimal, returnShadow, activity] : llvm::zip( | |||
FTy.getResults(), returnPrimals, returnShadows, ReturnActivity)) { | |||
if (returnPrimal) { | |||
RetTypes.push_back(Ty); | |||
if (width != 1) { | |||
RetTypes.push_back(mlir::RankedTensorType::get({width}, Ty)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn’t need changing since the primal is always unmodified, only Derivatives are changed (and we should be pushing the getshadow types for those below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, then I'm confused of what batched autodiff is.
How should my testcase change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nvm, it clicked. It's just the shadow that's batched 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so here's an example from llvm vector mode for example: https://github.com/EnzymeAD/Enzyme/blob/main/enzyme/test/Enzyme/ForwardModeVector/add.ll
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tho perhaps mul will be more illustrative, https://github.com/EnzymeAD/Enzyme/blob/main/enzyme/test/Enzyme/ForwardModeVector/mul.ll (and obviously feel free to look at any/all of the other examples
I haven't yet fully made the changes in enzyme-tblgen.cpp, and either way this just works for the simple test case. mlir::Value itmp = ({
// Computing MulFOp
auto fwdarg_0 = dif;
auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1));
if (gutils->width != 1)
{
fwdarg_1 = builder.create<tensor::SplatOp>(
op.getLoc(),
mlir::RankedTensorType::get({gutils->width},
fwdarg_1.getType()),
fwdarg_1);
}
builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
}); But this is the MLIR code that is generated for this simple test: func.func private @fwddiffe2square(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> {
%splat = tensor.splat %arg0 : tensor<2xf64>
%0 = arith.mulf %arg1, %splat : tensor<2xf64>
%splat_0 = tensor.splat %arg0 : tensor<2xf64>
%1 = arith.mulf %arg1, %splat_0 : tensor<2xf64>
%2 = arith.addf %0, %1 : tensor<2xf64>
%3 = arith.mulf %arg0, %arg0 : f64
return %2 : tensor<2xf64>
} |
This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this: ``` LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const { auto op = cast<arith::MulFOp>(op0); if (gutils->width != 1) { auto newop = gutils->getNewFromOriginal(op0); for (auto res : newop->getResults()) { res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType())); } } gutils->eraseIfUnused(op); if (gutils->isConstantInstruction(op)) return success(); mlir::Value res = nullptr; if (!gutils->isConstantValue(op->getOperand(0))) { auto dif = gutils->invertPointerM(op->getOperand(0), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); // TODO: gutils->makeBatched(...) auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1)); builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1); }); itmp.dump(); if (!res) res = itmp; else { auto operandType = cast<AutoDiffTypeInterface>(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } if (!gutils->isConstantValue(op->getOperand(1))) { auto dif = gutils->invertPointerM(op->getOperand(1), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0)); builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1); }); if (!res) res = itmp; else { auto operandType = cast<AutoDiffTypeInterface>(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } assert(res); gutils->setDiffe(op->getResult(0), res, builder); return success(); } ```
This reverts commit c06ed01.
Added some type conversions to tensor types if
width != 1
. The simple test case seems correct now.Corresponding Enzyme-JAX PR: EnzymeAD/Enzyme-JAX#197