From a386366ac34688fe1b96364b1d0b5980a8d73158 Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Tue, 28 Nov 2023 21:43:37 +0800 Subject: [PATCH] Fix fused_rotary_position_embedding in static mode (#59399) --- .../nn/functional/fused_rotary_position_embedding.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index adfcdc233fe56..e9b8f08ae6a82 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -100,6 +100,13 @@ def fused_rotary_position_embedding( out_v = ( helper.create_variable_for_type_inference(dtype=v.dtype) if v else None ) + + outputs = {'out_q': out_q} + if out_k: + outputs.update({'out_k': out_k}) + if out_v: + outputs.update({'out_v': out_v}) + helper.append_op( type='fused_rotary_position_embedding', inputs={ @@ -110,7 +117,7 @@ def fused_rotary_position_embedding( 'cos': cos, 'position_ids': position_ids, }, - outputs={'out_q': out_q, 'out_k': out_k, 'out_v': out_v}, + outputs=outputs, attrs={ 'use_neox_rotary_style': use_neox_rotary_style, },