Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Node Lowering #914

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device)
: LoweringContext(name, std::forward<BackendDevice>(device)),
graph_(std::make_shared<torch::jit::Graph>()),
function_(
std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)),
mlir_context_(mlirContextCreate()) {
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
RegisterMlirDialects();
}

Expand All @@ -49,16 +50,31 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)),
graph_(std::make_shared<torch::jit::Graph>()),
function_(
std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)),
mlir_context_(mlirContextCreate()) {
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
for (auto node : post_order) {
bool ok = lowering_->Lower(node);
CHECK(ok) << "Failed to lower: " << *node;
Lower(node);
}

RegisterMlirDialects();
}

void TorchMlirLoweringContext::Lower(const Node* node) {
if (auto* torch_mlir_node =
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this);
CHECK(!ops.empty()) << "Failed to lower: " << *node;
CHECK_EQ(node->num_outputs(), ops.size());
for (size_t i = 0; i < ops.size(); ++i) {
AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
}
} else {
throw std::runtime_error(
"Expected torch::lazy::TorchMlirNode but could not dynamic cast");
}
}

void TorchMlirLoweringContext::SetUpAlias(
const std::vector<int64_t>& output_index, int64_t param_number,
const std::vector<int64_t>& param_index, bool must_alias) {
Expand Down Expand Up @@ -136,8 +152,7 @@ torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
if (it == emitted_outputs_.end()) {
auto post_order = Util::ComputePostOrder(output.node, &emit_status_);
for (auto node : post_order) {
bool ok = lowering_->Lower(node);
TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
Lower(node);
}
// At this point the output better be present, otherwise there is an issue
// with the lowering code.
Expand Down
20 changes: 3 additions & 17 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,6 @@
namespace torch {
namespace lazy {

class TORCH_API TorchMlirNodeLoweringInterface {
/**
* This interface is only needed for legacy ops, and can be removed once all
* ops implement LtcMlirNode->lower().
* */
public:
TorchMlirNodeLoweringInterface() = default;

virtual ~TorchMlirNodeLoweringInterface() = default;

virtual bool Lower(const Node* node) = 0;

static std::unique_ptr<TorchMlirNodeLoweringInterface>
Create(LoweringContext* loctx);
};

class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
public:
// Describes an input/output alias as inserted by the SetUpAlias() API.
Expand All @@ -61,6 +45,8 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status);

void Lower(const Node* node);

// Adds a new input/output alias.
void SetUpAlias(
const std::vector<int64_t>& output_index, int64_t param_number,
Expand Down Expand Up @@ -120,11 +106,11 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
// Holds the input/output alias information populated by the SetUpAlias() API.
InputOutputAliases input_output_aliases_;
std::shared_ptr<torch::jit::Graph> graph_;
std::shared_ptr<torch::jit::GraphFunction> function_;
MlirContext mlir_context_;
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
std::vector<torch::jit::Value*> root_tuple_;
OutputMap<torch::jit::Value*> emitted_outputs_;
std::unique_ptr<TorchMlirNodeLoweringInterface> lowering_;
};

class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
Expand Down
6 changes: 0 additions & 6 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,6 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; }

hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }

TorchMlirOpVector TorchMlirNode::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
return {};
}


OpKind TorchMlirTensorList::ClassOpKind() {
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
// import otherwise
Expand Down
Loading