From d52cb9cf9ea299dead420711d989b1697216f709 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 24 Oct 2022 15:32:50 -0400 Subject: [PATCH] Add memref/llvm.ptr handling for fwd mode (#910) * Add memref handling for fwd mode * Add simple llvm dialect * Update enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp Co-authored-by: Tim Gymnich * fixup Co-authored-by: Tim Gymnich --- .../MLIR/Implementations/CMakeLists.txt | 4 + .../CoreDialectsAutoDiffImplementations.h | 2 + .../LLVMAutoDiffOpInterfaceImpl.cpp | 95 ++++++++++++++++ .../MemRefAutoDiffOpInterfaceImpl.cpp | 103 ++++++++++++++++++ .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 8 ++ enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h | 2 + enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 2 + enzyme/test/MLIR/llvm.mlir | 31 ++++++ enzyme/test/MLIR/memref.mlir | 31 ++++++ 9 files changed, 278 insertions(+) create mode 100644 enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/test/MLIR/llvm.mlir create mode 100644 enzyme/test/MLIR/memref.mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index 167b6d975c238..0472d1bf0caac 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_library(MLIREnzymeImplementations ArithAutoDiffOpInterfaceImpl.cpp + LLVMAutoDiffOpInterfaceImpl.cpp + MemRefAutoDiffOpInterfaceImpl.cpp BuiltinAutoDiffTypeInterfaceImpl.cpp SCFAutoDiffOpInterfaceImpl.cpp @@ -8,6 +10,8 @@ add_mlir_library(MLIREnzymeImplementations LINK_LIBS PUBLIC MLIRArithDialect + MLIRLLVMDialect + MLIRMemRefDialect MLIREnzymeAutoDiffInterface MLIRIR MLIRSCFDialect diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 5a3abde39b671..b994579eef65c 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -18,6 +18,8 @@ class DialectRegistry; namespace enzyme { void registerArithDialectAutoDiffInterface(DialectRegistry ®istry); void registerBuiltinDialectAutoDiffInterface(DialectRegistry ®istry); +void registerLLVMDialectAutoDiffInterface(DialectRegistry ®istry); +void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry); void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); } // namespace enzyme } // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 0000000000000..ff9f1494afa0f --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,95 @@ +//===- LLVMAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/GradientUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +struct LoadOpInterface + : public AutoDiffOpInterface::ExternalModel { + LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + auto loadOp = cast(op); + if (!gutils->isConstantValue(loadOp)) { + mlir::Value res = builder.create( + loadOp.getLoc(), gutils->invertPointerM(loadOp.getAddr(), builder)); + gutils->setDiffe(loadOp, res, builder); + } + gutils->eraseIfUnused(op); + return success(); + } +}; + +struct StoreOpInterface + : public AutoDiffOpInterface::ExternalModel { + LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + auto storeOp = cast(op); + if (!gutils->isConstantValue(storeOp.getAddr())) { + builder.create( + storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder), + gutils->invertPointerM(storeOp.getAddr(), builder)); + } + gutils->eraseIfUnused(op); + return success(); + } +}; + +struct AllocaOpInterface + : public AutoDiffOpInterface::ExternalModel { + LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + auto allocOp = cast(op); + if (!gutils->isConstantValue(allocOp)) { + Operation *nop = gutils->cloneWithNewOperands(builder, op); + gutils->setDiffe(allocOp, nop->getResult(0), builder); + } + gutils->eraseIfUnused(op); + return success(); + } +}; + +class PointerTypeInterface + : public AutoDiffTypeInterface::ExternalModel { +public: + Value createNullValue(Type self, OpBuilder &builder, Location loc) const { + return builder.create(loc, self); + } + + Type getShadowType(Type self, unsigned width) const { + assert(width == 1 && "unsupported width != 1"); + return self; + } +}; +} // namespace + +void mlir::enzyme::registerLLVMDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, LLVM::LLVMDialect *) { + LLVM::LoadOp::attachInterface(*context); + LLVM::StoreOp::attachInterface(*context); + LLVM::AllocaOp::attachInterface(*context); + LLVM::LLVMPointerType::attachInterface(*context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 0000000000000..2f338e0092ac3 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,103 @@ +//===- MemRefAutoDiffOpInterfaceImpl.cpp - Interface external model -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR memref dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/GradientUtils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +struct LoadOpInterface + : public AutoDiffOpInterface::ExternalModel { + LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + auto loadOp = cast(op); + if (!gutils->isConstantValue(loadOp)) { + SmallVector inds; + for (auto ind : loadOp.getIndices()) + inds.push_back(gutils->getNewFromOriginal(ind)); + mlir::Value res = builder.create( + loadOp.getLoc(), gutils->invertPointerM(loadOp.getMemref(), builder), + inds); + gutils->setDiffe(loadOp, res, builder); + } + gutils->eraseIfUnused(op); + return success(); + } +}; + +struct StoreOpInterface + : public AutoDiffOpInterface::ExternalModel { + LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + auto storeOp = cast(op); + if (!gutils->isConstantValue(storeOp.getMemref())) { + SmallVector inds; + for (auto ind : storeOp.getIndices()) + inds.push_back(gutils->getNewFromOriginal(ind)); + builder.create( + storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder), + gutils->invertPointerM(storeOp.getMemref(), builder), inds); + } + gutils->eraseIfUnused(op); + return success(); + } +}; + +struct AllocOpInterface + : public AutoDiffOpInterface::ExternalModel { + LogicalResult createForwardModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + auto allocOp = cast(op); + if (!gutils->isConstantValue(allocOp)) { + Operation *nop = gutils->cloneWithNewOperands(builder, op); + gutils->setDiffe(allocOp, nop->getResult(0), builder); + } + gutils->eraseIfUnused(op); + return success(); + } +}; + +class MemRefTypeInterface + : public AutoDiffTypeInterface::ExternalModel { +public: + Value createNullValue(Type self, OpBuilder &builder, Location loc) const { + llvm_unreachable("Cannot create null of memref (todo polygeist null)"); + } + + Type getShadowType(Type self, unsigned width) const { + assert(width == 1 && "unsupported width != 1"); + return self; + } +}; +} // namespace + +void mlir::enzyme::registerMemRefDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, memref::MemRefDialect *) { + memref::LoadOp::attachInterface(*context); + memref::StoreOp::attachInterface(*context); + memref::AllocOp::attachInterface(*context); + MemRefType::attachInterface(*context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index d03c0eb945188..fab77a7903952 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -126,6 +126,14 @@ mlir::enzyme::MGradientUtils::getNewFromOriginal(Operation *originst) const { return found->second; } +Operation *mlir::enzyme::MGradientUtils::cloneWithNewOperands(OpBuilder &B, + Operation *op) { + BlockAndValueMapping map; + for (auto operand : op->getOperands()) + map.map(operand, getNewFromOriginal(operand)); + return B.clone(*op, map); +} + bool mlir::enzyme::MGradientUtils::isConstantValue(Value v) const { if (isa(v.getType())) return true; diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index ed2e4b20d7e2d..60c92b4c78535 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -77,6 +77,8 @@ class MGradientUtils { void setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM); void forceAugmentedReturns(); + Operation *cloneWithNewOperands(OpBuilder &B, Operation *op); + LogicalResult visitChild(Operation *op); }; diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 1642d2add85f9..750ce5069661f 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -92,6 +92,8 @@ int main(int argc, char **argv) { // Register the autodiff interface implementations for upstream dialects. enzyme::registerArithDialectAutoDiffInterface(registry); enzyme::registerBuiltinDialectAutoDiffInterface(registry); + enzyme::registerLLVMDialectAutoDiffInterface(registry); + enzyme::registerMemRefDialectAutoDiffInterface(registry); enzyme::registerSCFDialectAutoDiffInterface(registry); return mlir::failed( diff --git a/enzyme/test/MLIR/llvm.mlir b/enzyme/test/MLIR/llvm.mlir new file mode 100644 index 0000000000000..544baf63e9663 --- /dev/null +++ b/enzyme/test/MLIR/llvm.mlir @@ -0,0 +1,31 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64 { + %c1_i64 = arith.constant 1 : i64 + %tmp = llvm.alloca %c1_i64 x f64 : (i64) -> !llvm.ptr + %y = arith.mulf %x, %x : f64 + llvm.store %y, %tmp : !llvm.ptr + %r = llvm.load %tmp : !llvm.ptr + return %r : f64 + } + func.func @dsq(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } +} + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { +// CHECK-NEXT: %[[c1_i64:.+]] = arith.constant 1 : i64 +// CHECK-NEXT: %[[i0:.+]] = llvm.alloca %[[c1_i64]] x f64 : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[i1:.+]] = llvm.alloca %[[c1_i64]] x f64 : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[i2:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i4:.+]] = arith.addf %[[i2]], %[[i3]] : f64 +// CHECK-NEXT: %[[i5:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: llvm.store %[[i4]], %[[i0]] : !llvm.ptr +// CHECK-NEXT: llvm.store %[[i5]], %[[i1]] : !llvm.ptr +// CHECK-NEXT: %[[i6:.+]] = llvm.load %[[i0]] : !llvm.ptr +// CHECK-NEXT: %[[i7:.+]] = llvm.load %[[i1]] : !llvm.ptr +// CHECK-NEXT: return %[[i6]] : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/memref.mlir b/enzyme/test/MLIR/memref.mlir new file mode 100644 index 0000000000000..cfd076844bda0 --- /dev/null +++ b/enzyme/test/MLIR/memref.mlir @@ -0,0 +1,31 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64 { + %c0 = arith.constant 0 : index + %tmp = memref.alloc() : memref<1xf64> + %y = arith.mulf %x, %x : f64 + memref.store %y, %tmp[%c0] : memref<1xf64> + %r = memref.load %tmp[%c0] : memref<1xf64> + return %r : f64 + } + func.func @dsq(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } +} + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { +// CHECK-NEXT: %[[c0:.+]] = arith.constant 0 : index +// CHECK-NEXT: %[[i0:.+]] = memref.alloc() : memref<1xf64> +// CHECK-NEXT: %[[i1:.+]] = memref.alloc() : memref<1xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i4:.+]] = arith.addf %[[i2]], %[[i3]] : f64 +// CHECK-NEXT: %[[i5:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: memref.store %[[i4]], %[[i0]][%[[c0]]] : memref<1xf64> +// CHECK-NEXT: memref.store %[[i5]], %[[i1]][%[[c0]]] : memref<1xf64> +// CHECK-NEXT: %[[i6:.+]] = memref.load %[[i0]][%[[c0]]] : memref<1xf64> +// CHECK-NEXT: %[[i7:.+]] = memref.load %[[i1]][%[[c0]]] : memref<1xf64> +// CHECK-NEXT: return %[[i6]] : f64 +// CHECK-NEXT: }