Skip to content

Commit

Permalink
[Flang][MLIR] Addition of declare target to and other related map fix…
Browse files Browse the repository at this point in the history
…es (#219)

This PR implements some initial support for declare target to and provides some map modifications that support it in relation to descriptor types, it borrows heavily from Clang's implementation of mapping when you map a full structure explicitly and then individiual components with the same or different map types.
  • Loading branch information
agozillon authored Dec 4, 2024
1 parent 16d4aea commit 13b4a9d
Show file tree
Hide file tree
Showing 20 changed files with 731 additions and 282 deletions.
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,10 @@ elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams);
/// Get the address space which should be used for allocas
uint64_t getAllocaAddressSpace(mlir::DataLayout *dataLayout);

uint64_t getGlobalAddressSpace(mlir::DataLayout *dataLayout);

uint64_t getProgramAddressSpace(mlir::DataLayout *dataLayout);

} // namespace fir::factory

#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
3 changes: 3 additions & 0 deletions flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
unsigned
getProgramAddressSpace(mlir::ConversionPatternRewriter &rewriter) const;

unsigned
getGlobalAddressSpace(mlir::ConversionPatternRewriter &rewriter) const;

const fir::FIRToLLVMPassOptions &options;

using ConvertToLLVMPattern::match;
Expand Down
14 changes: 14 additions & 0 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1721,3 +1721,17 @@ uint64_t fir::factory::getAllocaAddressSpace(mlir::DataLayout *dataLayout) {
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}

uint64_t fir::factory::getGlobalAddressSpace(mlir::DataLayout *dataLayout) {
if (dataLayout)
if (mlir::Attribute addrSpace = dataLayout->getGlobalMemorySpace())
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}

uint64_t fir::factory::getProgramAddressSpace(mlir::DataLayout *dataLayout) {
if (dataLayout)
if (mlir::Attribute addrSpace = dataLayout->getProgramMemorySpace())
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}
73 changes: 60 additions & 13 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,54 @@ addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
}

namespace {

mlir::Value replaceWithAddrOfOrASCast(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc,
std::uint64_t globalAS,
std::uint64_t programAS,
llvm::StringRef symName, mlir::Type type,
mlir::Operation *replaceOp = nullptr) {
if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
if (globalAS != programAS) {
auto llvmAddrOp = rewriter.create<mlir::LLVM::AddressOfOp>(
loc, getLlvmPtrType(rewriter.getContext(), globalAS), symName);
if (replaceOp)
return rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(
replaceOp, ::getLlvmPtrType(rewriter.getContext(), programAS),
llvmAddrOp);
return rewriter.create<mlir::LLVM::AddrSpaceCastOp>(
loc, getLlvmPtrType(rewriter.getContext(), programAS), llvmAddrOp);
}

if (replaceOp)
return rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
replaceOp, getLlvmPtrType(rewriter.getContext(), globalAS), symName);
return rewriter.create<mlir::LLVM::AddressOfOp>(
loc, getLlvmPtrType(rewriter.getContext(), globalAS), symName);
}

if (replaceOp)
return rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(replaceOp, type,
symName);
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, type, symName);
}

/// Lower `fir.address_of` operation to `llvm.address_of` operation.
struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
using FIROpConversion::FIROpConversion;

llvm::LogicalResult
matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto ty = convertType(addr.getType());
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
addr, ty, addr.getSymbol().getRootReference().getValue());
auto global = addr->getParentOfType<mlir::ModuleOp>()
.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
replaceWithAddrOfOrASCast(
rewriter, addr->getLoc(),
global ? global.getAddrSpace() : getGlobalAddressSpace(rewriter),
getProgramAddressSpace(rewriter),
global ? global.getSymName()
: addr.getSymbol().getRootReference().getValue(),
convertType(addr.getType()), addr);
return mlir::success();
}
};
Expand Down Expand Up @@ -1255,14 +1293,19 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
? fir::NameUniquer::getTypeDescriptorAssemblyName(recType.getName())
: fir::NameUniquer::getTypeDescriptorName(recType.getName());
mlir::Type llvmPtrTy = ::getLlvmPtrType(mod.getContext());
mlir::DataLayout dataLayout(mod);
if (auto global = mod.template lookupSymbol<fir::GlobalOp>(name)) {
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
global.getSymName());
return replaceWithAddrOfOrASCast(
rewriter, loc, fir::factory::getGlobalAddressSpace(&dataLayout),
fir::factory::getProgramAddressSpace(&dataLayout),
global.getSymName(), llvmPtrTy);
}
if (auto global = mod.template lookupSymbol<mlir::LLVM::GlobalOp>(name)) {
// The global may have already been translated to LLVM.
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
global.getSymName());
return replaceWithAddrOfOrASCast(
rewriter, loc, global.getAddrSpace(),
fir::factory::getProgramAddressSpace(&dataLayout),
global.getSymName(), llvmPtrTy);
}
// Type info derived types do not have type descriptors since they are the
// types defining type descriptors.
Expand Down Expand Up @@ -2763,12 +2806,16 @@ struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> {
: fir::NameUniquer::getTypeDescriptorName(recordType.getName());
auto llvmPtrTy = ::getLlvmPtrType(typeDescOp.getContext());
if (auto global = module.lookupSymbol<mlir::LLVM::GlobalOp>(typeDescName)) {
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
typeDescOp, llvmPtrTy, global.getSymName());
replaceWithAddrOfOrASCast(rewriter, typeDescOp->getLoc(),
global.getAddrSpace(),
getProgramAddressSpace(rewriter),
global.getSymName(), llvmPtrTy, typeDescOp);
return mlir::success();
} else if (auto global = module.lookupSymbol<fir::GlobalOp>(typeDescName)) {
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
typeDescOp, llvmPtrTy, global.getSymName());
replaceWithAddrOfOrASCast(rewriter, typeDescOp->getLoc(),
getGlobalAddressSpace(rewriter),
getProgramAddressSpace(rewriter),
global.getSymName(), llvmPtrTy, typeDescOp);
return mlir::success();
}
return mlir::failure();
Expand Down Expand Up @@ -2859,8 +2906,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
mlir::SymbolRefAttr comdat;
llvm::ArrayRef<mlir::NamedAttribute> attrs;
auto g = rewriter.create<mlir::LLVM::GlobalOp>(
loc, tyAttr, isConst, linkage, global.getSymName(), initAttr, 0, 0,
false, false, comdat, attrs, dbgExprs);
loc, tyAttr, isConst, linkage, global.getSymName(), initAttr, 0,
getGlobalAddressSpace(rewriter), false, false, comdat, attrs, dbgExprs);

if (global.getAlignment() && *global.getAlignment() > 0)
g.setAlignment(*global.getAlignment());
Expand Down
25 changes: 23 additions & 2 deletions flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@ unsigned ConvertFIRToLLVMPattern::getAllocaAddressSpace(
mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
assert(parentOp != nullptr &&
"expected insertion block to have parent operation");
if (auto module = parentOp->getParentOfType<mlir::ModuleOp>())
auto module = mlir::isa<mlir::ModuleOp>(parentOp)
? mlir::cast<mlir::ModuleOp>(parentOp)
: parentOp->getParentOfType<mlir::ModuleOp>();
if (module)
if (mlir::Attribute addrSpace =
mlir::DataLayout(module).getAllocaMemorySpace())
return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
Expand All @@ -358,11 +361,29 @@ unsigned ConvertFIRToLLVMPattern::getProgramAddressSpace(
mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
assert(parentOp != nullptr &&
"expected insertion block to have parent operation");
if (auto module = parentOp->getParentOfType<mlir::ModuleOp>())
auto module = mlir::isa<mlir::ModuleOp>(parentOp)
? mlir::cast<mlir::ModuleOp>(parentOp)
: parentOp->getParentOfType<mlir::ModuleOp>();
if (module)
if (mlir::Attribute addrSpace =
mlir::DataLayout(module).getProgramMemorySpace())
return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return defaultAddressSpace;
}

unsigned ConvertFIRToLLVMPattern::getGlobalAddressSpace(
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
assert(parentOp != nullptr &&
"expected insertion block to have parent operation");
auto module = mlir::isa<mlir::ModuleOp>(parentOp)
? mlir::cast<mlir::ModuleOp>(parentOp)
: parentOp->getParentOfType<mlir::ModuleOp>();
if (module)
if (mlir::Attribute addrSpace =
mlir::DataLayout(module).getGlobalMemorySpace())
return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return defaultAddressSpace;
}

} // namespace fir
6 changes: 4 additions & 2 deletions flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,10 @@ class MapInfoFinalizationPass
return llvm::to_underlying(
hasImplicitMap
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
}

mlir::omp::MapInfoOp genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
Expand Down
Loading

0 comments on commit 13b4a9d

Please sign in to comment.