Skip to content

Commit

Permalink
fix contrib interleaved_matmul_selfatt_valatt not render correctly (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
ys2843 committed Jun 29, 2020
1 parent f3c7b13 commit 2c16502
Showing 1 changed file with 29 additions and 24 deletions.
53 changes: 29 additions & 24 deletions src/operator/contrib/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -701,14 +701,16 @@ of queries, keys and values following the layout:
and the attention weights following the layout:
(batch_size, seq_length, seq_length)
the equivalent code would be:
tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
v_proj = mx.nd.transpose(tmp[:,:,:,2,:], axes=(1, 2, 0, 3))
v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(attention, v_proj)
output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True)
output = mx.nd.transpose(output, axes=(2, 0, 1, 3))
output = mx.nd.reshape(output, shape=(0, 0, -1))
the equivalent code would be::
tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
v_proj = mx.nd.transpose(tmp[:,:,:,2,:], axes=(1, 2, 0, 3))
v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(attention, v_proj)
output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True)
output = mx.nd.transpose(output, axes=(2, 0, 1, 3))
output = mx.nd.reshape(output, shape=(0, 0, -1))
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
Expand Down Expand Up @@ -745,14 +747,16 @@ the inputs must be a tensor of projections of queries following the layout:
and a tensor of interleaved projections of values and keys following the layout:
(seq_length, batch_size, num_heads * head_dim * 2)
the equivalent code would be:
q_proj = mx.nd.transpose(queries, axes=(1, 2, 0, 3))
q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True)
q_proj = mx.nd.contrib.div_sqrt_dim(q_proj)
tmp = mx.nd.reshape(keys_values, shape=(0, 0, num_heads, 2, -1))
k_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3))
k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True)
the equivalent code would be::
q_proj = mx.nd.transpose(queries, axes=(1, 2, 0, 3))
q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True)
q_proj = mx.nd.contrib.div_sqrt_dim(q_proj)
tmp = mx.nd.reshape(keys_values, shape=(0, 0, num_heads, 2, -1))
k_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3))
k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True)
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
Expand Down Expand Up @@ -790,15 +794,16 @@ keys and values following the layout:
and the attention weights following the layout:
(batch_size, seq_length, seq_length)
the equivalent code would be:
the equivalent code would be::
tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
v_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3))
v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(attention, v_proj, transpose_b=True)
output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True)
output = mx.nd.transpose(output, axes=(0, 2, 1, 3))
output = mx.nd.reshape(output, shape=(0, 0, -1))
tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
v_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3))
v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(attention, v_proj, transpose_b=True)
output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True)
output = mx.nd.transpose(output, axes=(0, 2, 1, 3))
output = mx.nd.reshape(output, shape=(0, 0, -1))
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
Expand Down

0 comments on commit 2c16502

Please sign in to comment.