-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
Hey @Zha0q1 , Thanks for submitting the PR
CI supported jobs: [centos-gpu, unix-gpu, miscellaneous, edge, website, windows-cpu, unix-cpu, windows-gpu, sanity, centos-cpu, clang] Note: |
|
||
assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03) | ||
|
||
def verify_one_step_ahead_decoder(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sxjscience would you help take a quick look at this func thanks!
|
||
batch = 7 | ||
seq_length = 16 | ||
C_in = 512 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are C_in and C_out? Should we also test when C_in != C_out
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can refer to this file https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py for C_in and C_out. Those are defined in the pretrained model thus we need to set it the same as in the pretrained model
prefix = "%s/one_step_ahead_decoder" %tmp_path | ||
|
||
# the input data order | ||
perm = [2, 0, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we put the correct order when instantiating the list instead of using perm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used a perm list so that the actual in_shapes an in_types list can have the same order as passed in the native model. It's just the converted onnx takes them in a different order some how. I think this is more consistent, what do you think?
seq_length = 16 | ||
C_in = 512 | ||
C_out = 512 | ||
src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, does not src
need to be int type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it's float in the original mxnet model too. This should not matter I think because the operator will apply ceiling/flooring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
LGTM, thanks |
This pr adds support for the pretrained
transformer_en_de_512
model.We are breaking the transformer into encoder, decoder, embedding and projection and test each part seperately
To get one_step_ahead_decoder to work the seq_len is dynamic