Skip to content

Commit

Permalink
[MLIR] Add a BlobAttr interface for attribute to wrap arbitrary conte…
Browse files Browse the repository at this point in the history
…nt and use it as linkLibs for ModuleToObject (#120116)

This change allows to expose through an interface attributes wrapping
content as external resources, and the usage inside the ModuleToObject
show how we will be able to provide runtime libraries without relying on
the filesystem.
  • Loading branch information
joker-eph authored Dec 17, 2024
1 parent 41a6e9c commit 72e8b9a
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 57 deletions.
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TargetOptions {
/// obtaining the parent symbol table. The default compilation target is
/// `Fatbin`.
TargetOptions(
StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
StringRef toolkitPath = {}, ArrayRef<Attribute> librariesToLink = {},
StringRef cmdOptions = {}, StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
Expand All @@ -66,8 +66,8 @@ class TargetOptions {
/// Returns the toolkit path.
StringRef getToolkitPath() const;

/// Returns the files to link to.
ArrayRef<std::string> getLinkFiles() const;
/// Returns the LLVM libraries to link to.
ArrayRef<Attribute> getLibrariesToLink() const;

/// Returns the command line options.
StringRef getCmdOptions() const;
Expand Down Expand Up @@ -113,7 +113,7 @@ class TargetOptions {
/// appropiate value: ie. `TargetOptions(TypeID::get<DerivedClass>())`.
TargetOptions(
TypeID typeID, StringRef toolkitPath = {},
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
ArrayRef<Attribute> librariesToLink = {}, StringRef cmdOptions = {},
StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
Expand All @@ -126,7 +126,7 @@ class TargetOptions {
std::string toolkitPath;

/// List of files to link with the LLVM module.
SmallVector<std::string> linkFiles;
SmallVector<Attribute> librariesToLink;

/// An optional set of command line options to be used by the compilation
/// process.
Expand Down
18 changes: 18 additions & 0 deletions mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ def TypedAttrInterface : AttrInterface<"TypedAttr"> {
>];
}

//===----------------------------------------------------------------------===//
// BlobAttrInterface
//===----------------------------------------------------------------------===//

def BlobAttrInterface : AttrInterface<"BlobAttr"> {
let cppNamespace = "::mlir";
let description = [{
This interface allows an attribute to expose a blob of data without more
information. The data must be stored so that it can be accessed as a
contiguous ArrayRef.
}];

let methods = [InterfaceMethod<
"Get the attribute's data",
"::llvm::ArrayRef<char>", "getData"
>];
}

//===----------------------------------------------------------------------===//
// ElementsAttrInterface
//===----------------------------------------------------------------------===//
Expand Down
13 changes: 11 additions & 2 deletions mlir/include/mlir/IR/BuiltinAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
}];
}

def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array"> {
def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array",
[BlobAttrInterface]> {
let summary = "A dense array of integer or floating point elements.";
let description = [{
A dense array attribute is an attribute that represents a dense array of
Expand Down Expand Up @@ -211,6 +212,10 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array"> {
int64_t size() const { return getSize(); }
/// Return true if there are no elements in the dense array.
bool empty() const { return !size(); }
/// BlobAttrInterface method.
ArrayRef<char> getData() {
return getRawData();
}
}];
}

Expand Down Expand Up @@ -431,7 +436,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
//===----------------------------------------------------------------------===//

def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements",
"dense_resource_elements", [ElementsAttrInterface]> {
"dense_resource_elements", [ElementsAttrInterface, BlobAttrInterface]> {
let summary = "An Attribute containing a dense multi-dimensional array "
"backed by a resource";
let description = [{
Expand Down Expand Up @@ -485,6 +490,10 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements",
"ShapedType":$type, "StringRef":$blobName, "AsmResourceBlob":$blob
)>
];
let extraClassDeclaration = [{
/// BlobAttrInterface method.
ArrayRef<char> getData();
}];

let skipDefaultBuilders = 1;
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Target/LLVM/ModuleToObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ModuleToObject {

/// Loads multiple bitcode files.
LogicalResult loadBitcodeFilesFromList(
llvm::LLVMContext &context, ArrayRef<std::string> fileList,
llvm::LLVMContext &context, ArrayRef<Attribute> librariesToLink,
SmallVector<std::unique_ptr<llvm::Module>> &llvmModules,
bool failureOnError = true);

Expand Down
15 changes: 9 additions & 6 deletions mlir/include/mlir/Target/LLVM/NVVM/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
/// Returns the CUDA toolkit path.
StringRef getToolkitPath() const;

/// Returns the bitcode files to be loaded.
ArrayRef<std::string> getFileList() const;
/// Returns the bitcode libraries to be linked into the gpu module after
/// translation to LLVM IR.
ArrayRef<Attribute> getLibrariesToLink() const;

/// Appends `nvvm/libdevice.bc` into `fileList`. Returns failure if the
/// Appends `nvvm/libdevice.bc` into `librariesToLink`. Returns failure if the
/// library couldn't be found.
LogicalResult appendStandardLibs();

/// Loads the bitcode files in `fileList`.
/// Loads the bitcode files in `librariesToLink`.
virtual std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
loadBitcodeFiles(llvm::Module &module) override;

Expand All @@ -64,8 +65,10 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
/// CUDA toolkit path.
std::string toolkitPath;

/// List of LLVM bitcode files to link to.
SmallVector<std::string> fileList;
/// List of LLVM bitcode to link into after translation to LLVM IR.
/// The attributes can be StringAttr pointing to a file path, or
/// a Resource blob pointing to the LLVM bitcode in-memory.
SmallVector<Attribute> librariesToLink;
};
} // namespace NVVM
} // namespace mlir
Expand Down
7 changes: 4 additions & 3 deletions mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVM/ModuleToObject.h"

Expand Down Expand Up @@ -61,8 +62,8 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
/// Returns the ROCM toolkit path.
StringRef getToolkitPath() const;

/// Returns the bitcode files to be loaded.
ArrayRef<std::string> getFileList() const;
/// Returns the LLVM bitcode libraries to be linked.
ArrayRef<Attribute> getLibrariesToLink() const;

/// Appends standard ROCm device libraries to `fileList`.
LogicalResult appendStandardLibs(AMDGCNLibraries libs);
Expand Down Expand Up @@ -107,7 +108,7 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
std::string toolkitPath;

/// List of LLVM bitcode files to link to.
SmallVector<std::string> fileList;
SmallVector<Attribute> librariesToLink;

/// AMD GCN libraries to use when linking, the default is using none.
AMDGCNLibraries deviceLibs = AMDGCNLibraries::None;
Expand Down
12 changes: 7 additions & 5 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2483,30 +2483,30 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
//===----------------------------------------------------------------------===//

TargetOptions::TargetOptions(
StringRef toolkitPath, ArrayRef<std::string> linkFiles,
StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
cmdOptions, elfSection, compilationTarget,
getSymbolTableCallback, initialLlvmIRCallback,
linkedLlvmIRCallback, optimizedLlvmIRCallback,
isaCallback) {}

TargetOptions::TargetOptions(
TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
compilationTarget(compilationTarget),
getSymbolTableCallback(getSymbolTableCallback),
Expand All @@ -2519,7 +2519,9 @@ TypeID TargetOptions::getTypeID() const { return typeID; }

StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }

ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
return librariesToLink;
}

StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }

Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ void GpuModuleToBinaryPass::runOnOperation() {
}
return &parentTable.value();
};

TargetOptions targetOptions(toolkitPath, linkFiles, cmdOptions, elfSection,
*targetFormat, lazyTableBuilder);
SmallVector<Attribute> librariesToLink;
for (const std::string &path : linkFiles)
librariesToLink.push_back(StringAttr::get(&getContext(), path));
TargetOptions targetOptions(toolkitPath, librariesToLink, cmdOptions,
elfSection, *targetFormat, lazyTableBuilder);
if (failed(transformGpuModulesToBinaries(
getOperation(), OffloadingLLVMTranslationAttrInterface(nullptr),
targetOptions)))
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,12 @@ DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
return get(type, manager.insert(blobName, std::move(blob)));
}

ArrayRef<char> DenseResourceElementsAttr::getData() {
if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
return blob->getDataAs<char>();
return {};
}

//===----------------------------------------------------------------------===//
// DenseResourceElementsAttrBase

Expand Down
56 changes: 45 additions & 11 deletions mlir/lib/Target/LLVM/ModuleToObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "mlir/Target/LLVM/ModuleToObject.h"

#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
Expand All @@ -25,6 +27,7 @@
#include "llvm/Linker/Linker.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -89,22 +92,53 @@ ModuleToObject::loadBitcodeFile(llvm::LLVMContext &context, StringRef path) {
}

LogicalResult ModuleToObject::loadBitcodeFilesFromList(
llvm::LLVMContext &context, ArrayRef<std::string> fileList,
llvm::LLVMContext &context, ArrayRef<Attribute> librariesToLink,
SmallVector<std::unique_ptr<llvm::Module>> &llvmModules,
bool failureOnError) {
for (const std::string &str : fileList) {
// Test if the path exists, if it doesn't abort.
StringRef pathRef = StringRef(str.data(), str.size());
if (!llvm::sys::fs::is_regular_file(pathRef)) {
for (Attribute linkLib : librariesToLink) {
// Attributes in this list can be either list of file paths using
// StringAttr, or a resource attribute pointing to the LLVM bitcode in
// memory.
if (auto filePath = dyn_cast<StringAttr>(linkLib)) {
// Test if the path exists, if it doesn't abort.
if (!llvm::sys::fs::is_regular_file(filePath.strref())) {
getOperation().emitError()
<< "File path: " << filePath << " does not exist or is not a file.";
return failure();
}
// Load the file or abort on error.
if (auto bcFile = loadBitcodeFile(context, filePath))
llvmModules.push_back(std::move(bcFile));
else if (failureOnError)
return failure();
continue;
}
if (auto blobAttr = dyn_cast<BlobAttr>(linkLib)) {
// Load the file or abort on error.
llvm::SMDiagnostic error;
ArrayRef<char> data = blobAttr.getData();
std::unique_ptr<llvm::MemoryBuffer> buffer =
llvm::MemoryBuffer::getMemBuffer(StringRef(data.data(), data.size()),
"blobLinkedLib",
/*RequiresNullTerminator=*/false);
std::unique_ptr<llvm::Module> mod =
getLazyIRModule(std::move(buffer), error, context);
if (mod) {
if (failed(handleBitcodeFile(*mod)))
return failure();
llvmModules.push_back(std::move(mod));
} else if (failureOnError) {
getOperation().emitError()
<< "Couldn't load LLVM library for linking: " << error.getMessage();
return failure();
}
continue;
}
if (failureOnError) {
getOperation().emitError()
<< "File path: " << pathRef << " does not exist or is not a file.\n";
<< "Unknown attribute describing LLVM library to load: " << linkLib;
return failure();
}
// Load the file or abort on error.
if (auto bcFile = loadBitcodeFile(context, pathRef))
llvmModules.push_back(std::move(bcFile));
else if (failureOnError)
return failure();
}
return success();
}
Expand Down
18 changes: 8 additions & 10 deletions mlir/lib/Target/LLVM/NVVM/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,15 @@ SerializeGPUModuleBase::SerializeGPUModuleBase(
targetOptions.getOptimizedLlvmIRCallback(),
targetOptions.getISACallback()),
target(target), toolkitPath(targetOptions.getToolkitPath()),
fileList(targetOptions.getLinkFiles()) {
librariesToLink(targetOptions.getLibrariesToLink()) {

// If `targetOptions` have an empty toolkitPath use `getCUDAToolkitPath`
if (toolkitPath.empty())
toolkitPath = getCUDAToolkitPath();

// Append the files in the target attribute.
if (ArrayAttr files = target.getLink())
for (Attribute attr : files.getValue())
if (auto file = dyn_cast<StringAttr>(attr))
fileList.push_back(file.str());
if (target.getLink())
librariesToLink.append(target.getLink().begin(), target.getLink().end());

// Append libdevice to the files to be loaded.
(void)appendStandardLibs();
Expand All @@ -126,8 +124,8 @@ NVVMTargetAttr SerializeGPUModuleBase::getTarget() const { return target; }

StringRef SerializeGPUModuleBase::getToolkitPath() const { return toolkitPath; }

ArrayRef<std::string> SerializeGPUModuleBase::getFileList() const {
return fileList;
ArrayRef<Attribute> SerializeGPUModuleBase::getLibrariesToLink() const {
return librariesToLink;
}

// Try to append `libdevice` from a CUDA toolkit installation.
Expand All @@ -149,16 +147,16 @@ LogicalResult SerializeGPUModuleBase::appendStandardLibs() {
<< " does not exist or is not a file.\n";
return failure();
}
fileList.push_back(pathRef.str());
librariesToLink.push_back(StringAttr::get(target.getContext(), pathRef));
}
return success();
}

std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
SerializeGPUModuleBase::loadBitcodeFiles(llvm::Module &module) {
SmallVector<std::unique_ptr<llvm::Module>> bcFiles;
if (failed(loadBitcodeFilesFromList(module.getContext(), fileList, bcFiles,
true)))
if (failed(loadBitcodeFilesFromList(module.getContext(), librariesToLink,
bcFiles, true)))
return std::nullopt;
return std::move(bcFiles);
}
Expand Down
Loading

0 comments on commit 72e8b9a

Please sign in to comment.