Skip to content

Commit

Permalink
Fix fused_rotary_position_embedding in static mode (#59399)
Browse files Browse the repository at this point in the history
  • Loading branch information
From00 authored Nov 28, 2023
1 parent 63bd41a commit a386366
Showing 1 changed file with 8 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ def fused_rotary_position_embedding(
out_v = (
helper.create_variable_for_type_inference(dtype=v.dtype) if v else None
)

outputs = {'out_q': out_q}
if out_k:
outputs.update({'out_k': out_k})
if out_v:
outputs.update({'out_v': out_v})

helper.append_op(
type='fused_rotary_position_embedding',
inputs={
Expand All @@ -110,7 +117,7 @@ def fused_rotary_position_embedding(
'cos': cos,
'position_ids': position_ids,
},
outputs={'out_q': out_q, 'out_k': out_k, 'out_v': out_v},
outputs=outputs,
attrs={
'use_neox_rotary_style': use_neox_rotary_style,
},
Expand Down

0 comments on commit a386366

Please sign in to comment.