Skip to content

Commit

Permalink
release codes for training and evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiaming Zhang committed Dec 27, 2021
1 parent 74a3d37 commit b42ddde
Show file tree
Hide file tree
Showing 52 changed files with 68,685 additions and 6 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__/
*.pth
run/
117 changes: 111 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,130 @@
# ISSAFE: Improving Semantic Segmentation in Accidents by Fusing Event-based Data
# ISSAFE & EDCNet

## Introduction

This is the implementation of our papers below, including the code and dataset (DADA-seg).

Improving Semantic Segmentation in Accidents by Fusing Event-based Data, IROS 2021, [[paper](arxiv.org/pdf/2008.08974.pdf)].

Exploring Event-Driven Dynamic Context for Accident Scene Segmentation, T-ITS 2021, [[paper](arxiv.org/pdf/2112.05006.pdf)].


![issafe_demo](demo/issafe.gif)

**ISSAFE**, is a multi-modal semantic segmentation architecture by fusing the RGB image and the event-based data, aiming to improve the model's robustness against driving edge-cases (critical or accidental situations). More details can be found in our [paper](https://arxiv.org/pdf/2008.08974.pdf).


## Installation

## Requirement
The requirements are listed in the `requirement.txt` file. To create your own environment, an example is:

TODO
```bash
conda create -n issafe python=3.7
conda activate issafe
cd /path/to/ISSAFE
pip install -r requirement.txt
```



## Datasets

Get dataset from [Cityscapes](https://www.cityscapes-dataset.com/), and [DADA-2000](https://github.com/JWFangit/LOTVS-DADA). Our proposed DADA-seg dataset is a subset from DADA-2000. Our annotations have the same labeling rule as Cityscapes and will be release soon.
For the basic setting of this work, please prepare datasets of [Cityscapes](https://www.cityscapes-dataset.com/), and DADA-seg.

Our proposed DADA-seg dataset is a subset from [DADA-2000](https://github.com/JWFangit/LOTVS-DADA). Our annotations have the same labeling rule as Cityscapes.

The DADA-seg dataset is now available in [Google Drive]().

The event generation can be found in [EventGAN](https://github.com/alexzzhu/EventGAN). The anchor and its previous frames are needed for event generation. The generated event volume is saved as `.npy` format for this work.

A structure of dataset is following:

```
dataset
├── Cityscapes
│   ├── event
│   │   ├── train
│   │ │ ├─aachen
│ │ │ │ ├─aachen_000000_000019_gtFine_event.npy # event volume
│   │   └── val
│   ├── gtFine
│   │   ├── train
│   │   └── val
│   ├─leftImg8bit_prev # for event synthesic
│   │ ├─train
│   │ │ ├─aachen
│   │ │ │ ├─aachen_000000_000019_leftImg8bit_prev.png
│   │ └─val
│   ├── leftImg8bit
│   │   ├── train
│   └── └── val
└── DADA_seg
├── dof
│   └── val
├── event
│   └── val
├── gtFine
│   └── val
└── leftImg8bit
   ├── train
   └── val
```

(option) other sources: [BDD3K](https://bdd-data.berkeley.edu/), [KITTI-360](http://www.cvlibs.net/datasets/kitti-360/), [ApolloScape](http://apolloscape.auto/).

(option) other modalities: dense optical flow.



## Training

The model of EDCNet can be found at `models/edcnet.py`.

Before run the training script, please modify your own path configurations at `mypath.py`.

The training configurations can be adjusted at `train.py`.

An example of training is `python train.py`



## Evaluation

The evaluation configurations can be adjusted at `eval.py`.

To achieve the evaluation result of EDCNet in D2S mode with 2 event time bins, the weights can be downloaded in [Google Drive](https://drive.google.com/drive/folders/19hUd8Mfj6K76G48AT9txq-PX9bHQN0qs?usp=sharing).

Put the weight at `run/cityscapesevent/test_EDCNet_r18/model_best.pth`.

An example of evaluation of the EDCNet at `B=2` event time bins is `python eval.py`.



## License

This repository is under the Apache-2.0 license. For commercial use, please contact with the authors.



## Citation

If you are interested in this work, please cite the following work:

```
@INPROCEEDINGS{zhang2021issafe,
author={Zhang, Jiaming and Yang, Kailun and Stiefelhagen, Rainer},
booktitle={2021 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
title={ISSAFE: Improving Semantic Segmentation in Accidents by Fusing Event-based Data},
year={2021},
pages={1132-1139},
doi={10.1109/IROS51168.2021.9636109}}
@ARTICLE{zhang2021edcnet,
author={Zhang, Jiaming and Yang, Kailun and Stiefelhagen, Rainer},
journal={IEEE Transactions on Intelligent Transportation Systems},
title={Exploring Event-Driven Dynamic Context for Accident Scene Segmentation},
year={2021},
pages={1-17},
doi={10.1109/TITS.2021.3134828}}
```

51 changes: 51 additions & 0 deletions dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from dataloaders import cityscapes, dada, apolloscape, kitti, bdd, merge3
from torch.utils.data import DataLoader


def make_data_loader(args, **kwargs):
if args.dataset == 'cityscapesevent':
train_set = cityscapes.CityscapesRGBEvent(args, split='train')
val_set = cityscapes.CityscapesRGBEvent(args, split='val')
num_class = val_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs)
test_loader = None
elif args.dataset == 'dadaevent':
val_set = dada.DADARGBEvent(args, split='val')
num_class = val_set.NUM_CLASSES
train_loader = None
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs)
test_loader = None
elif args.dataset == 'apolloscapeevent':
train_set = apolloscape.ApolloscapeRGBEvent(args, split='train')
val_set = apolloscape.ApolloscapeRGBEvent(args, split='val')
num_class = val_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs)
test_loader = None
elif args.dataset == 'kittievent':
train_set = kitti.KITTIRGBEvent(args, split='train')
val_set = kitti.KITTIRGBEvent(args, split='val')
num_class = val_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs)
test_loader = None
elif args.dataset == 'bdd':
train_set = bdd.BDD(args, split='train')
val_set = bdd.BDD(args, split='val')
num_class = val_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs)
test_loader = None
elif args.dataset == 'merge3':
train_set = merge3.Merge3(args, split='train')
val_set = merge3.Merge3(args, split='val')
num_class = val_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs)
test_loader = None
else:
raise NotImplementedError
return train_loader, val_loader, test_loader, num_class


141 changes: 141 additions & 0 deletions dataloaders/apolloscape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import os
import numpy as np
from PIL import Image
from torch.utils import data
from mypath import Path
from torchvision import transforms
from dataloaders import custom_transforms as tr
from dataloaders.mapping import APOLLO2CS, APOLLO16


class ApolloscapeRGBEvent(data.Dataset):
"""return dict with img, event, label of Apolloscape"""
NUM_CLASSES = 16

def __init__(self, args, root=Path.db_root_dir('apolloscapeevent'), split="train"):
self.root = root
self.split = split
self.args = args
self.images = {}
self.event = {}
self.labels = {}

with open('dataloaders/apolloscape_txt/colors_{}.txt'.format(split), 'r') as colors_f, \
open('dataloaders/apolloscape_txt/events_{}.txt'.format(split), 'r') as events_f, \
open('dataloaders/apolloscape_txt/labels_{}.txt'.format(split), 'r') as labels_f:
self.images[split] = colors_f.read().splitlines()
self.event[split] = events_f.read().splitlines()
self.labels[split] = labels_f.read().splitlines()
print("Found %d %s RGB images" % (len(self.images[split]), split))

self.ignore_index = 255

def __len__(self):
return len(self.labels[self.split])

def __getitem__(self, index):
sample = dict()
lbl_path = self.root + self.labels[self.split][index].rstrip()
sample['label'] = self.relabel(lbl_path)

img_path = self.root + self.images[self.split][index].rstrip()
sample['image'] = Image.open(img_path).convert('RGB')

if self.args.event_dim:
event_path = self.root + self.event[self.split][index].rstrip()
sample['event'] = self.get_event(event_path)

# data augment
if self.split == 'train':
return self.transform_tr(sample)
elif self.split == 'val':
return self.transform_val(sample), lbl_path
elif self.split == 'test':
raise NotImplementedError

def get_event(self, event_path):
event_volume = np.load(event_path)['data']
neg_volume = event_volume[:9, ...]
pos_volume = event_volume[9:, ...]
if self.args.event_dim == 18:
event_volume = np.concatenate((neg_volume, pos_volume), axis=0)
elif self.args.event_dim == 2:
neg_img = np.sum(neg_volume, axis=0, keepdims=True)
pos_img = np.sum(pos_volume, axis=0, keepdims=True)
event_volume = np.concatenate((neg_img, pos_img), axis=0)
elif self.args.event_dim == 1:
neg_img = np.sum(neg_volume, axis=0, keepdims=True)
pos_img = np.sum(pos_volume, axis=0, keepdims=True)
event_volume = neg_img + pos_img
return event_volume

def relabel(self, label_path):
"""from apollo to the 18 class (Cityscapes without 'train', cls=16)"""
_temp = np.array(Image.open(label_path))
label_mapping1 = {m[0]: m[1] for m in APOLLO2CS}
for k, v in label_mapping1.items():
_temp[_temp == k] = v

label_mapping2 = {m[0]: m[1] for m in APOLLO16}
for k, v in label_mapping2.items():
_temp[_temp == k] = v
return Image.fromarray(_temp.astype(np.uint8))

def transform_tr(self, sample):
composed_transforms = transforms.Compose([
tr.FixedResize(size=(1024, 2048)),
tr.ColorJitter(),
tr.RandomGaussianBlur(),
tr.RandomMotionBlur(),
tr.RandomHorizontalFlip(),
tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255),
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
tr.ToTensor()])
return composed_transforms(sample)

def transform_val(self, sample):
composed_transforms = transforms.Compose([
tr.FixedResize(size=self.args.crop_size),
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
tr.ToTensor()])
return composed_transforms(sample)


if __name__ == '__main__':
from dataloaders.utils import decode_segmap
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import argparse

parser = argparse.ArgumentParser()
args = parser.parse_args()
args.base_size = 1024
args.crop_size = 512
args.event_dim = 2

apolloscape_train = ApolloscapeRGBEvent(args, split='train')
dataloader = DataLoader(apolloscape_train, batch_size=2, shuffle=True, num_workers=0)

for ii, sample in enumerate(dataloader):
for jj in range(sample["image"].size()[0]):
img = sample['image'].numpy()
gt = sample['label'].numpy()
tmp = np.array(gt[jj]).astype(np.uint8)
segmap = decode_segmap(tmp, dataset='apolloscapeevent')
img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
img_tmp *= (0.229, 0.224, 0.225)
img_tmp += (0.485, 0.456, 0.406)
img_tmp *= 255.0
img_tmp = img_tmp.astype(np.uint8)
plt.figure()
plt.title('display')
plt.subplot(311)
plt.imshow(img_tmp)
plt.subplot(312)
plt.imshow(segmap)

if ii == 2:
break

plt.show(block=True)

Loading

0 comments on commit b42ddde

Please sign in to comment.