Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model_zoo/gpt-3] Fix bugs from PR-61236 which cleared paddle.jit.dy2static.utils_helper #7989

Merged
merged 2 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
except:
flash_attention = None

try:
from paddle.jit.api import set_dynamic_shape
except:
from paddle.jit.dy2static.utils_helper import set_dynamic_shape

def shard_op_for_sequence_parallel_linear(tgt, mesh):
# FIXME Hack to shard op for module (linear)
# we only shard the second to the last op (matmul) leave the last op (elementwise_add) un-touched
Expand Down Expand Up @@ -1206,7 +1211,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f

attn_mask = model_kwargs["attention_mask"]
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
paddle.jit.dy2static.utils_helper.set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
max_length = paddle.to_tensor(max_length)
while cur_len < max_length:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@
from paddle.nn.functional.flash_attention import flash_attention
except:
flash_attention = None

try:
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
except:
FusedDropoutAdd = None

try:
from paddle.jit.api import set_dynamic_shape
except:
from paddle.jit.dy2static.utils_helper import set_dynamic_shape

def get_attr(layer, name):
if getattr(layer, name, None) is not None:
return getattr(layer, name, None)
Expand Down Expand Up @@ -1501,7 +1507,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f

attn_mask = model_kwargs["attention_mask"]
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
paddle.jit.dy2static.utils_helper.set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
while cur_len < max_length:
# Note(GuoxiaWang): Remove outputs = _forward_(**model_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
except:
flash_attention = None

try:
from paddle.jit.api import set_dynamic_shape
except:
from paddle.jit.dy2static.utils_helper import set_dynamic_shape

def get_attr(layer, name):
if getattr(layer, name, None) is not None:
Expand Down Expand Up @@ -1077,7 +1081,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f

attn_mask = model_kwargs["attention_mask"]
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
paddle.jit.dy2static.utils_helper.set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
if hasattr(paddle.framework, "_no_check_dy2st_diff"):
# TODO(wanghuancoder): _no_check_dy2st_diff is used to turn off the checking of behavior
Expand Down
Loading