From 563f512f621029b87d336da6031c1681f8421bd9 Mon Sep 17 00:00:00 2001 From: lchdl Date: Tue, 28 Nov 2023 09:37:37 +0000 Subject: [PATCH 1/2] Added static graph support for 'scaled_dot_product_attention' --- .../paddle/nn/functional/flash_attention.py | 67 ++++++++++++++----- 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 98da4e717feb3..13599b3f78b88 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -498,22 +498,59 @@ def scaled_dot_product_attention( >>> print(output) >>> # doctest: -SKIP """ + if attn_mask is None: + # downgraded to ordinary flash attention implementation out, _ = flash_attention(query, key, value, dropout_p, is_causal) + return out else: - fixed_seed_offset = (None,) - return_softmax = False - rng_name = "" - out, _ = _C_ops.flash_attn( - query, - key, - value, - fixed_seed_offset, - attn_mask, - dropout_p, - is_causal, - return_softmax, - not training, - rng_name, + if in_dynamic_mode(): + fixed_seed_offset = (None,) + return_softmax = False + rng_name = "" + out, _ = _C_ops.flash_attn( + query, + key, + value, + fixed_seed_offset, + attn_mask, + dropout_p, + is_causal, + return_softmax, + not training, + rng_name, + ) + return out + + helper = LayerHelper('flash_attn', **locals()) + dtype = helper.input_dtype(input_param_name='q') + out = helper.create_variable_for_type_inference(dtype) + softmax = helper.create_variable_for_type_inference(dtype) + softmax_lse = helper.create_variable_for_type_inference(paddle.float32) + seed_offset = helper.create_variable_for_type_inference(paddle.int64) + inputs = { + 'q': query, + 'k': key, + 'v': value, + 'fixed_seed_offset': fixed_seed_offset, + 'attn_mask': attn_mask, + } + outputs = { + 'out': out, + 'softmax': softmax, + 'softmax_lse': softmax_lse, + 'seed_offset': seed_offset, + } + helper.append_op( + type='flash_attn', + inputs=inputs, + outputs=outputs, + attrs={ + 'dropout': dropout_p, + 'causal': is_causal, + 'return_softmax': False, + 'is_test': not training, + 'rng_name': '', + }, ) - return out + return out From c6452170e44e6656d6d82011755b5462bddf7345 Mon Sep 17 00:00:00 2001 From: lchdl Date: Wed, 29 Nov 2023 08:09:34 +0000 Subject: [PATCH 2/2] Add static graph support for "scaled_dot_product_attention" --- .../paddle/nn/functional/flash_attention.py | 69 ++++++++++--------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 13599b3f78b88..7e5cc60fa4478 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -521,36 +521,39 @@ def scaled_dot_product_attention( rng_name, ) return out - - helper = LayerHelper('flash_attn', **locals()) - dtype = helper.input_dtype(input_param_name='q') - out = helper.create_variable_for_type_inference(dtype) - softmax = helper.create_variable_for_type_inference(dtype) - softmax_lse = helper.create_variable_for_type_inference(paddle.float32) - seed_offset = helper.create_variable_for_type_inference(paddle.int64) - inputs = { - 'q': query, - 'k': key, - 'v': value, - 'fixed_seed_offset': fixed_seed_offset, - 'attn_mask': attn_mask, - } - outputs = { - 'out': out, - 'softmax': softmax, - 'softmax_lse': softmax_lse, - 'seed_offset': seed_offset, - } - helper.append_op( - type='flash_attn', - inputs=inputs, - outputs=outputs, - attrs={ - 'dropout': dropout_p, - 'causal': is_causal, - 'return_softmax': False, - 'is_test': not training, - 'rng_name': '', - }, - ) - return out + else: + helper = LayerHelper('flash_attn', **locals()) + dtype = helper.input_dtype(input_param_name='q') + out = helper.create_variable_for_type_inference(dtype) + softmax = helper.create_variable_for_type_inference(dtype) + softmax_lse = helper.create_variable_for_type_inference( + paddle.float32 + ) + seed_offset = helper.create_variable_for_type_inference( + paddle.int64 + ) + inputs = { + 'q': query, + 'k': key, + 'v': value, + 'attn_mask': attn_mask, + } + outputs = { + 'out': out, + 'softmax': softmax, + 'softmax_lse': softmax_lse, + 'seed_offset': seed_offset, + } + helper.append_op( + type='flash_attn', + inputs=inputs, + outputs=outputs, + attrs={ + 'dropout': dropout_p, + 'causal': is_causal, + 'return_softmax': False, + 'is_test': not training, + 'rng_name': '', + }, + ) + return out