Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeepVision Port] SegFormer and Mix-Transformers #1946

Merged
merged 57 commits into from
Aug 24, 2023

Conversation

DavidLandup0
Copy link
Contributor

@DavidLandup0 DavidLandup0 commented Jul 13, 2023

What does this PR do?

As discussed in #1933 - setting up a draft PR for porting SegFormer and associated layers into KCV. Draft PR for now with placeholder main model dump, layers and tests incoming soon. Will tag once ready for review.

Demo Notebooks

Questions and API Considerations

  • DeepLabV3 takes any backbone, but SegFormer is meant to be used with MiT (Mix Transformers), and depends on the output channels which is a field defined in the model. Should we make it generally usable with other backbones? IMO, no, since the head is really just an MLP head, and the crux of the paper is MiT.
  • There's no name for the type of attention they use, but they refer to it as efficient attention in the paper. What name should we use? SegFormerMultiHeadAttention sounds like a mouthful.
  • How do we expose the API if we don't support a backbone argument? Just SegFormer.from_preset()?

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case. Porting DeepVision into KerasCV #1933
  • Did you write any new necessary tests?
  • If this adds a new model, can you run a few training steps on TPU in Colab to ensure that no XLA incompatible OP are used?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ianstenbit just tagging so you can follow the progress as it comes in. Otherwise, no need to spend time until it's un-drafted for review :)

@DavidLandup0
Copy link
Contributor Author

DavidLandup0 commented Jul 17, 2023

Just to update you @ianstenbit - the port is going well, but it took me a bit longer than anticipated to get used to Keras Core + the new API 😅

I ran into a small blocker and documented it here since I'm not sure what the intended usage is when returning tensors and non-tensors from a call(). If you've encountered this before, any idea for a workaround would be greatly appreciated 🙇

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks David! I'm taking a look at the non-tensor return issue.

We have a few options for workarounds, including:

  • Computing the shape outside of the layer (because from my cursory view the returned shape is just the input shape / stride)
  • Making the PatchingEmbeddingLayer offer a new method which computes these values which callers can use.

That said, I think we should be able to make this work. I'm taking a look at your issue on Keras Core to see if I can get a working fix.

@DavidLandup0
Copy link
Contributor Author

Thanks! I opened it as an issue since I'm not sure if it's the intended usage. If so, I'd go with computing the values outside the layer/with an extra method

class MiTBackbone(Backbone):
def __init__(
self,
input_shape=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default to (None, None, 3) so that channel dims can be known at build time for conv layers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This'll have to default to (224, 224, 3) actually, since the input shape will have to be known at instantiation time

@DavidLandup0
Copy link
Contributor Author

DavidLandup0 commented Jul 23, 2023

@ianstenbit looks like MiTs are shaped up. Here's a demo notebook showcasing the components, inputs/output shapes, pyramid levels, from_preset() usage and training MiTs on a classification task: https://colab.research.google.com/drive/1Q3m9-LKICrFzuUhVMIPd7pY2l9Z3BLhg?usp=sharing

There are a couple of weird-looking ops.cast() calls that aren't very clean, and a custom reshaping layer since keras.Reshape() caused errors for some reason. I'd like to clean these up and sync up on whether there's a cleaner alternative for them :)

99% of the work are MiTs - SegFormer is just MiT+seg top. Could you please review the backbone while I shape up SegFormers? With a green light, I'll write up the unit tests and add proper docstrings.

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good! Left you a few minor comments.

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few things to clean up -- in the meantime I am seeing if the tests need any fixing.

keras_cv/layers/hierarchical_transformer_encoder.py Outdated Show resolved Hide resolved
keras_cv/layers/hierarchical_transformer_encoder.py Outdated Show resolved Hide resolved
class SegFormerB0(SegFormer):
def __new__(
cls,
num_classes=19,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't specify a default for num_classes as it needs to be user-specified. A silent default could be very confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. Should it be requested as a mandatory arg?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think this should be required at init time.

@ianstenbit
Copy link
Contributor

/gcbrun

@ianstenbit
Copy link
Contributor

Looks like tests are passing locally, but on CI (for TF) it will depend on us getting a new release of Keras Core which includes keras-team/keras-core#722

In the meantime, @DavidLandup0 I left a few review comments for you to take a look at -- thanks!

@DavidLandup0
Copy link
Contributor Author

Awesome, thanks! Getting to these soon. Thanks for the review pass! :)

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you David!

I think this PR is basically all set, I just need to merge #2037 to fix CI

@ianstenbit
Copy link
Contributor

/gcbrun

@DavidLandup0
Copy link
Contributor Author

Awesome, thank you for the help in the final stretch! @ianstenbit 🎉

@ianstenbit
Copy link
Contributor

It looks like the GCB failure is because I need to update the Docker image of our GCB runners to use the newest Keras Core version -- doing that now.

@ianstenbit
Copy link
Contributor

CI failures are unrelated -- seems like DeepLab + YOLOV8 have some breakages with the latest Keras Core version. I'll open a separate PR for those.

@ianstenbit ianstenbit merged commit ab812d1 into keras-team:master Aug 24, 2023
8 of 9 checks passed
@DavidLandup0
Copy link
Contributor Author

Need a hand with DLV3 or YOLO?

@ianstenbit
Copy link
Contributor

ianstenbit commented Aug 24, 2023

You're welcome to look if you'd like -- for DeepLab it's a deserialization issue. Haven't looked at YOLO yet.

You can repro by installing latest Keras Core version and running the large tests of those models with TF backend.
edit: probably best to just work on CLIP instead -- I can handle this part, it's not very fun anyway!

@DavidLandup0
Copy link
Contributor Author

Sure! Sign me up for YOLO if it's not too urgent then :)

ghost pushed a commit to y-vectorfield/keras-cv that referenced this pull request Nov 16, 2023
* initial dump

* add all basic layers, port roughly to keras core ops

* updated .gitignore

* segformer head and formatting

* cleanup

* remove tf call

* remove tf

* migrating to more keras ops

* cleanups and fixes

* fix reshaping

* comments

* from presets api, keras.ops -> ops

* embed_dims -> embedding_dims

* addressing some PR comments

* docstrings, argument update

* depths arg

* sync

* compute output shapes

* segformer progress

* head

* softmax

* remove softmax

* undo compute_output_shapes()

* efficientmultiheadattention -> segformermultiheadattention

* docstrings

* softmax output

* segformer presets

* updating segformer presets

* segformer presets

* import aliases

* refactoring

* pr comments

* pr comments

* add aliases

* aliases ot init

* refactor fix

* import keras_cv_export

* fix presets/aliases and add copyright

* linter warnings

* linter errors

* consistency in presets

* return config

* fix serialization

* Some cleanup + more tests

* Fix DropPath layer (need to update tests + add shim for tf.keras

* Finish DropPath layer

* Use static shape in backbone

* Formatting

* Switch back to ops.shape

* documentation

* documentation

* remove default num classes

* fix docs

---------

Co-authored-by: ianjjohnson <3072903+ianstenbit@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants