Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Onnx Support for Transformer #20048

Merged
merged 7 commits into from
Mar 23, 2021
Merged

Conversation

Zha0q1
Copy link
Contributor

@Zha0q1 Zha0q1 commented Mar 19, 2021

This pr adds support for the pretrained transformer_en_de_512 model.
We are breaking the transformer into encoder, decoder, embedding and projection and test each part seperately
To get one_step_ahead_decoder to work the seq_len is dynamic

@Zha0q1 Zha0q1 requested a review from szha as a code owner March 19, 2021 00:59
@mxnet-bot
Copy link

Hey @Zha0q1 , Thanks for submitting the PR
All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands:

  • To trigger all jobs: @mxnet-bot run ci [all]
  • To trigger specific jobs: @mxnet-bot run ci [job1, job2]

CI supported jobs: [centos-gpu, unix-gpu, miscellaneous, edge, website, windows-cpu, unix-cpu, windows-gpu, sanity, centos-cpu, clang]


Note:
Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin.
All CI tests must pass before the PR can be merged.

@lanking520 lanking520 added the pr-awaiting-testing PR is reviewed and waiting CI build and test label Mar 19, 2021
@lanking520 lanking520 added pr-work-in-progress PR is still work in progress and removed pr-awaiting-testing PR is reviewed and waiting CI build and test labels Mar 19, 2021
@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-work-in-progress PR is still work in progress labels Mar 19, 2021
@Zha0q1 Zha0q1 requested a review from sxjscience March 19, 2021 18:16

assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03)

def verify_one_step_ahead_decoder():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sxjscience would you help take a quick look at this func thanks!

@Zha0q1 Zha0q1 changed the title [wip][v1.x] Onnx Support for Transformer [v1.x] Onnx Support for Transformer Mar 19, 2021
@lanking520 lanking520 added pr-work-in-progress PR is still work in progress and removed pr-awaiting-testing PR is reviewed and waiting CI build and test labels Mar 19, 2021
@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-work-in-progress PR is still work in progress labels Mar 19, 2021
@lanking520 lanking520 added pr-work-in-progress PR is still work in progress and removed pr-awaiting-testing PR is reviewed and waiting CI build and test labels Mar 20, 2021

batch = 7
seq_length = 16
C_in = 512
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are C_in and C_out? Should we also test when C_in != C_out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can refer to this file https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py for C_in and C_out. Those are defined in the pretrained model thus we need to set it the same as in the pretrained model

prefix = "%s/one_step_ahead_decoder" %tmp_path

# the input data order
perm = [2, 0, 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we put the correct order when instantiating the list instead of using perm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used a perm list so that the actual in_shapes an in_types list can have the same order as passed in the native model. It's just the converted onnx takes them in a different order some how. I think this is more consistent, what do you think?

seq_length = 16
C_in = 512
C_out = 512
src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, does not src need to be int type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's float in the original mxnet model too. This should not matter I think because the operator will apply ceiling/flooring

@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test labels Mar 22, 2021
@lanking520 lanking520 added pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test labels Mar 22, 2021
Copy link
Contributor

@josephevans josephevans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@waytrue17
Copy link
Contributor

LGTM, thanks

@lanking520 lanking520 added pr-awaiting-review PR is waiting for code review and removed pr-awaiting-testing PR is reviewed and waiting CI build and test labels Mar 22, 2021
@Zha0q1 Zha0q1 merged commit 833cb89 into apache:v1.x Mar 23, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants