Skip to content

Commit

Permalink
Resolve CI testing failure for Lazy Tensor Core (#1088)
Browse files Browse the repository at this point in the history
* Xfail unsupported ops

* Register FuncDialect

* Include dynamic_ir in build

* Code reformat

* Enable LTC tests for macOS and Source Build
  • Loading branch information
henrytwo committed Jul 29, 2022
1 parent de35244 commit a876601
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 10 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/buildAndTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ jobs:
cd $GITHUB_WORKSPACE
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=tosa -v
- name: Lazy Tensor Core - TorchScript end-to-end tests
run: |
cd $GITHUB_WORKSPACE
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=lazy_tensor_core -v
macOS-x86_64:
name: Build and Test macOS(x86_64) Build (Release Asserts)
Expand Down Expand Up @@ -206,3 +211,8 @@ jobs:
cd $GITHUB_WORKSPACE
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=tosa -v
- name: Lazy Tensor Core - TorchScript end-to-end tests
run: |
cd $GITHUB_WORKSPACE
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=lazy_tensor_core -v
13 changes: 10 additions & 3 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,6 @@
"HBC_basic",
"HardTanhIntModule_basic",
"HardTanhModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
Expand Down Expand Up @@ -268,10 +266,10 @@
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule_basic",
"IndexTensorSelectDimModule_basic",
"Matmul_dot",
"Matmul_matvec",
"Matmul_vecmat",
"MobilenetV3Module_basic",
"MulIntModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
Expand Down Expand Up @@ -308,6 +306,7 @@
"SubFloatModule_basic",
"SubIntModule_basic",
"TableBatchEmbeddingModule_basic",
"TensorsConcatNegativeDimModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToBool_basic",
"TensorToFloatZeroRank_basic",
Expand All @@ -319,6 +318,14 @@
"UniformStaticModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"VarBiasedModule_basic",
"VarDimAllDimReduceModule_basic",
"VarDimBiasedModule_basic",
"VarDimKeepDimFalseModule_basic",
"VarDimModule_basic",
"VarDimMultiDimModule_basic",
"VarDimNegativeModule_basic",
"VarDimSingleDimModule_basic",
"VarDimUnbiasedModule_basic",
"VarUnbiasedModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
}
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_mlir_library(TorchMLIRInitAll
Core

LINK_LIBS PUBLIC
MLIRFuncDialect
MLIRIR
MLIRSupport

Expand Down
2 changes: 2 additions & 0 deletions lib/InitAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "torch-mlir/InitAll.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Dialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
Expand All @@ -20,6 +21,7 @@
#include "torch-mlir/RefBackend/Passes.h"

void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::torch::Torch::TorchDialect>();
registry.insert<mlir::torch::TorchConversion::TorchConversionDialect>();
registry.insert<mlir::torch::TMTensor::TMTensorDialect>();
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ add_library(torch_mlir_ltc_backend SHARED
${LTC_GENERATED}
${LTC_BACKEND_DEPENDS}
backend_impl.cpp
dynamic_ir.cpp
mlir_node.cpp
ops/device_data.cpp
ops/generic.cpp
Expand Down
2 changes: 1 addition & 1 deletion python/torch_mlir/csrc/base_lazy_backend/dynamic_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/dynamic_ir.cpp
//===----------------------------------------------------------------------===//

#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
#include "dynamic_ir.h"

namespace torch {
namespace lazy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
function_(
std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)),
mlir_context_(mlirContextCreate()) {
RegisterMlirDialects();

for (auto node : post_order) {
Lower(node);
}

RegisterMlirDialects();
}

void TorchMlirLoweringContext::Lower(const Node* node) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation {

const std::string debug_string() const;

const std::string to_string() const;
const std::string to_string() const override;

private:
std::vector<std::string> parameter_names_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
ReferenceLazyBackendDeviceType(int8_t device_type)
: device_type_(static_cast<c10::DeviceType>(device_type)) {}

std::string toString() const override { return c10::DeviceTypeName(device_type_); }
std::string toString() const override {
return c10::DeviceTypeName(device_type_);
}

c10::DeviceType device_type_;
};
Expand Down Expand Up @@ -127,11 +129,11 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl {
/**
* Device Configuration
* */
std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() const {
std::shared_ptr<torch::lazy::BackendDeviceType>
GetDefaultDeviceType() const override {
return std::make_shared<BackendDeviceType>(default_device_type_);
}


void SetDefaultDeviceType(int8_t device_type) override {
default_device_type_ = ReferenceLazyBackendDeviceType(device_type);
}
Expand Down

0 comments on commit a876601

Please sign in to comment.