Skip to content

A Python Library for High-Level Semantic Segmentation Models based on TensorFlow and Keras with pretrained backbones.

License

Notifications You must be signed in to change notification settings

JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models

Repository files navigation

TensorFlow Advanced Segmentation Models

A Python Library for High-Level Semantic Segmentation Models.

Preface

Since the breakthrough of Deep Learning and Computer Vision was always one of the core problems that researcher all over the world have worked on, to create better models every day. One Computer Vision area that got huge attention in the last couple of years is Semantic Segmentation. The task to segment every pixel on a given image led to the invention of many great models starting with the classical U-Net up to now more and more complex neural network structures. But even though many new algorithms were developed, the distribution of easy to use open source libraries which contain High-Level APIs that make the technology accessible for everyone is still far behind the huge amount of research that is published continuously.

Inspired by qubvel's segmentation_models this repository builds upon his work and extends it by a variety of recently developed models which achieved great results on the Cityscapes, PASCAL VOC 2012, PASCAL Context, ADE20K dataset and many more.

The library contains to date 14 different Semantic Segmentation Model Architecters for multi-class semantic segmentation as well as many on imagenet pretrained backbones. An important new feature is the upgrade to Tensorflow 2.x including the use of the advanced model subclassing feauture to build customized segmentation models. Further are now all system platforms compatible with the library this means that tasm can run on Windows, Linux and MacOS as well.

Main Library Features

  • High Level API
  • 14 Segmentation Model Architectures for multi-class semantic segmentation
    • New: HRNet + OCR Model
  • Many already pretrained backbones for each architecture
  • Many useful segmentation losses (Dice, Focal, Tversky, Jaccard and many more combinations of them)
  • New: Models can be used as Subclassed or Functional Model
  • New: TASM works now on all platforms, i.e. Windows, Linux, MacOS with Intel or Apple Silicon Chips

Table of Contents

Installation and Setup

To get the repository running just check the following requirements.

Requirements Windows or Linus

  1. Python 3.6 or higher
  2. tensorflow >= 2.3.0 (>= 2.0.0 is sufficient if no efficientnet backbone is used)
  3. numpy
  4. matplotlib

MacOS

  1. Python 3.9 or higher
  2. tensorflow-macos >= 2.5.0
  3. numpy >= 1.21.0
  4. matplotlib

Furthermore just execute the following command to download and install the git repository.

Clone Repository

$ git clone https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git

or directly install it:
Pip Install Repository

$ pip install git+https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git

Training Pipeline

Please check that Tensorflow is installed on your computer.

To import the library just use the standard python import statement:

import tensorflow_advanced_segmentation_models as tasm

Then pick any model backbone from the list below and define weights, height and width:

BACKBONE_NAME = "efficientnetb3"
WEIGHTS = "imagenet"
HEIGHT = 320
WIDTH = 320

Load the data

TrainingGenerator, ValidationGenerator = get_data(...)

Create the base model that works as backbone for the segmentation model:

base_model, layers, layer_names = tasm.create_base_model(name=BACKBONE_NAME, weights=WEIGHTS, height=HEIGHT, width=WIDTH)

Define a Model and compile it with an appropriate loss:

model = tasm.DANet(n_classes=3, base_model=base_model, output_layers=layers, backbone_trainable=False)
model.compile(tf.keras.optimizers.Adam(0.0001), loss=tasm.losses.CategoricalFocalLoss, tasm.metrics.IOUScore(threshold=0.5))

If you want to use the Functional Model class define instead:

model = tasm.DANet(n_classes=3, base_model=base_model, output_layers=layers, backbone_trainable=False).model()
model.compile(tf.keras.optimizers.Adam(0.0001), loss=tasm.losses.CategoricalFocalLoss, tasm.metrics.IOUScore(threshold=0.5))

Now finally train the model:

history = model.fit(
    TrainingGenerator
    batch_size=8,
    epochs=50,
    validation_data=ValidationGenerator
)

You can use the fit_generator method too, e.g. if you want to apply augmentations to the data. For complete training pipelines, go to the Examples folder

Examples

  • [Jupyter Notebook] Multi-class (3 classes) segmentation (sky, building, background) on CamVid dataset here
  • [Jupyter Notebook] Multi-class (11 classes) segmentation on CamVid dataset here
  • [Jupyter Notebook] Multi-class (11 classes) segmentation on CamVid dataset with a custom training loophere
  • [Jupyter Notebook] Two-class (2 classes) segmentation on Caltech-Birds-2010 dataset here

Models and Backbones

Models

Backbones (For Details see here.)

Type Names
VGG 'vgg16' 'vgg19'
ResNet 'resnet50' 'resnet50v2' 'resnet101' 'resnet101v2' 'resnet152' 'resnet152v2'
Xception 'xception'
MobileNet 'mobilenet' 'mobilenetv2' 'mobilenetv3small'
NASNet 'nasnetlarge' 'nasnetmobile'
DenseNet 'densenet121' 'densenet169' 'densenet201'
EfficientNet 'efficientnetb0' 'efficientnetb1' 'efficientnetb2' 'efficientnetb3' 'efficientnetb4' 'efficientnetb5' 'efficientnetb6' efficientnetb7'
All backbones have weights trained on 2012 ILSVRC ImageNet dataset.

Further Model Information

A new feature makes it possible to define the model as a Subclassed Model or as a Functional Model instead. To define the model as a Subclassed Model just write: tasm.UNet to define the UNet or replace it with any other model. If you want to define the Functional Model instead just append .model(), i.e. tasm.UNet.model(). This provides further TensorFlow features like saving the model in the "tf" format.

Citing

@misc{Kezmann:2020,
  Author = {Jan-Marcel Kezmann},
  Title = {Tensorflow Advanced Segmentation Models},
  Year = {2020},
  Publisher = {GitHub},
  Journal = {GitHub repository},
  Howpublished = {\url{https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models}}
}

License

Project is distributed under MIT License.

References

Thank you for all the papers that made this repository possible and especially thank you Pavel Yakubovskiy's initial segmentation models repository.