Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Memory][WIP] Recompute policy in memory optimization #9089

Closed
wants to merge 14 commits into from
Closed
2 changes: 1 addition & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library(threadpool SRCS threadpool.cc DEPS enforce)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)

cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool lod_tensor)
cc_test(scope_test SRCS scope_test.cc DEPS scope)

cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
Expand Down
11 changes: 10 additions & 1 deletion paddle/fluid/framework/scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include <mutex> // for call_once
#include <set>
#include "glog/logging.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/string/printf.h"

Expand Down Expand Up @@ -103,16 +104,24 @@ void Scope::DeleteScope(Scope* scope) {
}
}

void Scope::EraseVars(std::vector<std::string>& var_names) {
void Scope::ReleaseVarsMemory(std::vector<std::string>& var_names) {
std::set<std::string> var_set(var_names.begin(), var_names.end());
std::vector<std::string> delete_vars;
for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) {
delete_vars.push_back(it->first);
delete it->second;
it = vars_.erase(it);
} else {
++it;
}
}

for (auto& var_name : delete_vars) {
Variable* p = this->Var(var_name);
// It only works for LoDTensor.
p->GetMutable<LoDTensor>();
}
}

void Scope::Rename(const std::string& origin_name,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Scope {
/// Create a variable with a scope-unique name.
Variable* Var(std::string* name = nullptr);

void EraseVars(std::vector<std::string>& var_names);
void ReleaseVarsMemory(std::vector<std::string>& var_names);

/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/delete_var_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DeleteVarOp : public framework::OperatorBase {
dev_ctx.Wait();

auto delete_var_names = Inputs("X");
const_cast<framework::Scope &>(scope).EraseVars(delete_var_names);
const_cast<framework::Scope &>(scope).ReleaseVarsMemory(delete_var_names);
}
};

Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from concurrency import (Go, make_channel, channel_send, channel_recv,
channel_close, Select)
import clip
from memory_optimization_transpiler import memory_optimize, release_memory
from memory_optimization_transpiler import memory_optimize, release_memory, recomputation
import profiler
import unique_name
import recordio_writer
Expand Down Expand Up @@ -65,6 +65,7 @@
'DistributeTranspiler',
'memory_optimize',
'release_memory',
'recomputation',
'profiler',
'unique_name',
'recordio_writer',
Expand Down
89 changes: 89 additions & 0 deletions python/paddle/fluid/memory_optimization_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ def _update_skip_opt_set(self):
if op.type() == "fill_constant" and op.attr("force_cpu") == True:
self._skip_opt.update(op.output_arg_names())

def dataflow_analyze(self):
self._dataflow_analyze()

def find_variable_liveness_op(self, var_name):
for i in range(self.op_size):
op = self._ops[i]
in_diff, out_diff = self._get_diff(self._live_in[i],
self._live_out[i])
if var_name in in_diff:
return i

def release_memory(self):
self._dataflow_analyze()
self._update_skip_opt_set()
Expand Down Expand Up @@ -341,3 +352,81 @@ def release_memory(input_program):
cfgs = _get_cfgs(input_program)
for cfg in cfgs:
cfg.release_memory()


activation_ops = [
"sigmoid", "logsigmoid", "exp", "relu", "tanh", "tanh_shrink", "softshrink",
"sqrt", "abs", "ceil", "floor", "round", "reciprocal", "log", "square",
"softplus", "softsign", "brelu", "leaky_relu", "soft_relu", "elu", "relu6",
"pow", "stanh", "hard_shrink", "thresholded_relu", "hard_sigmoid", "swish"
]


def _check_match_pattern(ops, pattern):
match_flag = True
for j in range(len(ops)):
if pattern[j] == "activation":
if ops[j].type() not in activation_ops:
match_flag = False
break
else:
if ops[j].type() != pattern[j]:
match_flag = False
break
return match_flag


def _find_forward_index(block_desc):
pattern = ["mean", "fill_constant", "mean_grad"]
op_size = block_desc.op_size()
for i in range(op_size - len(pattern) + 1):
ops = [block_desc.op(i + j) for j in range(len(pattern))]
match_flag = _check_match_pattern(ops, pattern)
if match_flag:
return i


def recomputation(input_program, pattern="activation"):
pdesc = input_program.get_desc()
block_desc = pdesc.block(0)
op_size = block_desc.op_size()

match_ops = []
for i in range(op_size):
op = block_desc.op(i)
if pattern == "activation":
if op.type() in activation_ops:
match_ops.append(op)

forward_index = _find_forward_index(block_desc)
forward_ops = [block_desc.op(i) for i in range(forward_index)]
forward_cfg = ControlFlowGraph(input_program, forward_ops, forward_index,
set())
forward_cfg.dataflow_analyze()

match_indexs = []
for op in match_ops:
# print op.type()
# print op.output("Out")[0]
index = forward_cfg.find_variable_liveness_op(op.output("Out")[0])
# print index
match_indexs.append(index)

forward_hit = zip(match_ops, match_indexs)
forward_hit = sorted(forward_hit, key=lambda x: x[1])

for i, (fwd_op, fwd_index) in enumerate(forward_hit):
delete_op = block_desc.insert_op(fwd_index + i + 1)
delete_op.set_type("delete_var")
delete_op.set_input("X", fwd_op.output("Out"))

current_op_size = block_desc.op_size()
for i, (fwd_op, fwd_index) in enumerate(forward_hit):
bwd_index = 0
for j in range(forward_index + len(forward_hit), current_op_size):
bwd_op = block_desc.op(j)
if fwd_op.output("Out")[0] in bwd_op.input_arg_names():
new_op = block_desc.insert_op(j + bwd_index)
new_op.copy_from(fwd_op)
bwd_index += 1
break
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,11 @@ def conv_block(input, num_filter, groups, dropouts):
batch_size = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size)

# fluid.memory_optimize(fluid.default_main_program(), level=0)
fluid.recomputation(fluid.default_main_program())
fluid.release_memory(fluid.default_main_program())

# fluid.memory_optimize(fluid.default_main_program(), level=0)

BATCH_SIZE = 16
PASS_NUM = 1

Expand Down