Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) Batched autodiff #2181

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
16 changes: 16 additions & 0 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,20 @@ def GenericAdjointOp : Enzyme_Op<"genericAdjoint", [AttrSizedOperandSegments]> {

}

def BroadcastOp : Enzyme_Op<"broadcast"> {
let description = [{
Broadcast the operand by adding an extra dimension the frond with a size equal to the width attribute.
For scalar operands, a one-dimensional ranked tensor is created.

NOTE: Only works for scalars and *ranked* tensors for now.
}];

let arguments = (ins AnyType:$input, I64Attr:$width);
let results = (outs AnyRankedTensor:$output);

let builders = [
OpBuilder<(ins "Value":$input, "int64_t":$width)>
];
}

#endif // ENZYME_OPS
20 changes: 20 additions & 0 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,23 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

return success();
}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//

void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input, int64_t width) {
auto widthAttr = builder.getI64IntegerAttr(width);
RankedTensorType output;
// TODO: support things other than scalars and ranked tensors, maybe reuse getShadowType here?
if (auto tensorType = input.getType().dyn_cast<TensorType>()) {
auto shape = tensorType.getShape();
SmallVector<int64_t, 4> newShape;
newShape.push_back(width);
newShape.append(shape.begin(), shape.end());
output = RankedTensorType::get(newShape, tensorType.getElementType());
} else {
output = RankedTensorType::get({width}, input.getType());
}
build(builder, result, output, input, widthAttr);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

Expand Down Expand Up @@ -69,3 +70,10 @@ void mlir::enzyme::registerArithDialectAutoDiffInterface(
arith::ConstantOp::attachInterface<ArithConstantOpBatchInterface>(*context);
});
}

void mlir::enzyme::registerTensorDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) {
registerInterfaces(context);
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ class FloatTypeInterface
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
if (width > 1) {
return RankedTensorType::get({width}, self);
} else {
return self;
}
}

bool isMutable(Type self) const { return false; }
Expand Down Expand Up @@ -106,7 +109,14 @@ class TensorTypeInterface
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
if (width != 1) {
auto tenType = self.cast<TensorType>();
auto shape = tenType.getShape();
SmallVector<int64_t, 4> newShape;
newShape.push_back(width);
newShape.append(shape.begin(), shape.end());
return RankedTensorType::get(newShape, tenType.getElementType());
}
return self;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,4 +432,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces(
enzyme::registerCFDialectAutoDiffInterface(registry);
enzyme::registerLinalgDialectAutoDiffInterface(registry);
enzyme::registerFuncDialectAutoDiffInterface(registry);
enzyme::registerTensorDialectAutoDiffInterface(registry);
}
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ void registerCFDialectAutoDiffInterface(DialectRegistry &registry);
void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry);
void registerMathDialectAutoDiffInterface(DialectRegistry &registry);
void registerFuncDialectAutoDiffInterface(DialectRegistry &registry);
void registerTensorDialectAutoDiffInterface(DialectRegistry &registry);

void registerCoreDialectAutodiffInterfaces(DialectRegistry &registry);

Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ FunctionOpInterface CloneFunctionWithReturns(
mlir::Value val = blk.getArgument(i);
mlir::Value dval;
if (i == ArgActivity.size() - 1)
dval = blk.addArgument(val.getType(), val.getLoc());
dval = blk.addArgument(getShadowType(val.getType(), width), val.getLoc());
else
dval = blk.insertArgument(blk.args_begin() + i + 1, val.getType(),
dval = blk.insertArgument(blk.args_begin() + i + 1, getShadowType(val.getType(), width),
val.getLoc());
ptrInputs.map(oval, dval);
}
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms
MLIRFuncDialect
MLIRFuncTransforms
MLIRGPUDialect
MLIRTensorDialect
MLIRIR
MLIRLLVMDialect
MLIRMathDialect
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

#include "Dialect/Dialect.h"

Expand Down Expand Up @@ -80,6 +81,10 @@ namespace affine {
class AffineDialect;
} // end namespace affine

namespace tensor {
class TensorDialect;
} // end namespace tensor

namespace LLVM {
class LLVMDialect;
} // end namespace LLVM
Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def DifferentiatePass : Pass<"enzyme"> {
let dependentDialects = [
"arith::ArithDialect",
"complex::ComplexDialect",
"cf::ControlFlowDialect"
"cf::ControlFlowDialect",
"tensor::TensorDialect",
];
let constructor = "mlir::enzyme::createDifferentiatePass()";
}
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/enzymemlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ int main(int argc, char **argv) {
registry.insert<mlir::omp::OpenMPDialect>();
registry.insert<mlir::math::MathDialect>();
registry.insert<mlir::linalg::LinalgDialect>();
registry.insert<mlir::tensor::TensorDialect>();
registry.insert<DLTIDialect>();

registry.insert<mlir::enzyme::EnzymeDialect>();
Expand Down
26 changes: 26 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_scalar.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : f64) -> f64{
%y = arith.mulf %x, %x : f64
return %y : f64
}
func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>)
return %r : tensor<2xf64>
}
}

// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (f64, tensor<2xf64>) -> tensor<2xf64>
// CHECK-NEXT: return %[[i0]] : tensor<2xf64>
// CHECK-NEXT: }
// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{width = 2 : i64}> : f64 -> tensor<2xf64>
// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64>
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{width = 2 : i64}> : f64 -> tensor<2xf64>
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64>
// CHECK-NEXT: return %[[i2]] : tensor<2xf64>
// CHECK-NEXT: }
26 changes: 26 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_tensor.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{
%y = arith.mulf %x, %x : tensor<10xf64>
return %y : tensor<10xf64>
}
func.func @dsq(%x : tensor<10xf64>, %dx : tensor<2x10xf64>) -> tensor<2x10xf64> {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<2x10xf64>)
return %r : tensor<2x10xf64>
}
}

// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> {
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64>
// CHECK-NEXT: return %[[i0]] : tensor<2x10xf64>
// CHECK-NEXT: }
// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> {
// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{width = 2 : i64}> : (tensor<10xf64>) -> tensor<2x10xf64>
// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64>
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{width = 2 : i64}> : (tensor<10xf64>) -> tensor<2x10xf64>
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64>
// CHECK-NEXT: return %[[i2]] : tensor<2x10xf64>
// CHECK-NEXT: }
12 changes: 11 additions & 1 deletion enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,18 @@ SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,
os << ord;
}
if (!vecValue && !startsWith(ord, "local")) {
if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives))
if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) {
os << ")";
if (intrinsic == MLIRDerivatives) {
os << ";\n";
os << "if (gutils->width != 1) {\n"
<< " " << argName << "_" << (idx - 1) << " = builder.create<enzyme::BroadcastOp>(\n"
<< " op.getLoc(),\n"
<< " " << argName << "_" << (idx - 1) << ",\n"
<< " gutils->width);\n"
<< "}";
}
}

if (lookup && intrinsic != MLIRDerivatives)
os << ", " << builder << ")";
Expand Down
Loading