Skip to content

Commit

Permalink
Deprecate TSNodeLoweringInterface (#78273)
Browse files Browse the repository at this point in the history
Fixes #78206

Deprecate `TSNodeLoweringInterface` and refactor lower functions into IR nodes.

CC: @wconstab @desertfire
Pull Request resolved: #78273
Approved by: https://github.com/wconstab
  • Loading branch information
antoniojkim authored and pytorchmergebot committed May 31, 2022
1 parent 032f8d0 commit fe67dff
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 400 deletions.
6 changes: 6 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class TSNativeBatchNormBackward : public torch::lazy::TsNode {

const std::array<bool, 3>& output_mask() const { return output_mask_; }

TSOpVector Lower(std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const override;

private:
bool training_;
double eps_;
Expand Down Expand Up @@ -95,6 +98,9 @@ class TSNativeBatchNormForward : public torch::lazy::TsNode {

double eps() const { return eps_; }

TSOpVector Lower(std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const override;

private:
bool training_;
double momentum_;
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/device_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class TORCH_API DeviceData : public TsNode {
// instead of calling the constructor directly.
static NodePtr Create(std::shared_ptr<BackendData> data);

TSOpVector Lower(std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const override;

private:
std::shared_ptr<BackendData> data_;
};
Expand Down
30 changes: 23 additions & 7 deletions torch/csrc/lazy/ts_backend/ts_lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,37 @@ TSLoweringContext::TSLoweringContext(
const std::string& name,
BackendDevice device)
: torch::lazy::LoweringContext(name, device),
graph_(std::make_shared<torch::jit::Graph>()) {
lowering_ = TSNodeLoweringInterface::Create(this);
}
graph_(std::make_shared<torch::jit::Graph>()),
function_(std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) {}

TSLoweringContext::TSLoweringContext(
const std::string& name,
BackendDevice device,
c10::ArrayRef<Node*> post_order,
Util::EmissionMap emit_status)
: torch::lazy::LoweringContext(name, device, post_order, emit_status),
graph_(std::make_shared<torch::jit::Graph>()) {
lowering_ = TSNodeLoweringInterface::Create(this);
graph_(std::make_shared<torch::jit::Graph>()),
function_(std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) {
for (auto node : post_order) {
bool ok = lowering_->Lower(node);
CHECK(ok) << "Failed to lower: " << *node;
Lower(node);
}
}


void TSLoweringContext::Lower(const Node* node) {
if (auto* tsnode = dynamic_cast<const torch::lazy::TsNode*>(node)) {
// First, we call the node lowering function, which exists for newly
// codegenned or refactored nodes
TSOpVector ops = tsnode->Lower(function_, this);
CHECK(!ops.empty()) << "Failed to lower: " << *node;
CHECK_EQ(node->num_outputs(), ops.size());
for (size_t i = 0; i < ops.size(); ++i) {
AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
}
}
else {
throw std::runtime_error(
"Expected torch::lazy::TsNode but could not dynamic cast");
}
}

Expand Down
24 changes: 5 additions & 19 deletions torch/csrc/lazy/ts_backend/ts_lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,6 @@ namespace lazy {

using TSOpVector = std::vector<torch::jit::Value*>;

class TORCH_API TSNodeLoweringInterface {
/**
* This interface is only needed for legacy ops, and can be removed once all
* ops implement TSNode->lower().
* */
public:
TSNodeLoweringInterface() = default;

virtual ~TSNodeLoweringInterface() = default;

virtual bool Lower(const Node* node) = 0;

static std::unique_ptr<TSNodeLoweringInterface> Create(
LoweringContext* loctx);
};

class TORCH_API TSComputation : public Computation {
public:
TSComputation(const std::shared_ptr<torch::jit::Graph>& graph)
Expand Down Expand Up @@ -99,9 +83,12 @@ class TORCH_API TSLoweringContext : public LoweringContext {
size_t index,
const Shape& shape,
const std::string& name) override {

TORCH_INTERNAL_ASSERT(false, "not implemented");
}

void Lower(const Node* node);

ComputationPtr Build() override {
for (torch::jit::Value* output : root_tuple_) {
graph_->block()->registerOutput(output);
Expand All @@ -117,8 +104,7 @@ class TORCH_API TSLoweringContext : public LoweringContext {
if (it == emitted_outputs_.end()) {
auto post_order = Util::ComputePostOrder(output.node, &emit_status_);
for (auto node : post_order) {
bool ok = lowering_->Lower(node);
TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
Lower(node);
}
// At this point the output better be present, otherwise there is an issue
// with the lowering code.
Expand Down Expand Up @@ -157,10 +143,10 @@ class TORCH_API TSLoweringContext : public LoweringContext {
}

std::shared_ptr<torch::jit::Graph> graph_;
std::shared_ptr<torch::jit::GraphFunction> function_;
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
std::vector<torch::jit::Value*> root_tuple_;
OutputMap<torch::jit::Value*> emitted_outputs_;
std::unique_ptr<TSNodeLoweringInterface> lowering_;
};

} // namespace lazy
Expand Down
10 changes: 0 additions & 10 deletions torch/csrc/lazy/ts_backend/ts_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,6 @@ const std::string TsNode::getPythonStacktrace() const {
return GetFirstUserFrameInPythonIfEnabled();
}

TSOpVector TsNode::Lower(std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const {
// TODO(whc) beginning to invert the design here. Move to provide a Lower()
// method on each node, starting with codegen. Once we delete most
// non-codegen ops, make this pure-virtual and put Lower() on the remaining
// non-codegen ops. For now, returning empty list here triggers fallback to
// old lowering path.
return {};
}

TensorList::TensorList(OpList values)
: TsNode(/*op=*/ClassOpKind(),
/*operands=*/values,
Expand Down
Loading

0 comments on commit fe67dff

Please sign in to comment.