Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static graph support for "scaled_dot_product_attention" #59498

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,22 +498,62 @@ def scaled_dot_product_attention(
>>> print(output)
>>> # doctest: -SKIP
"""

if attn_mask is None:
# downgraded to ordinary flash attention implementation
out, _ = flash_attention(query, key, value, dropout_p, is_causal)
return out
else:
fixed_seed_offset = (None,)
return_softmax = False
rng_name = ""
out, _ = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
attn_mask,
dropout_p,
is_causal,
return_softmax,
not training,
rng_name,
)
return out
if in_dynamic_mode():
fixed_seed_offset = (None,)
return_softmax = False
rng_name = ""
out, _ = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
attn_mask,
dropout_p,
is_causal,
return_softmax,
not training,
rng_name,
)
return out
else:
helper = LayerHelper('flash_attn', **locals())
dtype = helper.input_dtype(input_param_name='q')
out = helper.create_variable_for_type_inference(dtype)
softmax = helper.create_variable_for_type_inference(dtype)
softmax_lse = helper.create_variable_for_type_inference(
paddle.float32
)
seed_offset = helper.create_variable_for_type_inference(
paddle.int64
)
inputs = {
'q': query,
'k': key,
'v': value,
'attn_mask': attn_mask,
}
outputs = {
'out': out,
'softmax': softmax,
'softmax_lse': softmax_lse,
'seed_offset': seed_offset,
}
helper.append_op(
type='flash_attn',
inputs=inputs,
outputs=outputs,
attrs={
'dropout': dropout_p,
'causal': is_causal,
'return_softmax': False,
'is_test': not training,
'rng_name': '',
},
)
return out