Skip to content

Commit

Permalink
Cxx prim custom vjp (#8)
Browse files Browse the repository at this point in the history
* [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557)

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix cast prim and vjp dtype mapping error bug

* [dy2static-ci] fix dy2static ci errors.

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
Co-authored-by: jiangcheng <thisjiang@qq.com>
Co-authored-by: cxxly <chenxx_id@163.com>
  • Loading branch information
4 people committed Mar 14, 2023
1 parent 5dda91a commit e4a93b0
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 18 deletions.
1 change: 0 additions & 1 deletion python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -3751,7 +3751,6 @@ def __init__(self, program, idx):
self.vars = collections.OrderedDict() # var_name --> var
self.ops = list() # operator list
self.program = program
self.removed_vars = collections.OrderedDict()

def __str__(self):
return self._to_readable_code()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ def train(self, use_prim):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops)

Expand Down Expand Up @@ -128,7 +133,12 @@ def train(self, use_prim):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
all_ops = [
op.type
for op in net.forward.program_cache.last()[-1][-1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _train(self, use_prim, approximate, data):
net = apply_to_static(net, use_prim)

res = []
self.x = data
for _ in range(10):
out = net(data)
loss = paddle.mean(out)
Expand All @@ -92,7 +93,12 @@ def _train(self, use_prim, approximate, data):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that gelu is splitted into small ops
self.assertTrue('gelu' not in fwd_ops)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ def train(self, use_prim):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x, self.w, self.b)[
1
]
.train_program.block(0)
.ops
]
# Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops)

Expand Down Expand Up @@ -150,7 +157,14 @@ def train(self, use_prim):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x, self.w, self.b)[
1
]
.train_program.block(0)
.ops
]
# Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _train(self, use_prim, data, axis, keep_dim):
net = apply_to_static(net, use_prim)

res = []
self.x = data
for _ in range(10):
out = net(data)
loss = paddle.mean(out, axis, keep_dim)
Expand All @@ -99,7 +100,12 @@ def _train(self, use_prim, data, axis, keep_dim):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops)

Expand Down Expand Up @@ -150,6 +156,7 @@ def _train(self, use_prim, data, axis, keep_dim):
net = apply_to_static(net, use_prim)

res = []
self.x = data
for _ in range(10):
out = net(data)
loss = paddle.mean(out, axis, keep_dim)
Expand All @@ -166,7 +173,12 @@ def _train(self, use_prim, data, axis, keep_dim):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops)

Expand Down
20 changes: 12 additions & 8 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,7 @@ def _create_pure_fp16_program(self, is_infer_mode=False):
def _create_forward_backward_train_program(self):
whole_program = self._train_program
# _, forward_end_op_index = self._infer_info('fp32', self._create_program)
forward_end_op_index = self._forward_end_index_map[
_hash_with_id(whole_program, self)
]
forward_end_op_index = self.get_forward_end_op_idx(whole_program)
assert forward_end_op_index >= 0

return self._get_forward_backward_program_form(
Expand Down Expand Up @@ -438,11 +436,14 @@ def _infer_pure_fp16_program_id(self):
def _param_grad_names(self):
return _param_grad_names(self._train_program.desc, self._params)

def get_forward_end_op_idx(self, program):
return self._forward_end_index_map[_hash_with_id(program, self)]

@LazyInitialized
def _out_grad_names(self):
return _out_grad_names(
self._train_program.desc,
self._create_program(is_infer_mode=True).desc.block(0).op_size(),
self.get_forward_end_op_idx(self._train_program),
len(self._outputs.var_ids),
)

Expand Down Expand Up @@ -642,6 +643,7 @@ def _append_backward_desc(self, main_program):
if isinstance(out, framework.Variable):
targets.append(program.global_block().var(out.name))

start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
if targets:
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled()
Expand All @@ -652,12 +654,11 @@ def _append_backward_desc(self, main_program):
program, start_idx = self._hooker.after_append_backward(
self, program, start_idx
)
self._forward_end_index_map[
_hash_with_id(program, self)
] = start_idx - len(self._outputs.tolist())
# TODO: prim make this complicate
self.prepare_gradient_aggregation(start_idx, main_program, program)

self._forward_end_index_map[
_hash_with_id(program, self)
] = start_idx - len(self._outputs.tolist())
return program

def _prune_unused_params(self, program):
Expand Down Expand Up @@ -1155,5 +1156,8 @@ def add_build_strategy_for(
if hasattr(compiled_program._program, 'lr_sheduler'):
builded_program.lr_sheduler = compiled_program._program.lr_sheduler
else:
# can't just create a new program, we need copy the vardesc.
builded_program = paddle.static.Program()
for var in program.block(0).vars.values():
builded_program.block(0)._clone_variable(var, False)
return builded_program
1 change: 0 additions & 1 deletion python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,6 @@ def after_infer(self, partial_program_layer, infer_program):
partial_program.set_hooker(PrimHooker())
return concrete_program, partial_program


def __getitem__(self, item):
if not isinstance(item, CacheKey):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,7 +1568,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
):
op = program_desc.block(0).op(i)
if op.type() == 'fill_any_like':
if op.type() in ['fill_any_like', "fill_constant"]:
var_name = op.output('Out')[0]
names.append(var_name)
return names
Expand Down

0 comments on commit e4a93b0

Please sign in to comment.