diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h index 5b6e03a2e6e75e..c950ef220f692f 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h @@ -51,7 +51,7 @@ class TargetOptions { /// obtaining the parent symbol table. The default compilation target is /// `Fatbin`. TargetOptions( - StringRef toolkitPath = {}, ArrayRef linkFiles = {}, + StringRef toolkitPath = {}, ArrayRef librariesToLink = {}, StringRef cmdOptions = {}, StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref getSymbolTableCallback = {}, @@ -66,8 +66,8 @@ class TargetOptions { /// Returns the toolkit path. StringRef getToolkitPath() const; - /// Returns the files to link to. - ArrayRef getLinkFiles() const; + /// Returns the LLVM libraries to link to. + ArrayRef getLibrariesToLink() const; /// Returns the command line options. StringRef getCmdOptions() const; @@ -113,7 +113,7 @@ class TargetOptions { /// appropiate value: ie. `TargetOptions(TypeID::get())`. TargetOptions( TypeID typeID, StringRef toolkitPath = {}, - ArrayRef linkFiles = {}, StringRef cmdOptions = {}, + ArrayRef librariesToLink = {}, StringRef cmdOptions = {}, StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref getSymbolTableCallback = {}, @@ -126,7 +126,7 @@ class TargetOptions { std::string toolkitPath; /// List of files to link with the LLVM module. - SmallVector linkFiles; + SmallVector librariesToLink; /// An optional set of command line options to be used by the compilation /// process. diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td index 954429c7d8eaeb..017559cc353e6b 100644 --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -34,6 +34,24 @@ def TypedAttrInterface : AttrInterface<"TypedAttr"> { >]; } +//===----------------------------------------------------------------------===// +// BlobAttrInterface +//===----------------------------------------------------------------------===// + +def BlobAttrInterface : AttrInterface<"BlobAttr"> { + let cppNamespace = "::mlir"; + let description = [{ + This interface allows an attribute to expose a blob of data without more + information. The data must be stored so that it can be accessed as a + contiguous ArrayRef. + }]; + + let methods = [InterfaceMethod< + "Get the attribute's data", + "::llvm::ArrayRef", "getData" + >]; +} + //===----------------------------------------------------------------------===// // ElementsAttrInterface //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index 492b8309a5ea33..06f5e172a9909d 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -153,7 +153,8 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter< }]; } -def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array"> { +def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array", + [BlobAttrInterface]> { let summary = "A dense array of integer or floating point elements."; let description = [{ A dense array attribute is an attribute that represents a dense array of @@ -211,6 +212,10 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array"> { int64_t size() const { return getSize(); } /// Return true if there are no elements in the dense array. bool empty() const { return !size(); } + /// BlobAttrInterface method. + ArrayRef getData() { + return getRawData(); + } }]; } @@ -431,7 +436,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr< //===----------------------------------------------------------------------===// def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", - "dense_resource_elements", [ElementsAttrInterface]> { + "dense_resource_elements", [ElementsAttrInterface, BlobAttrInterface]> { let summary = "An Attribute containing a dense multi-dimensional array " "backed by a resource"; let description = [{ @@ -485,6 +490,10 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", "ShapedType":$type, "StringRef":$blobName, "AsmResourceBlob":$blob )> ]; + let extraClassDeclaration = [{ + /// BlobAttrInterface method. + ArrayRef getData(); + }]; let skipDefaultBuilders = 1; } diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h index 07fc55b41ae9c5..11fea6f0a44432 100644 --- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h +++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h @@ -83,7 +83,7 @@ class ModuleToObject { /// Loads multiple bitcode files. LogicalResult loadBitcodeFilesFromList( - llvm::LLVMContext &context, ArrayRef fileList, + llvm::LLVMContext &context, ArrayRef librariesToLink, SmallVector> &llvmModules, bool failureOnError = true); diff --git a/mlir/include/mlir/Target/LLVM/NVVM/Utils.h b/mlir/include/mlir/Target/LLVM/NVVM/Utils.h index 65ae8a6bdb4ada..2d6157b1e5a60d 100644 --- a/mlir/include/mlir/Target/LLVM/NVVM/Utils.h +++ b/mlir/include/mlir/Target/LLVM/NVVM/Utils.h @@ -46,14 +46,15 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject { /// Returns the CUDA toolkit path. StringRef getToolkitPath() const; - /// Returns the bitcode files to be loaded. - ArrayRef getFileList() const; + /// Returns the bitcode libraries to be linked into the gpu module after + /// translation to LLVM IR. + ArrayRef getLibrariesToLink() const; - /// Appends `nvvm/libdevice.bc` into `fileList`. Returns failure if the + /// Appends `nvvm/libdevice.bc` into `librariesToLink`. Returns failure if the /// library couldn't be found. LogicalResult appendStandardLibs(); - /// Loads the bitcode files in `fileList`. + /// Loads the bitcode files in `librariesToLink`. virtual std::optional>> loadBitcodeFiles(llvm::Module &module) override; @@ -64,8 +65,10 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject { /// CUDA toolkit path. std::string toolkitPath; - /// List of LLVM bitcode files to link to. - SmallVector fileList; + /// List of LLVM bitcode to link into after translation to LLVM IR. + /// The attributes can be StringAttr pointing to a file path, or + /// a Resource blob pointing to the LLVM bitcode in-memory. + SmallVector librariesToLink; }; } // namespace NVVM } // namespace mlir diff --git a/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h b/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h index 2d8204b55d360d..8f5d4162984fac 100644 --- a/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h +++ b/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVM/ModuleToObject.h" @@ -61,8 +62,8 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject { /// Returns the ROCM toolkit path. StringRef getToolkitPath() const; - /// Returns the bitcode files to be loaded. - ArrayRef getFileList() const; + /// Returns the LLVM bitcode libraries to be linked. + ArrayRef getLibrariesToLink() const; /// Appends standard ROCm device libraries to `fileList`. LogicalResult appendStandardLibs(AMDGCNLibraries libs); @@ -107,7 +108,7 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject { std::string toolkitPath; /// List of LLVM bitcode files to link to. - SmallVector fileList; + SmallVector librariesToLink; /// AMD GCN libraries to use when linking, the default is using none. AMDGCNLibraries deviceLibs = AMDGCNLibraries::None; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 1fad251b2f79e0..ed2d81ee65eb4a 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2483,7 +2483,7 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const { //===----------------------------------------------------------------------===// TargetOptions::TargetOptions( - StringRef toolkitPath, ArrayRef linkFiles, + StringRef toolkitPath, ArrayRef librariesToLink, StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref getSymbolTableCallback, @@ -2491,14 +2491,14 @@ TargetOptions::TargetOptions( function_ref linkedLlvmIRCallback, function_ref optimizedLlvmIRCallback, function_ref isaCallback) - : TargetOptions(TypeID::get(), toolkitPath, linkFiles, + : TargetOptions(TypeID::get(), toolkitPath, librariesToLink, cmdOptions, elfSection, compilationTarget, getSymbolTableCallback, initialLlvmIRCallback, linkedLlvmIRCallback, optimizedLlvmIRCallback, isaCallback) {} TargetOptions::TargetOptions( - TypeID typeID, StringRef toolkitPath, ArrayRef linkFiles, + TypeID typeID, StringRef toolkitPath, ArrayRef librariesToLink, StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref getSymbolTableCallback, @@ -2506,7 +2506,7 @@ TargetOptions::TargetOptions( function_ref linkedLlvmIRCallback, function_ref optimizedLlvmIRCallback, function_ref isaCallback) - : toolkitPath(toolkitPath.str()), linkFiles(linkFiles), + : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink), cmdOptions(cmdOptions.str()), elfSection(elfSection.str()), compilationTarget(compilationTarget), getSymbolTableCallback(getSymbolTableCallback), @@ -2519,7 +2519,9 @@ TypeID TargetOptions::getTypeID() const { return typeID; } StringRef TargetOptions::getToolkitPath() const { return toolkitPath; } -ArrayRef TargetOptions::getLinkFiles() const { return linkFiles; } +ArrayRef TargetOptions::getLibrariesToLink() const { + return librariesToLink; +} StringRef TargetOptions::getCmdOptions() const { return cmdOptions; } diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp index 295ece4782fdbf..f8a548af6b3e8f 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp @@ -68,9 +68,11 @@ void GpuModuleToBinaryPass::runOnOperation() { } return &parentTable.value(); }; - - TargetOptions targetOptions(toolkitPath, linkFiles, cmdOptions, elfSection, - *targetFormat, lazyTableBuilder); + SmallVector librariesToLink; + for (const std::string &path : linkFiles) + librariesToLink.push_back(StringAttr::get(&getContext(), path)); + TargetOptions targetOptions(toolkitPath, librariesToLink, cmdOptions, + elfSection, *targetFormat, lazyTableBuilder); if (failed(transformGpuModulesToBinaries( getOperation(), OffloadingLLVMTranslationAttrInterface(nullptr), targetOptions))) diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index f288dd42baaa16..112e3f376bd418 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1544,6 +1544,12 @@ DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type, return get(type, manager.insert(blobName, std::move(blob))); } +ArrayRef DenseResourceElementsAttr::getData() { + if (AsmResourceBlob *blob = this->getRawHandle().getBlob()) + return blob->getDataAs(); + return {}; +} + //===----------------------------------------------------------------------===// // DenseResourceElementsAttrBase diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp index 3f5b3d5e31864b..102d149a7bb584 100644 --- a/mlir/lib/Target/LLVM/ModuleToObject.cpp +++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp @@ -14,6 +14,8 @@ #include "mlir/Target/LLVM/ModuleToObject.h" #include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" @@ -25,6 +27,7 @@ #include "llvm/Linker/Linker.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Path.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" @@ -89,22 +92,53 @@ ModuleToObject::loadBitcodeFile(llvm::LLVMContext &context, StringRef path) { } LogicalResult ModuleToObject::loadBitcodeFilesFromList( - llvm::LLVMContext &context, ArrayRef fileList, + llvm::LLVMContext &context, ArrayRef librariesToLink, SmallVector> &llvmModules, bool failureOnError) { - for (const std::string &str : fileList) { - // Test if the path exists, if it doesn't abort. - StringRef pathRef = StringRef(str.data(), str.size()); - if (!llvm::sys::fs::is_regular_file(pathRef)) { + for (Attribute linkLib : librariesToLink) { + // Attributes in this list can be either list of file paths using + // StringAttr, or a resource attribute pointing to the LLVM bitcode in + // memory. + if (auto filePath = dyn_cast(linkLib)) { + // Test if the path exists, if it doesn't abort. + if (!llvm::sys::fs::is_regular_file(filePath.strref())) { + getOperation().emitError() + << "File path: " << filePath << " does not exist or is not a file."; + return failure(); + } + // Load the file or abort on error. + if (auto bcFile = loadBitcodeFile(context, filePath)) + llvmModules.push_back(std::move(bcFile)); + else if (failureOnError) + return failure(); + continue; + } + if (auto blobAttr = dyn_cast(linkLib)) { + // Load the file or abort on error. + llvm::SMDiagnostic error; + ArrayRef data = blobAttr.getData(); + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(StringRef(data.data(), data.size()), + "blobLinkedLib", + /*RequiresNullTerminator=*/false); + std::unique_ptr mod = + getLazyIRModule(std::move(buffer), error, context); + if (mod) { + if (failed(handleBitcodeFile(*mod))) + return failure(); + llvmModules.push_back(std::move(mod)); + } else if (failureOnError) { + getOperation().emitError() + << "Couldn't load LLVM library for linking: " << error.getMessage(); + return failure(); + } + continue; + } + if (failureOnError) { getOperation().emitError() - << "File path: " << pathRef << " does not exist or is not a file.\n"; + << "Unknown attribute describing LLVM library to load: " << linkLib; return failure(); } - // Load the file or abort on error. - if (auto bcFile = loadBitcodeFile(context, pathRef)) - llvmModules.push_back(std::move(bcFile)); - else if (failureOnError) - return failure(); } return success(); } diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index 745c1a5a6ee601..a9f7806b10f404 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -93,17 +93,15 @@ SerializeGPUModuleBase::SerializeGPUModuleBase( targetOptions.getOptimizedLlvmIRCallback(), targetOptions.getISACallback()), target(target), toolkitPath(targetOptions.getToolkitPath()), - fileList(targetOptions.getLinkFiles()) { + librariesToLink(targetOptions.getLibrariesToLink()) { // If `targetOptions` have an empty toolkitPath use `getCUDAToolkitPath` if (toolkitPath.empty()) toolkitPath = getCUDAToolkitPath(); // Append the files in the target attribute. - if (ArrayAttr files = target.getLink()) - for (Attribute attr : files.getValue()) - if (auto file = dyn_cast(attr)) - fileList.push_back(file.str()); + if (target.getLink()) + librariesToLink.append(target.getLink().begin(), target.getLink().end()); // Append libdevice to the files to be loaded. (void)appendStandardLibs(); @@ -126,8 +124,8 @@ NVVMTargetAttr SerializeGPUModuleBase::getTarget() const { return target; } StringRef SerializeGPUModuleBase::getToolkitPath() const { return toolkitPath; } -ArrayRef SerializeGPUModuleBase::getFileList() const { - return fileList; +ArrayRef SerializeGPUModuleBase::getLibrariesToLink() const { + return librariesToLink; } // Try to append `libdevice` from a CUDA toolkit installation. @@ -149,7 +147,7 @@ LogicalResult SerializeGPUModuleBase::appendStandardLibs() { << " does not exist or is not a file.\n"; return failure(); } - fileList.push_back(pathRef.str()); + librariesToLink.push_back(StringAttr::get(target.getContext(), pathRef)); } return success(); } @@ -157,8 +155,8 @@ LogicalResult SerializeGPUModuleBase::appendStandardLibs() { std::optional>> SerializeGPUModuleBase::loadBitcodeFiles(llvm::Module &module) { SmallVector> bcFiles; - if (failed(loadBitcodeFilesFromList(module.getContext(), fileList, bcFiles, - true))) + if (failed(loadBitcodeFilesFromList(module.getContext(), librariesToLink, + bcFiles, true))) return std::nullopt; return std::move(bcFiles); } diff --git a/mlir/lib/Target/LLVM/ROCDL/Target.cpp b/mlir/lib/Target/LLVM/ROCDL/Target.cpp index 227b45133b57e3..cd7a67e58d612b 100644 --- a/mlir/lib/Target/LLVM/ROCDL/Target.cpp +++ b/mlir/lib/Target/LLVM/ROCDL/Target.cpp @@ -97,17 +97,15 @@ SerializeGPUModuleBase::SerializeGPUModuleBase( : ModuleToObject(module, target.getTriple(), target.getChip(), target.getFeatures(), target.getO()), target(target), toolkitPath(targetOptions.getToolkitPath()), - fileList(targetOptions.getLinkFiles()) { + librariesToLink(targetOptions.getLibrariesToLink()) { // If `targetOptions` has an empty toolkitPath use `getROCMPath` if (toolkitPath.empty()) toolkitPath = getROCMPath(); // Append the files in the target attribute. - if (ArrayAttr files = target.getLink()) - for (Attribute attr : files.getValue()) - if (auto file = dyn_cast(attr)) - fileList.push_back(file.str()); + if (target.getLink()) + librariesToLink.append(target.getLink().begin(), target.getLink().end()); } void SerializeGPUModuleBase::init() { @@ -128,8 +126,8 @@ ROCDLTargetAttr SerializeGPUModuleBase::getTarget() const { return target; } StringRef SerializeGPUModuleBase::getToolkitPath() const { return toolkitPath; } -ArrayRef SerializeGPUModuleBase::getFileList() const { - return fileList; +ArrayRef SerializeGPUModuleBase::getLibrariesToLink() const { + return librariesToLink; } LogicalResult SerializeGPUModuleBase::appendStandardLibs(AMDGCNLibraries libs) { @@ -160,7 +158,7 @@ LogicalResult SerializeGPUModuleBase::appendStandardLibs(AMDGCNLibraries libs) { << " does not exist or is not a file"; return true; } - fileList.push_back(pathRef.str()); + librariesToLink.push_back(StringAttr::get(target.getContext(), pathRef)); path.truncate(baseSize); return false; }; @@ -178,13 +176,13 @@ LogicalResult SerializeGPUModuleBase::appendStandardLibs(AMDGCNLibraries libs) { std::optional>> SerializeGPUModuleBase::loadBitcodeFiles(llvm::Module &module) { // Return if there are no libs to load. - if (deviceLibs == AMDGCNLibraries::None && fileList.empty()) + if (deviceLibs == AMDGCNLibraries::None && librariesToLink.empty()) return SmallVector>(); if (failed(appendStandardLibs(deviceLibs))) return std::nullopt; SmallVector> bcFiles; - if (failed(loadBitcodeFilesFromList(module.getContext(), fileList, bcFiles, - true))) + if (failed(loadBitcodeFilesFromList(module.getContext(), librariesToLink, + bcFiles, true))) return std::nullopt; return std::move(bcFiles); } diff --git a/mlir/unittests/Target/LLVM/CMakeLists.txt b/mlir/unittests/Target/LLVM/CMakeLists.txt index 4dcbc9653fa059..15835b904c4649 100644 --- a/mlir/unittests/Target/LLVM/CMakeLists.txt +++ b/mlir/unittests/Target/LLVM/CMakeLists.txt @@ -1,9 +1,13 @@ set(LLVM_LINK_COMPONENTS nativecodegen) +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + add_mlir_unittest(MLIRTargetLLVMTests SerializeNVVMTarget.cpp SerializeROCDLTarget.cpp SerializeToLLVMBitcode.cpp +DEPENDS + ${dialect_libs} ) mlir_target_link_libraries(MLIRTargetLLVMTests diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp index a92ad18c956821..eabfd1c4d32eb0 100644 --- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp +++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp @@ -18,15 +18,18 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/Config/llvm-config.h" // for LLVM_HAS_NVPTX_TARGET #include "llvm/IRReader/IRReader.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/Process.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Host.h" #include "gmock/gmock.h" +#include using namespace mlir; @@ -215,3 +218,81 @@ TEST_F(MLIRTargetLLVMNVVM, isaResult.clear(); } } + +// Test linking LLVM IR from a resource attribute. +TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { + MLIRContext context(registry); + std::string moduleStr = R"mlir( + gpu.module @nvvm_test { + llvm.func @bar() + llvm.func @nvvm_kernel(%arg0: f32) attributes {gpu.kernel, nvvm.kernel} { + llvm.call @bar() : () -> () + llvm.return + } + } + )mlir"; + // Provide the library to link as a serialized bitcode blob. + SmallVector bitcodeToLink; + { + std::string linkedLib = R"llvm( + define void @bar() { + ret void + } + )llvm"; + llvm::SMDiagnostic err; + llvm::MemoryBufferRef buffer(linkedLib, "linkedLib"); + llvm::LLVMContext llvmCtx; + std::unique_ptr module = llvm::parseIR(buffer, err, llvmCtx); + ASSERT_TRUE(module) << " Can't parse IR: " << err.getMessage(); + { + llvm::raw_svector_ostream os(bitcodeToLink); + WriteBitcodeToFile(*module, os); + } + } + + OwningOpRef module = + parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + + NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); + auto serializer = dyn_cast(target); + + // Hook to intercept the LLVM IR after linking external libs. + std::string linkedLLVMIR; + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + llvm::raw_string_ostream ros(linkedLLVMIR); + module.print(ros, nullptr); + }; + + // Store the bitcode as a DenseI8ArrayAttr. + SmallVector librariesToLink; + librariesToLink.push_back(DenseI8ArrayAttr::get( + &context, + ArrayRef((int8_t *)bitcodeToLink.data(), bitcodeToLink.size()))); + gpu::TargetOptions options({}, librariesToLink, {}, {}, + gpu::CompilationTarget::Assembly, {}, {}, + linkedCallback); + for (auto gpuModule : (*module).getBody()->getOps()) { + std::optional> object = + serializer.serializeToObject(gpuModule, options); + + // Verify that we correctly linked in the library: the external call is + // replaced by the definition. + ASSERT_TRUE(!linkedLLVMIR.empty()); + { + llvm::SMDiagnostic err; + llvm::MemoryBufferRef buffer(linkedLLVMIR, "linkedLLVMIR"); + llvm::LLVMContext llvmCtx; + std::unique_ptr module = + llvm::parseIR(buffer, err, llvmCtx); + ASSERT_TRUE(module) << " Can't parse linkedLLVMIR: " << err.getMessage() + << " IR: \n\b" << linkedLLVMIR; + llvm::Function *bar = module->getFunction("bar"); + ASSERT_TRUE(bar); + ASSERT_FALSE(bar->empty()); + } + ASSERT_TRUE(object != std::nullopt); + ASSERT_TRUE(!object->empty()); + } +}