Skip to content

Commit

Permalink
Move internal changes
Browse files Browse the repository at this point in the history
## NFC: Simplify some aspects of options management (OptionsContext)

- Adds a convenience 'OptionsContext::Option' class that simplifies how
  options are declared.

- Closes a loophole where tuples of structs containing options can cause
  crashes if they populate options in their constructor. Due to how the
  external storage mechanism works, we can no longer use direct `std::tuple`
  of aggregate objects which invoke `addOption`. Instead, one must use
  `unique_ptr` to wrap those types when used as elements of a `std::tuple`.

- To help enforce this, we explicitly delete the move constructor of
  `OptionsProvider`.

## [compiler|python] Update how cached pipelines/"Compiler Tasks" are registered

This change updates how registration functions for  "compilation tasks"
invoked. We now expose a C API method that can be invoked within the
Pybind11 module initializer. This decouples compiler task registration from
pass or dialect registration.

This change also cleans up the C API function naming for pass/dialect
registration functions.

## [python] Add more robust CMake logic for fixing missing CAPI dependency in core MLIR PyBind module

Adds CMake logic to ensure that the Core '_mlir' pybind extension has the
correct CAPI dependencies declared until the upstream CMake declarations can be
fixed.

## NFC: Remove unnecessary PyCapsule <-> CAPI casters in compiler and runtime bindings

Removes unnecessary custom PyBind11 capsule -> C API object casters.
These cast functions are only required when it is desired to allow
PyBind11 to extract the C API object from the C++ python wrapper type
automatically.

## [tensorrt|compiler] Drop "layer metadata callback" utility from TensorRT translation

This change removes the "layer metadata callback" feature from the
MLIR-to-TensorRT translation. It also removes the relevant APIs from the
MLIR-TensorRT compiler's C++ and Python APIs.

This capability was originally offered as a bridge for populating the
generated TensorRT ILayers with custom metadata. However, the mechanism
prevents caching of pass pipelines and therefore is too expensive to use.

In the future, any metadata passed to TensorRT should be derived from
the MLIR operations' location information.

## NFC: update various uses of "Stablehlo" in class and function names to have consistent capitalization

## NFC: Reorganize some directories

This change:

- Moves the top-level 'tools' to 'compiler/tools'
- Moves the top-level 'test' to 'compiler/test'
- Moves the 'mlir-tensorrt-tblgen' tool under 'tensorrt/tools'
  since the 'tensorrt' project is supposed to be independent.
- Similarly move TensorRT-specific python definitions under `tensorrt/python`.

## [executor]: Add a missing guard for builds without CUDA enabled.

Wrapping the makeCudaStringError function with MLIR_EXECUTOR_ENABLE_CUDA fixes builds without CUDA enabled.

## [executor] Use Lua locals for block arguments

Previously, the Executor MLIR-to-Lua translator used Lua globals for
block arguments outside of the entry block since the values that
represent block arguments need to be passed between blocks. On the
other hand, the scope of Lua local variables are restricted to their
block. It is almost never a good idea to use Lua global variables in
our translation strategy, however -- for coroutine functions, a
translation that uses globals is obviously incorrect since all Lua
coroutines in a single Lua environment share the same set of globals.

This change declares all block arguments up front as locals in the
"entry block" and just sets them to `nil` initially. Since we don't
declare a block scope for the entry block, all the following Lua block
scopes will have these locals in scope. This allows us to retain the
use of locals for all block arguments.

GitOrigin-RevId: e9dd03c47eab6145e889ea8ff56fd1c71181f72a
  • Loading branch information
Copybara Bot authored and christopherbate committed Dec 17, 2024
1 parent cbf0e09 commit 33dcba8
Show file tree
Hide file tree
Showing 177 changed files with 419 additions and 498 deletions.
4 changes: 4 additions & 0 deletions mlir-tensorrt/compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
set(MLIR_TENSORRT_COMPILER_DIR "${CMAKE_CURRENT_SOURCE_DIR}")

include_directories(${CMAKE_CURRENT_LIST_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
include_directories(${MLIR_TENSORRT_ROOT_DIR}/executor/include)
include_directories(${MLIR_TENSORRT_ROOT_BINARY_DIR}/executor/include)

add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(test)
add_subdirectory(tools)

14 changes: 0 additions & 14 deletions mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,6 @@ typedef struct MTRT_StableHLOToExecutableOptions {
void *ptr;
} MTRT_StableHLOToExecutableOptions;

/// A callback that allows the user to customize the metadata set for layers
/// corresponding to each MLIR operation. The callback should invoke the
/// provided append function in order to manipulate the result string.
typedef void (*MTRT_MetadataCallback)(MlirOperation op,
MlirStringCallback append,
void *appendCtx, void *userData);

MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsCreate(
MTRT_CompilerClient client, MTRT_StableHLOToExecutableOptions *options,
int32_t tensorRTBuilderOptLevel, bool tensorRTStronglyTyped);
Expand All @@ -108,13 +101,6 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
const char **debugTypes, size_t debugTypeSizes,
const char *dumpIrTreeDir = nullptr, const char *dumpTensorRTDir = nullptr);

/// Sets the layer metadata callback. The `userData` argument is passed along
/// to the callback when it is invoked.
MLIR_CAPI_EXPORTED MTRT_Status
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
void *userData);

MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
MTRT_StableHLOToExecutableOptions options);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ extern "C" {

/// Add all the dialects used by MLIR-TensorRT to the registry.
MLIR_CAPI_EXPORTED void
mlirTensorRTRegisterAllDialects(MlirDialectRegistry registry);
mtrtCompilerRegisterDialects(MlirDialectRegistry registry);

/// Register all the compiler passes used by MLIR-TensorRT.
MLIR_CAPI_EXPORTED void mlirTensorRTRegisterAllPasses();
MLIR_CAPI_EXPORTED void mtrtCompilerRegisterPasses();

/// Register all the compiler task types (pass manager types).
MLIR_CAPI_EXPORTED void mtrtCompilerRegisterTasks();

#ifdef __cplusplus
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,13 @@ namespace mlirtrt::compiler {
// StableHLOToExecutableOptions
//===----------------------------------------------------------------------===//

class StableHloToExecutableTask;
class StablehloToExecutableTask;

struct StableHLOToExecutableOptions
struct StablehloToExecutableOptions
: public mlir::OptionsBundle<DebugOptions, ExecutorOptions, DeviceOptions> {
/// Initializes the options. The extensions in the provided registry
/// must be extensions for the StableHloToExecutable task.
StableHLOToExecutableOptions(TaskExtensionRegistry extensions);

/// Return the hash of the options. Returns `nullopt` when the TensorRT
/// layer metadata callback is set since that can't be reliably hashed.
std::optional<llvm::hash_code> getHash() const override;
StablehloToExecutableOptions(TaskExtensionRegistry extensions);

/// Whether to disallow host tensors in TensorRT clusters.
bool disallowHostTensorsInTensorRTClusters = false;
Expand All @@ -71,18 +67,16 @@ struct StableHLOToExecutableOptions
/// Entrypoint function name.
std::string entrypoint = "main";

std::function<std::string(mlir::Operation *)> layerMetadataCallback{nullptr};

/// Base class for extensions associated with StableHloToExecutableTask.
class ExtensionBase : public TaskExtensionBase {
public:
ExtensionBase(mlir::TypeID typeID)
: TaskExtensionBase(typeID,
mlir::TypeID::get<StableHloToExecutableTask>()) {}
mlir::TypeID::get<StablehloToExecutableTask>()) {}

static bool classof(const TaskExtensionBase *extension) {
return extension->getTaskID() ==
mlir::TypeID::get<StableHloToExecutableTask>();
mlir::TypeID::get<StablehloToExecutableTask>();
}

enum class Phase {
Expand All @@ -98,7 +92,7 @@ struct StableHLOToExecutableOptions
/// relative to each other (yet).
virtual void
populatePasses(mlir::OpPassManager &pm, Phase phase,
const StableHLOToExecutableOptions &options) const = 0;
const StablehloToExecutableOptions &options) const = 0;
};

/// A StableHLOToExecutableOptions::Extension is an extension that must
Expand All @@ -120,39 +114,39 @@ struct StableHLOToExecutableOptions
/// A StableHloToExecutableTask is a concrete CompilationTask (PassManager) that
/// accepts StableHLO input IR and lowers it down to Executor IR which can be
/// translated into a MLIR-TensorRT executable.
class StableHloToExecutableTask
: public CompilationTask<StableHloToExecutableTask,
StableHLOToExecutableOptions> {
class StablehloToExecutableTask
: public CompilationTask<StablehloToExecutableTask,
StablehloToExecutableOptions> {
public:
using Base::Base;

/// Build the clustering pipeline that occurs on Stablehlo Ops.
static void
buildStablehloClusteringPipeline(mlir::OpPassManager &pm,
const StableHLOToExecutableOptions &options);
const StablehloToExecutableOptions &options);

/// Build the pipeline (bufferization and lowering) that runs after
/// clustering.
static void
buildPostClusteringPipeline(mlir::OpPassManager &pm,
const StableHLOToExecutableOptions &options);
const StablehloToExecutableOptions &options);

static void populatePassManager(mlir::PassManager &pm,
const StableHLOToExecutableOptions &options);
const StablehloToExecutableOptions &options);

/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
/// This is the "functional" entrypoint that will allocate a new PassManager
/// for a single run.
static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
compileStableHLOToExecutable(mlir::ModuleOp module,
const StableHLOToExecutableOptions &options);
const StablehloToExecutableOptions &options);

/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
/// This is the "functional" entrypoint that will allocate a new PassManager
/// for a single run.
static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
compileStableHLOToExecutable(CompilerClient &client, mlir::ModuleOp module,
const StableHLOToExecutableOptions &options);
const StablehloToExecutableOptions &options);
};

/// Register the task/options with the client's registry.
Expand All @@ -175,7 +169,7 @@ void registerStablehloClusteringPipelines();

} // namespace mlirtrt::compiler

MLIR_DECLARE_EXPLICIT_TYPE_ID(mlirtrt::compiler::StableHloToExecutableTask)
MLIR_DECLARE_EXPLICIT_TYPE_ID(mlirtrt::compiler::StablehloToExecutableTask)

#endif // MLIR_TRT_ENABLE_HLO
#endif // MLIR_TENSORRT_COMPILER_STABLEHLOTOEXECUTABLE
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace mlirtrt::compiler {
//===----------------------------------------------------------------------===//

class StableHLOToExecutableTensorRTExtension
: public StableHLOToExecutableOptions::Extension<
: public StablehloToExecutableOptions::Extension<
StableHLOToExecutableTensorRTExtension> {
public:
StableHLOToExecutableTensorRTExtension();
Expand All @@ -45,7 +45,7 @@ class StableHLOToExecutableTensorRTExtension
/// It is not guarunteed the order in which different extensions are run
/// relative to each other (yet).
void populatePasses(mlir::OpPassManager &pm, Phase phase,
const StableHLOToExecutableOptions &options) const final;
const StablehloToExecutableOptions &options) const final;

/// Allows the extension to hook into the option parsing infrastructure.
void addToOptions(mlir::OptionsContext &context) final {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ inline void registerAllMlirTensorRtPasses() {
mlir::registerConvertPDLToPDLInterp();

#ifdef MLIR_TRT_ENABLE_HLO
mlirtrt::compiler::registerStableHloToExecutableTask();
mlirtrt::compiler::registerStablehloClusteringPipelines();
registerStableHloInputPipelines();
stablehlo_ext::registerStableHloExtPasses();
Expand Down
46 changes: 10 additions & 36 deletions mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ using namespace mlir;
#endif
DEFINE_C_API_PTR_METHODS(MTRT_CompilerClient, CompilerClient)
DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions,
StableHLOToExecutableOptions)
StablehloToExecutableOptions)
DEFINE_C_API_PTR_METHODS(MTRT_OptionsContext, OptionsContext)
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
Expand Down Expand Up @@ -84,7 +84,7 @@ MTRT_Status mtrtCompilerClientCreate(MlirContext context,
ctx->getOrLoadDialect<mlir::plan::PlanDialect>();
assert(planDialect && "expected loaded PlanDialect");
if (failed(planDialect->extensionConstructors.addCheckedExtensionConstructor<
compiler::StableHloToExecutableTask,
compiler::StablehloToExecutableTask,
compiler::StableHLOToExecutableTensorRTExtension>()))
emitWarning(mlir::UnknownLoc::get(ctx))
<< "ignoring duplicate extension load request; TensorRTExtension is "
Expand Down Expand Up @@ -156,7 +156,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreate(
context->getLoadedDialect<mlir::plan::PlanDialect>();
compiler::TaskExtensionRegistry extensions =
planDialect->extensionConstructors
.getExtensionRegistryForTask<compiler::StableHloToExecutableTask>();
.getExtensionRegistryForTask<compiler::StablehloToExecutableTask>();

// Check that default extension set is loaded and set options on the TRT
// extension.
Expand All @@ -168,7 +168,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreate(
trtExtension->setOptions(translationOpts);

auto result =
std::make_unique<StableHLOToExecutableOptions>(std::move(extensions));
std::make_unique<StablehloToExecutableOptions>(std::move(extensions));

llvm::Error finalizeStatus = result->finalize();

Expand All @@ -194,7 +194,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreateFromArgs(
context->getLoadedDialect<mlir::plan::PlanDialect>();
compiler::TaskExtensionRegistry extensions =
planDialect->extensionConstructors
.getExtensionRegistryForTask<compiler::StableHloToExecutableTask>();
.getExtensionRegistryForTask<compiler::StablehloToExecutableTask>();

// Check that default extension set is loaded.
assert(
Expand All @@ -203,7 +203,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreateFromArgs(
"expected valid StableHLOToExecutableTensorRTExtension");

auto result =
std::make_unique<StableHLOToExecutableOptions>(std::move(extensions));
std::make_unique<StablehloToExecutableOptions>(std::move(extensions));
std::vector<llvm::StringRef> argvStrRef(argc);
for (unsigned i = 0; i < argc; i++)
argvStrRef[i] = llvm::StringRef(argv[i].data, argv[i].length);
Expand Down Expand Up @@ -234,7 +234,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
const char **debugTypes, size_t debugTypeSizes, const char *dumpIrTreeDir,
const char *dumpTensorRTDir) {

StableHLOToExecutableOptions *cppOpts = unwrap(options);
StablehloToExecutableOptions *cppOpts = unwrap(options);
cppOpts->get<DebugOptions>().enableLLVMDebugFlag = enableDebugging;
for (unsigned i = 0; i < debugTypeSizes; i++)
cppOpts->get<DebugOptions>().llvmDebugTypes.emplace_back(debugTypes[i]);
Expand All @@ -245,35 +245,9 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
return mtrtStatusGetOk();
}

MTRT_Status
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
void *userData) {
StableHLOToExecutableOptions *cppOpts = unwrap(options);

// Construct the append callback which we will pass to the callback provided
// by the user. We do it this way to avoid needing a string construct in the C
// API.
auto appendFunc = [](MlirStringRef str, void *appendCtx) {
std::string &accum = *reinterpret_cast<std::string *>(appendCtx);
accum += std::string(str.data, str.length);
};

// Capturing by reference here will cause `callback` to point to the wrong
// place at the time this callback is invoked.
cppOpts->layerMetadataCallback = [=](Operation *op) {
std::string accum;
void *appendCtx = reinterpret_cast<void *>(&accum);
callback(wrap(op), appendFunc, appendCtx, userData);
return accum;
};

return mtrtStatusGetOk();
}

MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
MTRT_StableHLOToExecutableOptions options) {
delete reinterpret_cast<StableHLOToExecutableOptions *>(options.ptr);
delete reinterpret_cast<StablehloToExecutableOptions *>(options.ptr);
return mtrtStatusGetOk();
}

Expand All @@ -288,7 +262,7 @@ mtrtStableHloPipelineGetCached(MTRT_CompilerClient client,

mlir::PassManager *runner{};
if (unwrap(options)->getHash()) {
runner = &unwrap(client)->getOrCreatePassManager<StableHloToExecutableTask>(
runner = &unwrap(client)->getOrCreatePassManager<StablehloToExecutableTask>(
*unwrap(options));
result->ptr = runner;
return mtrtStatusGetOk();
Expand Down Expand Up @@ -340,7 +314,7 @@ MTRT_Status mtrtCompilerStableHLOToExecutable(
"StableHLO-to-Executable compilation expects a ModuleOp");

StatusOr<std::unique_ptr<mlirtrt::runtime::Executable>> exe =
compiler::StableHloToExecutableTask::compileStableHLOToExecutable(
compiler::StablehloToExecutableTask::compileStableHLOToExecutable(
*unwrap(client), moduleOp, *unwrap(stableHloToExecutableOptions));
if (!exe.isOk())
return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InternalError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@ add_mlir_tensorrt_public_c_api_library(MLIRTensorRTCAPIRegisterAllDialects
LINK_LIBS PUBLIC
MLIRTensorRTRegistration
MLIRCAPIIR
MLIRCAPITransforms
MLIRTensorRTCompilerStableHloToExecutable
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@
//===----------------------------------------------------------------------===//

#include "mlir-tensorrt-c/Compiler/Registration/RegisterAllDialects.h"
#include "mlir-tensorrt/Compiler/StableHloToExecutable.h"
#include "mlir-tensorrt/Registration/RegisterMlirTensorRtDialects.h"
#include "mlir-tensorrt/Registration/RegisterMlirTensorRtPasses.h"
#include "mlir/CAPI/IR.h"

void mlirTensorRTRegisterAllDialects(MlirDialectRegistry registry) {
void mtrtCompilerRegisterDialects(MlirDialectRegistry registry) {
mlir::registerAllMlirTensorRtDialects(*unwrap(registry));
}

void mlirTensorRTRegisterAllPasses() {
void mtrtCompilerRegisterPasses() {
mlir::tensorrt::registerAllMlirTensorRtPasses();
}

void mtrtCompilerRegisterTasks() {
mlirtrt::compiler::registerStableHloToExecutableTask();
}
Loading

0 comments on commit 33dcba8

Please sign in to comment.