Skip to content

Commit

Permalink
Move internal changes
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 169b47d9d5749a3ae2fd003369d1e6f0a625c97f
  • Loading branch information
Copybara Bot committed Dec 11, 2024
1 parent 85d1d7d commit ff6e8e7
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 27 deletions.
17 changes: 17 additions & 0 deletions mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define MLIR_TENSORRT_C_COMPILER_COMPILER

#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
#include "mlir-c/Support.h"
#include "mlir-executor-c/Common/Common.h"
#include "mlir-executor-c/Support/Status.h"
Expand Down Expand Up @@ -122,10 +123,26 @@ static inline bool mtrtStableHloToExecutableOptionsIsNull(
return !options.ptr;
}

//===----------------------------------------------------------------------===//
// StableHloPipeline APIs
//===----------------------------------------------------------------------===//

static inline bool mtrtStableHloPipelineIsNull(MlirPassManager pm) {
return !pm.ptr;
}

MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloPipelineGetCached(
MTRT_CompilerClient client, MTRT_StableHLOToExecutableOptions options,
MlirPassManager *result);

//===----------------------------------------------------------------------===//
// Main StableHLO Compiler API Functions
//===----------------------------------------------------------------------===//

/// Get Executable using StableHloPassManager.
MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerGetExecutable(
MlirPassManager pm, MlirOperation module, MTRT_Executable *result);

/// Compiler StableHLO to Executable.
MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerStableHLOToExecutable(
MTRT_CompilerClient client, MlirOperation module,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct DeviceOptions : public OptionsProvider<DeviceOptions> {
/// Whether to ignore `deviceX` options and instead infer them from the GPUs
/// on the host system running the compilation.
bool shouldInferFromHost = false;
Status inferFromHost();

public:
void addToOptions(mlir::OptionsContext &context) {
Expand Down
49 changes: 49 additions & 0 deletions mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir-executor-c/Support/Status.h"
#include "mlir-executor/Target/Lua/TranslateToRuntimeExecutable.h"
#include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h"
#include "mlir-tensorrt-dialect/Utils/Options.h"
#include "mlir-tensorrt/Compiler/Extension.h"
Expand Down Expand Up @@ -276,10 +277,58 @@ MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// StableHloPipeline APIs
//===----------------------------------------------------------------------===//

MTRT_Status
mtrtStableHloPipelineGetCached(MTRT_CompilerClient client,
MTRT_StableHLOToExecutableOptions options,
MlirPassManager *result) {

mlir::PassManager *runner{};
if (unwrap(options)->getHash()) {
runner = &unwrap(client)->getOrCreatePassManager<StableHloToExecutableTask>(
*unwrap(options));
result->ptr = runner;
return mtrtStatusGetOk();
}
return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InternalError,
"options cannot be hashed");
}

//===----------------------------------------------------------------------===//
// Main StableHLO Compiler API Functions
//===----------------------------------------------------------------------===//

MTRT_Status mtrtCompilerGetExecutable(MlirPassManager pm, MlirOperation module,
MTRT_Executable *result) {

ModuleOp moduleOp = llvm::dyn_cast<ModuleOp>(unwrap(module));
if (!moduleOp)
return mtrtStatusCreate(
MTRT_StatusCode::MTRT_StatusCode_InvalidArgument,
"StableHLO-to-Executable compilation expects a ModuleOp");

// Setup pass manager
mlir::PassManager *runner = static_cast<mlir::PassManager *>(pm.ptr);
if (failed(runner->run(moduleOp)))
return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InternalError,
"failed to run MLIR compilation pipeline");

// Translate to Runtime Executable
FailureOr<std::unique_ptr<runtime::ExecutableStorage>> exeStorage =
mlir::translateToRuntimeExecutable(unwrap(module));
if (failed(exeStorage))
return mtrtStatusCreate(
MTRT_StatusCode::MTRT_StatusCode_InternalError,
"failed to perform MLIR-to-RuntimeExecutable translation");

result->ptr =
std::make_unique<runtime::Executable>(std::move(*exeStorage)).release();
return mtrtStatusGetOk();
}

MTRT_Status mtrtCompilerStableHLOToExecutable(
MTRT_CompilerClient client, MlirOperation module,
MTRT_StableHLOToExecutableOptions stableHloToExecutableOptions,
Expand Down
27 changes: 27 additions & 0 deletions mlir-tensorrt/compiler/lib/Compiler/OptionsProviders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,36 @@
///
//===----------------------------------------------------------------------===//
#include "mlir-tensorrt/Compiler/OptionsProviders.h"
#include "cuda_runtime_api.h"
#include "mlir-executor/Support/DeviceInfo.h"
#include "llvm/Support/Error.h"

mlirtrt::Status mlirtrt::compiler::DeviceOptions::inferFromHost() {
cudaDeviceProp properties;
cudaError_t err = cudaGetDeviceProperties(&properties, 0);
if (err != cudaSuccess)
return getStatusWithMsg(StatusCode::InternalError,
"failed to get cuda device properties");
int ccMajor = 0;
int ccMinor = 0;
err = cudaDeviceGetAttribute(
&ccMajor, cudaDeviceAttr::cudaDevAttrComputeCapabilityMajor, 0);
if (err != cudaSuccess)
return getStatusWithMsg(StatusCode::InternalError,
"failed to get cuda device compute capability");
err = cudaDeviceGetAttribute(
&ccMinor, cudaDeviceAttr::cudaDevAttrComputeCapabilityMinor, 0);
if (err != cudaSuccess)
return getStatusWithMsg(StatusCode::InternalError,
"failed to get cuda device compute capability");
// We want SM version as a single number.
int64_t smVersion = ccMajor * 10 + ccMinor;
info.computeCapability = smVersion;
info.maxSharedMemoryPerBlockKb = properties.sharedMemPerBlock / 1024;
info.maxRegistersPerBlock = properties.regsPerBlock;
return Status::getOk();
}

llvm::Error mlirtrt::compiler::DeviceOptions::finalizeImpl() {
if (shouldInferFromHost) {
StatusOr<DeviceInfo> deviceInfo = getDeviceInformationFromHost();
Expand Down
32 changes: 18 additions & 14 deletions mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,20 +461,24 @@ static Status validateArgsTypesAgainstFuncArgs(const RuntimeValue *runArg,
}

if (view.getStrides() != value->getStrides()) {
for (unsigned i = 0; i < view.getStrides().size(); ++i) {
if (value->getStrides()[i] < 0)
return getInvalidArgStatus(
"all strides must be non-negative but received shape [{0:$[, ]}]",
value->getStrides());
if (view.getStrides()[i] >= 0 &&
view.getStrides()[i] != value->getStrides()[i])
// Allow the special case of non-canonical stride for unit dimensions
// See https://github.com/pytorch/pytorch/issues/99803 for more detail
if (value->getShape()[i] != 1 || value->getStrides()[i] != 1)
return getInvalidArgStatus(
"Runtime stride mismatch. Expected [{0:$[, ]}] "
"but received [{1:$[, ]}]",
view.getStrides(), value->getStrides());
bool isEmpty = llvm::is_contained(view.getShape(), 0);
if (!isEmpty) { // Allow any non-canonical stride for empty tensor
for (unsigned i = 0; i < view.getStrides().size(); ++i) {
if (value->getStrides()[i] < 0)
return getInvalidArgStatus("all strides must be non-negative but "
"received shape [{0:$[, ]}]",
value->getStrides());
if (view.getStrides()[i] >= 0 &&
view.getStrides()[i] != value->getStrides()[i])
// Allow the special case of non-canonical stride for unit
// dimensions See https://github.com/pytorch/pytorch/issues/99803
// for more detail
if (value->getShape()[i] != 1 || value->getStrides()[i] != 1)
return getInvalidArgStatus(
"Runtime stride mismatch. Expected [{0:$[, ]}] "
"but received [{1:$[, ]}]",
view.getStrides(), value->getStrides());
}
}
}

Expand Down
1 change: 1 addition & 0 deletions mlir-tensorrt/executor/test/Unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set_target_properties(MLIRTensorRTExecutorUnitTests PROPERTIES FOLDER "MLIR-Tens
function(add_mlir_executor_unittest target)
set(LLVM_LINK_COMPONENTS Support)
add_llvm_executable(${target} IGNORE_EXTERNALIZE_DEBUGINFO NO_INSTALL_RPATH ${ARGN})
set_target_properties(${target} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}")
add_dependencies(MLIRTensorRTExecutorUnitTests ${target})
llvm_update_compile_flags(${target})
if(TARGET gtest)
Expand Down
32 changes: 32 additions & 0 deletions mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ class PyStableHLOToExecutableOptions

~PyStableHLOToExecutableOptions() { callback = nullptr; }
};

/// Python object type wrapper for `MlirPassManager`.
class PyStableHloPipeline
: public PyMTRTWrapper<PyStableHloPipeline, MlirPassManager> {
public:
using PyMTRTWrapper::PyMTRTWrapper;
DECLARE_WRAPPER_CONSTRUCTORS(PyStableHloPipeline);

static constexpr auto kMethodTable =
CAPITable<MlirPassManager>{mtrtStableHloPipelineIsNull, nullptr};
};
} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -355,6 +366,27 @@ PYBIND11_MODULE(_api, m) {
#endif
;

py::class_<PyStableHloPipeline>(m, "StableHloPipeline", py::module_local())
.def(py::init<>([](PyCompilerClient &client,
PyStableHLOToExecutableOptions &options) {
MlirPassManager pm{};
MTRT_Status status =
mtrtStableHloPipelineGetCached(client, options, &pm);
THROW_IF_MTRT_ERROR(status);
return new PyStableHloPipeline(pm);
}),
py::arg("client"), py::arg("options"));

m.def(
"get_executable",
[](PyStableHloPipeline &pm, MlirOperation module) {
MTRT_Executable exe{nullptr};
MTRT_Status status = mtrtCompilerGetExecutable(pm, module, &exe);
THROW_IF_MTRT_ERROR(status);
return new PyExecutable(exe);
},
py::arg("pm"), py::arg("module"));

m.def(
"compiler_stablehlo_to_executable",
[](PyCompilerClient &client, MlirOperation module,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ __all__ = [
"StableHLOToExecutableOptions",
"Type",
"bf16",
"PyStableHloPipeline",
"get_executable",
"compiler_stablehlo_to_executable",
"device",
"f16",
Expand Down Expand Up @@ -292,6 +294,12 @@ class StableHLOToExecutableOptions:
class Type:
def __init__(self, cast_from_type: Type) -> None: ...

class PyStableHloPipeline:
def __init__(
self, arg0: CompilerClient, arg1: StableHLOToExecutableOptions
) -> None: ...

def get_executable(client: CompilerClient, module: Operation) -> Executable: ...
def compiler_stablehlo_to_executable(
client: CompilerClient, module: Operation, options: StableHLOToExecutableOptions
) -> Executable: ...
Expand Down
16 changes: 16 additions & 0 deletions mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/linspace.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: %pick-one-gpu tensorrt-opt -split-input-file -pass-pipeline="builtin.module(translate-tensorrt-to-engine)" \
// RUN: -mlir-elide-elementsattrs-if-larger=32 -tensorrt-builder-opt-level=0 -tensorrt-strongly-typed %s | FileCheck %s
// RUN: %pick-one-gpu tensorrt-opt -split-input-file -pass-pipeline="builtin.module(translate-tensorrt-to-engine)" \
// RUN: -mlir-elide-elementsattrs-if-larger=32 -tensorrt-builder-opt-level=0 %s | FileCheck %s

// CHECK-LABEL: @dynamic_nd_iota_3
// CHECK-SAME: tensorrt.engine
func.func @dynamic_nd_iota_3(%arg0: tensor<2xi32> {
tensorrt.value_bounds = #tensorrt.shape_profile<min=[1, 3], opt=[4, 3], max=[12, 3]>,
tensorrt.host_tensor
}) -> tensor<?x3xi64> {
%cst_f16 = tensorrt.constant dense<0> : tensor<i64>
%cst_f16_0 = tensorrt.constant dense<[0, 1]> : tensor<2xi64>
%0 = tensorrt.linspace[%cst_f16 : tensor<i64>] [%arg0 : tensor<2xi32>] [%cst_f16_0 : tensor<2xi64>] : tensor<?x3xi64>
return %0 : tensor<?x3xi64>
}
13 changes: 0 additions & 13 deletions mlir-tensorrt/tensorrt/test/Target/TensorRT/linspace.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,3 @@ func.func @dynamic_nd_iota_2(%arg0: tensor<2xi32> {
%0 = tensorrt.linspace[%cst_i32 : tensor<i32>] [%arg0 : tensor<2xi32>] [%cst_i32_0 : tensor<2xi32>] : tensor<?x3xi32>
return %0 : tensor<?x3xi32>
}

// CHECK-LABEL: @dynamic_nd_iota_3
// CHECK-SAME: tensorrt.engine
func.func @dynamic_nd_iota_3(%arg0: tensor<2xi32> {
tensorrt.value_bounds = #tensorrt.shape_profile<min=[1, 3], opt=[4, 3], max=[12, 3]>,
tensorrt.host_tensor
}) -> tensor<?x3xi64> {
%cst_f16 = tensorrt.constant dense<0> : tensor<i64>
%cst_f16_0 = tensorrt.constant dense<[0, 1]> : tensor<2xi64>
%0 = tensorrt.linspace[%cst_f16 : tensor<i64>] [%arg0 : tensor<2xi32>] [%cst_f16_0 : tensor<2xi64>] : tensor<?x3xi64>
return %0 : tensor<?x3xi64>
}

12 changes: 12 additions & 0 deletions mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
}
"""

empty_memref_io = """
func.func @main(%arg0: tensor<5x?x4xf32> {tensorrt.shape_profile = #tensorrt.shape_profile<min = [5, 0, 4], opt = [5, 3, 4], max = [5, 3, 4]>}) -> tensor<5x?x4xf32> {
%1 = stablehlo.add %arg0, %arg0 : (tensor<5x?x4xf32>, tensor<5x?x4xf32>) -> tensor<5x?x4xf32>
func.return %1 : tensor<5x?x4xf32>
}
"""


class Test:
def __init__(self, program: str):
Expand Down Expand Up @@ -97,6 +104,9 @@ def execute(self, arg: runtime.RuntimeValue):
t = Test(main_scalar_io)
print("TEST: function with no output arguments")
t.execute(t.create_memref((5, 3), np.int32))
t = Test(empty_memref_io)
print("TEST: empty tensor validation")
t.execute(t.create_memref((5, 0, 4), np.float32))

# CHECK-LABEL: TEST: runtime shape mismatch
# CHECK: MTRTException: InvalidArgument: InvalidArgument: Input argument 0 validation failed against corresponding function signature arg 0. Reason: InvalidArgument: Runtime shape mismatch. Expected [-9223372036854775808, 3, 4] but received [5, 4, 2]
Expand All @@ -113,3 +123,5 @@ def execute(self, arg: runtime.RuntimeValue):
# CHECK: MTRTException: InvalidArgument: InvalidArgument: Input argument 0 validation failed against corresponding function signature arg 0. Reason: InvalidArgument: function expects a memref type but received scalar type
# CHECK-LABEL: TEST: function with no output arguments
# CHECK: MTRTException: InvalidArgument: InvalidArgument: function expects 0 output args (destination args) but received 1
# CHECK-LABEL: TEST: empty tensor validation
# CHECK: Test passed succesfully

0 comments on commit ff6e8e7

Please sign in to comment.