Skip to content

Commit

Permalink
[MXNET-555] Add subgraph storage type inference to CachedOp (apache#1…
Browse files Browse the repository at this point in the history
…1306)

* copy paste

* pass unit test

* remove lock

* save all inputs and outputs

* add one more test

* update test

* update backward stype inference

* + fwd inference
  • Loading branch information
eric-haibin-lin authored and zheng-da committed Jun 28, 2018
1 parent 18f8aba commit 7fd12f3
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 36 deletions.
147 changes: 119 additions & 28 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "./cached_op.h"
#include "../executor/exec_pass.h"
#include "../profiler/profiler.h"
#include "../operator/operator_common.h"


namespace mxnet {
Expand Down Expand Up @@ -95,7 +96,6 @@ CachedOp::CachedOp(
using namespace imperative;
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
static const auto _copy = Op::Get("_copy");

config_.Init(flags);

if (config_.static_shape) {
Expand Down Expand Up @@ -204,26 +204,17 @@ CachedOp::CachedOp(
size_t num_forward_outputs = num_outputs();
for (uint32_t i = 0; i < ograd_entries_.size(); ++i) {
if (!idx.exist(ograd_entries_[i].node.get())) continue;
auto eid = idx.entry_id(ograd_entries_[i]);
if (ref_count[eid] > 0) {
bwd_ograd_dep_.push_back(i);
}
bwd_ograd_dep_.push_back(i);
}
save_inputs_.resize(num_forward_inputs, false);
for (uint32_t i = 0; i < num_forward_inputs; ++i) {
auto eid = idx.entry_id(idx.input_nodes()[i], 0);
if (ref_count[eid] > 0) {
save_inputs_[i] = true;
bwd_in_dep_.push_back(i);
}
save_inputs_[i] = true;
bwd_in_dep_.push_back(i);
}
save_outputs_.resize(idx.outputs().size(), false);
for (uint32_t i = 0; i < num_forward_outputs; ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
if (ref_count[eid] > 0) {
save_outputs_[i] = true;
bwd_out_dep_.push_back(i);
}
save_outputs_[i] = true;
bwd_out_dep_.push_back(i);
}
}
}
Expand All @@ -233,7 +224,7 @@ CachedOp::~CachedOp() {

std::vector<nnvm::NodeEntry> CachedOp::Gradient(
const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>& ograds) {
const std::vector<nnvm::NodeEntry>& ograds) const {
using namespace nnvm;
static const auto _backward_CachedOp = Op::Get("_backward_CachedOp");
static const auto _NoGrad = Op::Get("_NoGradient");
Expand Down Expand Up @@ -328,6 +319,27 @@ bool CachedOp::SetForwardGraph(
return false;
}

// Utility function to set backward input eids
void SetBackwardInputEid(const std::vector<uint32_t>& bwd_in_dep,
const std::vector<uint32_t>& bwd_out_dep,
const std::vector<uint32_t>& bwd_ograd_dep,
const std::vector<nnvm::NodeEntry>& ograd_entries,
const nnvm::IndexedGraph& idx,
std::vector<uint32_t> *bwd_input_eid) {
for (const auto& i : bwd_ograd_dep) {
auto eid = idx.entry_id(ograd_entries[i]);
bwd_input_eid->push_back(eid);
}
for (const auto& i : bwd_in_dep) {
auto eid = idx.entry_id(idx.input_nodes()[i], 0);
bwd_input_eid->push_back(eid);
}
for (const auto& i : bwd_out_dep) {
auto eid = idx.entry_id(idx.outputs()[i]);
bwd_input_eid->push_back(eid);
}
}

bool CachedOp::SetBackwardGraph(
GraphInfo* info,
const std::vector<OpReqType>& reqs,
Expand Down Expand Up @@ -356,18 +368,8 @@ bool CachedOp::SetBackwardGraph(

if (info->bwd_input_eid.size() != inputs.size()) {
info->bwd_input_eid.clear();
for (const auto& i : bwd_ograd_dep_) {
auto eid = idx.entry_id(ograd_entries_[i]);
info->bwd_input_eid.push_back(eid);
}
for (const auto& i : bwd_in_dep_) {
auto eid = idx.entry_id(idx.input_nodes()[i], 0);
info->bwd_input_eid.push_back(eid);
}
for (const auto& i : bwd_out_dep_) {
auto eid = idx.entry_id(idx.outputs()[i]);
info->bwd_input_eid.push_back(eid);
}
SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_,
ograd_entries_, idx, &info->bwd_input_eid);
CHECK_EQ(inputs.size(), info->bwd_input_eid.size());
}

Expand Down Expand Up @@ -1019,6 +1021,79 @@ void CachedOp::Backward(
Engine::Get()->set_bulk_size(prev_bulk_size);
}

bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
using namespace imperative;
nnvm::Graph g(fwd_graph_);
const auto& idx = g.indexed_graph();
const auto &outputs = idx.outputs();

// Prepare stypes and contexts based on inputs
StorageTypeVector storage_type_inputs;
storage_type_inputs.reserve(in_attrs->size());
for (size_t i = 0; i < in_attrs->size(); ++i) {
storage_type_inputs.emplace_back(in_attrs->at(i));
}
exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);

// Forward graph storage type inference
CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true);
// Retrieve result and set outputs
const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
for (size_t i = 0; i < out_attrs->size(); i++) {
const auto eid = idx.entry_id(outputs[i]);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
}

bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
using namespace imperative;
nnvm::Graph g(full_graph_);
const auto& idx = g.indexed_graph();
const auto &outputs = idx.outputs();
const size_t num_forward_outputs = fwd_graph_.outputs.size();
CHECK_EQ(outputs.size(), num_forward_outputs + out_attrs->size());

// Construct bwd_input_eid
std::vector<uint32_t> bwd_input_eid;
SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_,
ograd_entries_, idx, &bwd_input_eid);
CHECK_EQ(in_attrs->size(), bwd_input_eid.size());

// Prepare stypes and contexts based on inputs
StorageTypeVector stypes(idx.num_node_entries(), -1);
for (size_t i = 0; i < in_attrs->size(); ++i) {
stypes[bwd_input_eid[i]] = in_attrs->at(i);
}
// Some out_attr is known ahead of time (e.g. the grad stype is given by users).
// Prepare these to before invoking infer storage on the subgraph
for (size_t i = 0; i < out_attrs->size(); i++) {
const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
stypes[eid] = out_attrs->at(i);
}
exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);

// Full graph storage type inference
CheckAndInferStorageType(&g, std::move(dev_masks), std::move(stypes), false);
// Retrieve result and set outputs
const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
for (size_t i = 0; i < out_attrs->size(); i++) {
const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
}


NNVM_REGISTER_OP(_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs) {
Expand All @@ -1029,6 +1104,14 @@ NNVM_REGISTER_OP(_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->num_outputs();
})
.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
Expand All @@ -1044,6 +1127,14 @@ NNVM_REGISTER_OP(_backward_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->num_inputs() - op->mutable_input_nodes().size();
})
.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<bool>("TIsBackward", true);

Expand Down
24 changes: 19 additions & 5 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ class CachedOp {
const nnvm::Symbol& sym,
const std::vector<std::pair<std::string, std::string> >& flags);
~CachedOp();
uint32_t num_inputs() {
uint32_t num_inputs() const {
return fwd_graph_.indexed_graph().input_nodes().size();
}
uint32_t num_outputs() {
uint32_t num_outputs() const {
return fwd_graph_.outputs.size();
}
uint32_t num_backward_inputs() {
uint32_t num_backward_inputs() const {
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
}
std::vector<bool>& save_inputs() {
Expand All @@ -86,12 +86,12 @@ class CachedOp {
std::vector<bool>& save_outputs() {
return save_outputs_;
}
const std::unordered_set<uint32_t>& mutable_input_nodes() {
const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return fwd_graph_.indexed_graph().mutable_input_nodes();
}
std::vector<nnvm::NodeEntry> Gradient(
const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>& ograds);
const std::vector<nnvm::NodeEntry>& ograds) const;
void Forward(
const std::shared_ptr<CachedOp>& op_ptr,
const std::vector<NDArray*>& inputs,
Expand All @@ -102,6 +102,20 @@ class CachedOp {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
// forward storage type inference
bool ForwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
// backward storage type inference
bool BackwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);

private:
struct GraphInfo;
Expand Down
1 change: 0 additions & 1 deletion src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,6 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, exec::DevMaskVector&& dev
g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(storage_types));
g = exec::InferStorageType(std::move(g));
}

CHECK_EQ(g.GetAttr<size_t>("storage_type_num_unknown_nodes"), 0U);
return false;
}
Expand Down
4 changes: 2 additions & 2 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
*/
#define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type) \
{ \
if (!type_assign(&(type_array)[index], type)) { \
if (!::mxnet::op::type_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Storage type inconsistent, Provided = " \
<< common::stype_string((type_array)[index]) << ',' \
Expand All @@ -274,7 +274,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
*/
#define DISPATCH_MODE_ASSIGN_CHECK(type_array, index, type) \
{ \
if (!dispatch_mode_assign(&(type_array)[index], type)) { \
if (!::mxnet::op::dispatch_mode_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Dispatch mode inconsistent, Provided = " \
<< common::dispatch_mode_string((type_array)[index]) << ',' \
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,63 @@ def test_legacy_save_params():
model.load_params('test.params', ctx=mx.cpu())


@with_seed()
def test_sparse_hybrid_block_grad():
class Embedding(mx.gluon.HybridBlock):
def __init__(self, num_tokens, embedding_size):
super(Embedding, self).__init__()
self.num_tokens = num_tokens

with self.name_scope():
self.embedding = mx.gluon.nn.Embedding(
num_tokens, embedding_size, sparse_grad=True)

def hybrid_forward(self, F, words):
emb = self.embedding(words)
return emb + F.ones_like(emb)

embedding = Embedding(20, 3)
embedding.initialize()
embedding.hybridize()

with mx.autograd.record():
emb0 = embedding(mx.nd.arange(10)).sum()
emb1 = embedding(mx.nd.arange(10)).sum()
loss = emb0 + emb1
loss.backward()
grad = embedding.embedding.weight.grad().asnumpy()
assert (grad[:10] == 2).all()
assert (grad[10:] == 0).all()

@with_seed()
def test_sparse_hybrid_block():
class Linear(mx.gluon.HybridBlock):
def __init__(self, units):
super(Linear, self).__init__()
with self.name_scope():
self.w = self.params.get('w', shape=(units, units))

def hybrid_forward(self, F, x, w):
return F.dot(x, w)

class SparseBlock(mx.gluon.HybridBlock):
def __init__(self, units):
super(SparseBlock, self).__init__()
with self.name_scope():
self.net = Linear(units)

def hybrid_forward(self, F, x):
return self.net(x) * x

block = SparseBlock(2)
block.initialize()
block.hybridize()
x = mx.nd.ones((2,2)).tostype('csr')
with mx.autograd.record():
z = block(x) + block(x)
z.backward()
assert (block.net.w.grad().asnumpy() == 4).all()

def test_hybrid_static_memory_recording():
net = gluon.model_zoo.vision.get_resnet(
1, 18, pretrained=True, ctx=mx.context.current_context())
Expand Down

0 comments on commit 7fd12f3

Please sign in to comment.