Skip to content

Commit

Permalink
Add memref/llvm.ptr handling for fwd mode (rust-lang#910)
Browse files Browse the repository at this point in the history
* Add memref handling for fwd mode

* Add simple llvm dialect

* Update enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp

Co-authored-by: Tim Gymnich <timgymnich@me.com>

* fixup

Co-authored-by: Tim Gymnich <timgymnich@me.com>
  • Loading branch information
wsmoses and tgymnich authored Oct 24, 2022
1 parent 7da8126 commit d52cb9c
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 0 deletions.
4 changes: 4 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
add_mlir_library(MLIREnzymeImplementations
ArithAutoDiffOpInterfaceImpl.cpp
LLVMAutoDiffOpInterfaceImpl.cpp
MemRefAutoDiffOpInterfaceImpl.cpp
BuiltinAutoDiffTypeInterfaceImpl.cpp
SCFAutoDiffOpInterfaceImpl.cpp

Expand All @@ -8,6 +10,8 @@ add_mlir_library(MLIREnzymeImplementations

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRLLVMDialect
MLIRMemRefDialect
MLIREnzymeAutoDiffInterface
MLIRIR
MLIRSCFDialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class DialectRegistry;
namespace enzyme {
void registerArithDialectAutoDiffInterface(DialectRegistry &registry);
void registerBuiltinDialectAutoDiffInterface(DialectRegistry &registry);
void registerLLVMDialectAutoDiffInterface(DialectRegistry &registry);
void registerMemRefDialectAutoDiffInterface(DialectRegistry &registry);
void registerSCFDialectAutoDiffInterface(DialectRegistry &registry);
} // namespace enzyme
} // namespace mlir
95 changes: 95 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//===- LLVMAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the external model implementation of the automatic
// differentiation op interfaces for the upstream LLVM dialect.
//
//===----------------------------------------------------------------------===//

#include "Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/GradientUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

using namespace mlir;
using namespace mlir::enzyme;

namespace {
struct LoadOpInterface
: public AutoDiffOpInterface::ExternalModel<LoadOpInterface, LLVM::LoadOp> {
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtils *gutils) const {
auto loadOp = cast<LLVM::LoadOp>(op);
if (!gutils->isConstantValue(loadOp)) {
mlir::Value res = builder.create<LLVM::LoadOp>(
loadOp.getLoc(), gutils->invertPointerM(loadOp.getAddr(), builder));
gutils->setDiffe(loadOp, res, builder);
}
gutils->eraseIfUnused(op);
return success();
}
};

struct StoreOpInterface
: public AutoDiffOpInterface::ExternalModel<StoreOpInterface,
LLVM::StoreOp> {
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtils *gutils) const {
auto storeOp = cast<LLVM::StoreOp>(op);
if (!gutils->isConstantValue(storeOp.getAddr())) {
builder.create<LLVM::StoreOp>(
storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder),
gutils->invertPointerM(storeOp.getAddr(), builder));
}
gutils->eraseIfUnused(op);
return success();
}
};

struct AllocaOpInterface
: public AutoDiffOpInterface::ExternalModel<AllocaOpInterface,
LLVM::AllocaOp> {
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtils *gutils) const {
auto allocOp = cast<LLVM::AllocaOp>(op);
if (!gutils->isConstantValue(allocOp)) {
Operation *nop = gutils->cloneWithNewOperands(builder, op);
gutils->setDiffe(allocOp, nop->getResult(0), builder);
}
gutils->eraseIfUnused(op);
return success();
}
};

class PointerTypeInterface
: public AutoDiffTypeInterface::ExternalModel<PointerTypeInterface,
LLVM::LLVMPointerType> {
public:
Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
return builder.create<LLVM::NullOp>(loc, self);
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
}
};
} // namespace

void mlir::enzyme::registerLLVMDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, LLVM::LLVMDialect *) {
LLVM::LoadOp::attachInterface<LoadOpInterface>(*context);
LLVM::StoreOp::attachInterface<StoreOpInterface>(*context);
LLVM::AllocaOp::attachInterface<AllocaOpInterface>(*context);
LLVM::LLVMPointerType::attachInterface<PointerTypeInterface>(*context);
});
}
103 changes: 103 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//===- MemRefAutoDiffOpInterfaceImpl.cpp - Interface external model -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the external model implementation of the automatic
// differentiation op interfaces for the upstream MLIR memref dialect.
//
//===----------------------------------------------------------------------===//

#include "Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/GradientUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

using namespace mlir;
using namespace mlir::enzyme;

namespace {
struct LoadOpInterface
: public AutoDiffOpInterface::ExternalModel<LoadOpInterface,
memref::LoadOp> {
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtils *gutils) const {
auto loadOp = cast<memref::LoadOp>(op);
if (!gutils->isConstantValue(loadOp)) {
SmallVector<Value> inds;
for (auto ind : loadOp.getIndices())
inds.push_back(gutils->getNewFromOriginal(ind));
mlir::Value res = builder.create<memref::LoadOp>(
loadOp.getLoc(), gutils->invertPointerM(loadOp.getMemref(), builder),
inds);
gutils->setDiffe(loadOp, res, builder);
}
gutils->eraseIfUnused(op);
return success();
}
};

struct StoreOpInterface
: public AutoDiffOpInterface::ExternalModel<StoreOpInterface,
memref::StoreOp> {
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtils *gutils) const {
auto storeOp = cast<memref::StoreOp>(op);
if (!gutils->isConstantValue(storeOp.getMemref())) {
SmallVector<Value> inds;
for (auto ind : storeOp.getIndices())
inds.push_back(gutils->getNewFromOriginal(ind));
builder.create<memref::StoreOp>(
storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder),
gutils->invertPointerM(storeOp.getMemref(), builder), inds);
}
gutils->eraseIfUnused(op);
return success();
}
};

struct AllocOpInterface
: public AutoDiffOpInterface::ExternalModel<AllocOpInterface,
memref::AllocOp> {
LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtils *gutils) const {
auto allocOp = cast<memref::AllocOp>(op);
if (!gutils->isConstantValue(allocOp)) {
Operation *nop = gutils->cloneWithNewOperands(builder, op);
gutils->setDiffe(allocOp, nop->getResult(0), builder);
}
gutils->eraseIfUnused(op);
return success();
}
};

class MemRefTypeInterface
: public AutoDiffTypeInterface::ExternalModel<MemRefTypeInterface,
MemRefType> {
public:
Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
llvm_unreachable("Cannot create null of memref (todo polygeist null)");
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
}
};
} // namespace

void mlir::enzyme::registerMemRefDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, memref::MemRefDialect *) {
memref::LoadOp::attachInterface<LoadOpInterface>(*context);
memref::StoreOp::attachInterface<StoreOpInterface>(*context);
memref::AllocOp::attachInterface<AllocOpInterface>(*context);
MemRefType::attachInterface<MemRefTypeInterface>(*context);
});
}
8 changes: 8 additions & 0 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ mlir::enzyme::MGradientUtils::getNewFromOriginal(Operation *originst) const {
return found->second;
}

Operation *mlir::enzyme::MGradientUtils::cloneWithNewOperands(OpBuilder &B,
Operation *op) {
BlockAndValueMapping map;
for (auto operand : op->getOperands())
map.map(operand, getNewFromOriginal(operand));
return B.clone(*op, map);
}

bool mlir::enzyme::MGradientUtils::isConstantValue(Value v) const {
if (isa<mlir::IntegerType>(v.getType()))
return true;
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class MGradientUtils {
void setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM);
void forceAugmentedReturns();

Operation *cloneWithNewOperands(OpBuilder &B, Operation *op);

LogicalResult visitChild(Operation *op);
};

Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/enzymemlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ int main(int argc, char **argv) {
// Register the autodiff interface implementations for upstream dialects.
enzyme::registerArithDialectAutoDiffInterface(registry);
enzyme::registerBuiltinDialectAutoDiffInterface(registry);
enzyme::registerLLVMDialectAutoDiffInterface(registry);
enzyme::registerMemRefDialectAutoDiffInterface(registry);
enzyme::registerSCFDialectAutoDiffInterface(registry);

return mlir::failed(
Expand Down
31 changes: 31 additions & 0 deletions enzyme/test/MLIR/llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : f64) -> f64 {
%c1_i64 = arith.constant 1 : i64
%tmp = llvm.alloca %c1_i64 x f64 : (i64) -> !llvm.ptr<f64>
%y = arith.mulf %x, %x : f64
llvm.store %y, %tmp : !llvm.ptr<f64>
%r = llvm.load %tmp : !llvm.ptr<f64>
return %r : f64
}
func.func @dsq(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64)
return %r : f64
}
}

// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 {
// CHECK-NEXT: %[[c1_i64:.+]] = arith.constant 1 : i64
// CHECK-NEXT: %[[i0:.+]] = llvm.alloca %[[c1_i64]] x f64 : (i64) -> !llvm.ptr<f64>
// CHECK-NEXT: %[[i1:.+]] = llvm.alloca %[[c1_i64]] x f64 : (i64) -> !llvm.ptr<f64>
// CHECK-NEXT: %[[i2:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64
// CHECK-NEXT: %[[i4:.+]] = arith.addf %[[i2]], %[[i3]] : f64
// CHECK-NEXT: %[[i5:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64
// CHECK-NEXT: llvm.store %[[i4]], %[[i0]] : !llvm.ptr<f64>
// CHECK-NEXT: llvm.store %[[i5]], %[[i1]] : !llvm.ptr<f64>
// CHECK-NEXT: %[[i6:.+]] = llvm.load %[[i0]] : !llvm.ptr<f64>
// CHECK-NEXT: %[[i7:.+]] = llvm.load %[[i1]] : !llvm.ptr<f64>
// CHECK-NEXT: return %[[i6]] : f64
// CHECK-NEXT: }
31 changes: 31 additions & 0 deletions enzyme/test/MLIR/memref.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : f64) -> f64 {
%c0 = arith.constant 0 : index
%tmp = memref.alloc() : memref<1xf64>
%y = arith.mulf %x, %x : f64
memref.store %y, %tmp[%c0] : memref<1xf64>
%r = memref.load %tmp[%c0] : memref<1xf64>
return %r : f64
}
func.func @dsq(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64)
return %r : f64
}
}

// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 {
// CHECK-NEXT: %[[c0:.+]] = arith.constant 0 : index
// CHECK-NEXT: %[[i0:.+]] = memref.alloc() : memref<1xf64>
// CHECK-NEXT: %[[i1:.+]] = memref.alloc() : memref<1xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64
// CHECK-NEXT: %[[i4:.+]] = arith.addf %[[i2]], %[[i3]] : f64
// CHECK-NEXT: %[[i5:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64
// CHECK-NEXT: memref.store %[[i4]], %[[i0]][%[[c0]]] : memref<1xf64>
// CHECK-NEXT: memref.store %[[i5]], %[[i1]][%[[c0]]] : memref<1xf64>
// CHECK-NEXT: %[[i6:.+]] = memref.load %[[i0]][%[[c0]]] : memref<1xf64>
// CHECK-NEXT: %[[i7:.+]] = memref.load %[[i1]][%[[c0]]] : memref<1xf64>
// CHECK-NEXT: return %[[i6]] : f64
// CHECK-NEXT: }

0 comments on commit d52cb9c

Please sign in to comment.