Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move internal changes #455

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/mlir-tensorrt-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ jobs:
cat > run_format_check.sh <<EOF
#!/bin/bash
set -e
python3 -m black --check --exclude='.*\.pyi' mlir-tensorrt/test/
python3 -m black --check --exclude='.*\.pyi' mlir-tensorrt/python/
python3 -m black --check --extend-exclude='.*\.pyi' mlir-tensorrt/compiler/
python3 -m black --check --extend-exclude='.*\.pyi' mlir-tensorrt/python/
git clang-format HEAD~1 --diff
EOF

Expand Down
6 changes: 0 additions & 6 deletions mlir-tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,3 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/tensorrt/include)

add_subdirectory(compiler)
add_subdirectory(python)

if(MLIR_TRT_ENABLE_TESTING)
add_subdirectory(test)
endif()

add_subdirectory(tools)
15 changes: 15 additions & 0 deletions mlir-tensorrt/build_tools/cmake/ManagedLLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,19 @@ macro(mtrt_llvm_project)

set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin)
set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib)

# The 'MLIRPythonExtensions.Core' target upstream is missing an
# EMBED_CAPI_LINK_LIBS argument on 'MLIRCAPITransforms'. Instead, it's
# declared on the '_mlirRegisterEverything' extension, which appears to be wrong.
# TODO: fix this upstream.
if(MLIR_TRT_ENABLE_PYTHON)
get_property(mlir_core_pybind_capi_embed
TARGET MLIRPythonExtension.Core
PROPERTY mlir_python_EMBED_CAPI_LINK_LIBS)
list(FIND mlir_core_pybind_capi_embed MLIRCAPITransforms item_index)
if(item_index EQUAL -1)
set_property(TARGET MLIRPythonExtension.Core
APPEND PROPERTY mlir_python_EMBED_CAPI_LINK_LIBS MLIRCAPITransforms)
endif()
endif()
endmacro()
4 changes: 4 additions & 0 deletions mlir-tensorrt/compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
set(MLIR_TENSORRT_COMPILER_DIR "${CMAKE_CURRENT_SOURCE_DIR}")

include_directories(${CMAKE_CURRENT_LIST_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
include_directories(${MLIR_TENSORRT_ROOT_DIR}/executor/include)
include_directories(${MLIR_TENSORRT_ROOT_BINARY_DIR}/executor/include)

add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(test)
add_subdirectory(tools)

Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,6 @@ typedef struct MTRT_StableHLOToExecutableOptions {
void *ptr;
} MTRT_StableHLOToExecutableOptions;

/// A callback that allows the user to customize the metadata set for layers
/// corresponding to each MLIR operation. The callback should invoke the
/// provided append function in order to manipulate the result string.
typedef void (*MTRT_MetadataCallback)(MlirOperation op,
MlirStringCallback append,
void *appendCtx, void *userData);

MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsCreate(
MTRT_CompilerClient client, MTRT_StableHLOToExecutableOptions *options,
int32_t tensorRTBuilderOptLevel, bool tensorRTStronglyTyped);
Expand All @@ -108,13 +101,6 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
const char **debugTypes, size_t debugTypeSizes,
const char *dumpIrTreeDir = nullptr, const char *dumpTensorRTDir = nullptr);

/// Sets the layer metadata callback. The `userData` argument is passed along
/// to the callback when it is invoked.
MLIR_CAPI_EXPORTED MTRT_Status
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
void *userData);

MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
MTRT_StableHLOToExecutableOptions options);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ extern "C" {

/// Add all the dialects used by MLIR-TensorRT to the registry.
MLIR_CAPI_EXPORTED void
mlirTensorRTRegisterAllDialects(MlirDialectRegistry registry);
mtrtCompilerRegisterDialects(MlirDialectRegistry registry);

/// Register all the compiler passes used by MLIR-TensorRT.
MLIR_CAPI_EXPORTED void mlirTensorRTRegisterAllPasses();
MLIR_CAPI_EXPORTED void mtrtCompilerRegisterPasses();

/// Register all the compiler task types (pass manager types).
MLIR_CAPI_EXPORTED void mtrtCompilerRegisterTasks();

#ifdef __cplusplus
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,13 @@ namespace mlirtrt::compiler {
// StableHLOToExecutableOptions
//===----------------------------------------------------------------------===//

class StableHloToExecutableTask;
class StablehloToExecutableTask;

struct StableHLOToExecutableOptions
struct StablehloToExecutableOptions
: public mlir::OptionsBundle<DebugOptions, ExecutorOptions, DeviceOptions> {
/// Initializes the options. The extensions in the provided registry
/// must be extensions for the StableHloToExecutable task.
StableHLOToExecutableOptions(TaskExtensionRegistry extensions);

/// Return the hash of the options. Returns `nullopt` when the TensorRT
/// layer metadata callback is set since that can't be reliably hashed.
std::optional<llvm::hash_code> getHash() const override;
StablehloToExecutableOptions(TaskExtensionRegistry extensions);

/// Whether to disallow host tensors in TensorRT clusters.
bool disallowHostTensorsInTensorRTClusters = false;
Expand All @@ -71,18 +67,16 @@ struct StableHLOToExecutableOptions
/// Entrypoint function name.
std::string entrypoint = "main";

std::function<std::string(mlir::Operation *)> layerMetadataCallback{nullptr};

/// Base class for extensions associated with StableHloToExecutableTask.
class ExtensionBase : public TaskExtensionBase {
public:
ExtensionBase(mlir::TypeID typeID)
: TaskExtensionBase(typeID,
mlir::TypeID::get<StableHloToExecutableTask>()) {}
mlir::TypeID::get<StablehloToExecutableTask>()) {}

static bool classof(const TaskExtensionBase *extension) {
return extension->getTaskID() ==
mlir::TypeID::get<StableHloToExecutableTask>();
mlir::TypeID::get<StablehloToExecutableTask>();
}

enum class Phase {
Expand All @@ -98,7 +92,7 @@ struct StableHLOToExecutableOptions
/// relative to each other (yet).
virtual void
populatePasses(mlir::OpPassManager &pm, Phase phase,
const StableHLOToExecutableOptions &options) const = 0;
const StablehloToExecutableOptions &options) const = 0;
};

/// A StableHLOToExecutableOptions::Extension is an extension that must
Expand All @@ -120,39 +114,39 @@ struct StableHLOToExecutableOptions
/// A StableHloToExecutableTask is a concrete CompilationTask (PassManager) that
/// accepts StableHLO input IR and lowers it down to Executor IR which can be
/// translated into a MLIR-TensorRT executable.
class StableHloToExecutableTask
: public CompilationTask<StableHloToExecutableTask,
StableHLOToExecutableOptions> {
class StablehloToExecutableTask
: public CompilationTask<StablehloToExecutableTask,
StablehloToExecutableOptions> {
public:
using Base::Base;

/// Build the clustering pipeline that occurs on Stablehlo Ops.
static void
buildStablehloClusteringPipeline(mlir::OpPassManager &pm,
const StableHLOToExecutableOptions &options);
const StablehloToExecutableOptions &options);

/// Build the pipeline (bufferization and lowering) that runs after
/// clustering.
static void
buildPostClusteringPipeline(mlir::OpPassManager &pm,
const StableHLOToExecutableOptions &options);
const StablehloToExecutableOptions &options);

static void populatePassManager(mlir::PassManager &pm,
const StableHLOToExecutableOptions &options);
const StablehloToExecutableOptions &options);

/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
/// This is the "functional" entrypoint that will allocate a new PassManager
/// for a single run.
static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
compileStableHLOToExecutable(mlir::ModuleOp module,
const StableHLOToExecutableOptions &options);
const StablehloToExecutableOptions &options);

/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
/// This is the "functional" entrypoint that will allocate a new PassManager
/// for a single run.
static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
compileStableHLOToExecutable(CompilerClient &client, mlir::ModuleOp module,
const StableHLOToExecutableOptions &options);
const StablehloToExecutableOptions &options);
};

/// Register the task/options with the client's registry.
Expand All @@ -175,7 +169,7 @@ void registerStablehloClusteringPipelines();

} // namespace mlirtrt::compiler

MLIR_DECLARE_EXPLICIT_TYPE_ID(mlirtrt::compiler::StableHloToExecutableTask)
MLIR_DECLARE_EXPLICIT_TYPE_ID(mlirtrt::compiler::StablehloToExecutableTask)

#endif // MLIR_TRT_ENABLE_HLO
#endif // MLIR_TENSORRT_COMPILER_STABLEHLOTOEXECUTABLE
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace mlirtrt::compiler {
//===----------------------------------------------------------------------===//

class StableHLOToExecutableTensorRTExtension
: public StableHLOToExecutableOptions::Extension<
: public StablehloToExecutableOptions::Extension<
StableHLOToExecutableTensorRTExtension> {
public:
StableHLOToExecutableTensorRTExtension();
Expand All @@ -45,7 +45,7 @@ class StableHLOToExecutableTensorRTExtension
/// It is not guarunteed the order in which different extensions are run
/// relative to each other (yet).
void populatePasses(mlir::OpPassManager &pm, Phase phase,
const StableHLOToExecutableOptions &options) const final;
const StablehloToExecutableOptions &options) const final;

/// Allows the extension to hook into the option parsing infrastructure.
void addToOptions(mlir::OptionsContext &context) final {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ inline void registerAllMlirTensorRtPasses() {
mlir::registerConvertPDLToPDLInterp();

#ifdef MLIR_TRT_ENABLE_HLO
mlirtrt::compiler::registerStableHloToExecutableTask();
mlirtrt::compiler::registerStablehloClusteringPipelines();
registerStableHloInputPipelines();
stablehlo_ext::registerStableHloExtPasses();
Expand Down
46 changes: 10 additions & 36 deletions mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ using namespace mlir;
#endif
DEFINE_C_API_PTR_METHODS(MTRT_CompilerClient, CompilerClient)
DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions,
StableHLOToExecutableOptions)
StablehloToExecutableOptions)
DEFINE_C_API_PTR_METHODS(MTRT_OptionsContext, OptionsContext)
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
Expand Down Expand Up @@ -84,7 +84,7 @@ MTRT_Status mtrtCompilerClientCreate(MlirContext context,
ctx->getOrLoadDialect<mlir::plan::PlanDialect>();
assert(planDialect && "expected loaded PlanDialect");
if (failed(planDialect->extensionConstructors.addCheckedExtensionConstructor<
compiler::StableHloToExecutableTask,
compiler::StablehloToExecutableTask,
compiler::StableHLOToExecutableTensorRTExtension>()))
emitWarning(mlir::UnknownLoc::get(ctx))
<< "ignoring duplicate extension load request; TensorRTExtension is "
Expand Down Expand Up @@ -156,7 +156,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreate(
context->getLoadedDialect<mlir::plan::PlanDialect>();
compiler::TaskExtensionRegistry extensions =
planDialect->extensionConstructors
.getExtensionRegistryForTask<compiler::StableHloToExecutableTask>();
.getExtensionRegistryForTask<compiler::StablehloToExecutableTask>();

// Check that default extension set is loaded and set options on the TRT
// extension.
Expand All @@ -168,7 +168,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreate(
trtExtension->setOptions(translationOpts);

auto result =
std::make_unique<StableHLOToExecutableOptions>(std::move(extensions));
std::make_unique<StablehloToExecutableOptions>(std::move(extensions));

llvm::Error finalizeStatus = result->finalize();

Expand All @@ -194,7 +194,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreateFromArgs(
context->getLoadedDialect<mlir::plan::PlanDialect>();
compiler::TaskExtensionRegistry extensions =
planDialect->extensionConstructors
.getExtensionRegistryForTask<compiler::StableHloToExecutableTask>();
.getExtensionRegistryForTask<compiler::StablehloToExecutableTask>();

// Check that default extension set is loaded.
assert(
Expand All @@ -203,7 +203,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreateFromArgs(
"expected valid StableHLOToExecutableTensorRTExtension");

auto result =
std::make_unique<StableHLOToExecutableOptions>(std::move(extensions));
std::make_unique<StablehloToExecutableOptions>(std::move(extensions));
std::vector<llvm::StringRef> argvStrRef(argc);
for (unsigned i = 0; i < argc; i++)
argvStrRef[i] = llvm::StringRef(argv[i].data, argv[i].length);
Expand Down Expand Up @@ -234,7 +234,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
const char **debugTypes, size_t debugTypeSizes, const char *dumpIrTreeDir,
const char *dumpTensorRTDir) {

StableHLOToExecutableOptions *cppOpts = unwrap(options);
StablehloToExecutableOptions *cppOpts = unwrap(options);
cppOpts->get<DebugOptions>().enableLLVMDebugFlag = enableDebugging;
for (unsigned i = 0; i < debugTypeSizes; i++)
cppOpts->get<DebugOptions>().llvmDebugTypes.emplace_back(debugTypes[i]);
Expand All @@ -245,35 +245,9 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
return mtrtStatusGetOk();
}

MTRT_Status
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
void *userData) {
StableHLOToExecutableOptions *cppOpts = unwrap(options);

// Construct the append callback which we will pass to the callback provided
// by the user. We do it this way to avoid needing a string construct in the C
// API.
auto appendFunc = [](MlirStringRef str, void *appendCtx) {
std::string &accum = *reinterpret_cast<std::string *>(appendCtx);
accum += std::string(str.data, str.length);
};

// Capturing by reference here will cause `callback` to point to the wrong
// place at the time this callback is invoked.
cppOpts->layerMetadataCallback = [=](Operation *op) {
std::string accum;
void *appendCtx = reinterpret_cast<void *>(&accum);
callback(wrap(op), appendFunc, appendCtx, userData);
return accum;
};

return mtrtStatusGetOk();
}

MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
MTRT_StableHLOToExecutableOptions options) {
delete reinterpret_cast<StableHLOToExecutableOptions *>(options.ptr);
delete reinterpret_cast<StablehloToExecutableOptions *>(options.ptr);
return mtrtStatusGetOk();
}

Expand All @@ -288,7 +262,7 @@ mtrtStableHloPipelineGetCached(MTRT_CompilerClient client,

mlir::PassManager *runner{};
if (unwrap(options)->getHash()) {
runner = &unwrap(client)->getOrCreatePassManager<StableHloToExecutableTask>(
runner = &unwrap(client)->getOrCreatePassManager<StablehloToExecutableTask>(
*unwrap(options));
result->ptr = runner;
return mtrtStatusGetOk();
Expand Down Expand Up @@ -340,7 +314,7 @@ MTRT_Status mtrtCompilerStableHLOToExecutable(
"StableHLO-to-Executable compilation expects a ModuleOp");

StatusOr<std::unique_ptr<mlirtrt::runtime::Executable>> exe =
compiler::StableHloToExecutableTask::compileStableHLOToExecutable(
compiler::StablehloToExecutableTask::compileStableHLOToExecutable(
*unwrap(client), moduleOp, *unwrap(stableHloToExecutableOptions));
if (!exe.isOk())
return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InternalError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@ add_mlir_tensorrt_public_c_api_library(MLIRTensorRTCAPIRegisterAllDialects
LINK_LIBS PUBLIC
MLIRTensorRTRegistration
MLIRCAPIIR
MLIRCAPITransforms
MLIRTensorRTCompilerStableHloToExecutable
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@
//===----------------------------------------------------------------------===//

#include "mlir-tensorrt-c/Compiler/Registration/RegisterAllDialects.h"
#include "mlir-tensorrt/Compiler/StableHloToExecutable.h"
#include "mlir-tensorrt/Registration/RegisterMlirTensorRtDialects.h"
#include "mlir-tensorrt/Registration/RegisterMlirTensorRtPasses.h"
#include "mlir/CAPI/IR.h"

void mlirTensorRTRegisterAllDialects(MlirDialectRegistry registry) {
void mtrtCompilerRegisterDialects(MlirDialectRegistry registry) {
mlir::registerAllMlirTensorRtDialects(*unwrap(registry));
}

void mlirTensorRTRegisterAllPasses() {
void mtrtCompilerRegisterPasses() {
mlir::tensorrt::registerAllMlirTensorRtPasses();
}

void mtrtCompilerRegisterTasks() {
mlirtrt::compiler::registerStableHloToExecutableTask();
}
Loading
Loading