-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
release codes for training and evaluation
- Loading branch information
Jiaming Zhang
committed
Dec 27, 2021
1 parent
74a3d37
commit b42ddde
Showing
52 changed files
with
68,685 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__pycache__/ | ||
*.pth | ||
run/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,130 @@ | ||
# ISSAFE: Improving Semantic Segmentation in Accidents by Fusing Event-based Data | ||
# ISSAFE & EDCNet | ||
|
||
## Introduction | ||
|
||
This is the implementation of our papers below, including the code and dataset (DADA-seg). | ||
|
||
Improving Semantic Segmentation in Accidents by Fusing Event-based Data, IROS 2021, [[paper](arxiv.org/pdf/2008.08974.pdf)]. | ||
|
||
Exploring Event-Driven Dynamic Context for Accident Scene Segmentation, T-ITS 2021, [[paper](arxiv.org/pdf/2112.05006.pdf)]. | ||
|
||
|
||
![issafe_demo](demo/issafe.gif) | ||
|
||
**ISSAFE**, is a multi-modal semantic segmentation architecture by fusing the RGB image and the event-based data, aiming to improve the model's robustness against driving edge-cases (critical or accidental situations). More details can be found in our [paper](https://arxiv.org/pdf/2008.08974.pdf). | ||
|
||
|
||
## Installation | ||
|
||
## Requirement | ||
The requirements are listed in the `requirement.txt` file. To create your own environment, an example is: | ||
|
||
TODO | ||
```bash | ||
conda create -n issafe python=3.7 | ||
conda activate issafe | ||
cd /path/to/ISSAFE | ||
pip install -r requirement.txt | ||
``` | ||
|
||
|
||
|
||
## Datasets | ||
|
||
Get dataset from [Cityscapes](https://www.cityscapes-dataset.com/), and [DADA-2000](https://github.com/JWFangit/LOTVS-DADA). Our proposed DADA-seg dataset is a subset from DADA-2000. Our annotations have the same labeling rule as Cityscapes and will be release soon. | ||
For the basic setting of this work, please prepare datasets of [Cityscapes](https://www.cityscapes-dataset.com/), and DADA-seg. | ||
|
||
Our proposed DADA-seg dataset is a subset from [DADA-2000](https://github.com/JWFangit/LOTVS-DADA). Our annotations have the same labeling rule as Cityscapes. | ||
|
||
The DADA-seg dataset is now available in [Google Drive](). | ||
|
||
The event generation can be found in [EventGAN](https://github.com/alexzzhu/EventGAN). The anchor and its previous frames are needed for event generation. The generated event volume is saved as `.npy` format for this work. | ||
|
||
A structure of dataset is following: | ||
|
||
``` | ||
dataset | ||
├── Cityscapes | ||
│ ├── event | ||
│ │ ├── train | ||
│ │ │ ├─aachen | ||
│ │ │ │ ├─aachen_000000_000019_gtFine_event.npy # event volume | ||
│ │ └── val | ||
│ ├── gtFine | ||
│ │ ├── train | ||
│ │ └── val | ||
│ ├─leftImg8bit_prev # for event synthesic | ||
│ │ ├─train | ||
│ │ │ ├─aachen | ||
│ │ │ │ ├─aachen_000000_000019_leftImg8bit_prev.png | ||
│ │ └─val | ||
│ ├── leftImg8bit | ||
│ │ ├── train | ||
│ └── └── val | ||
└── DADA_seg | ||
├── dof | ||
│ └── val | ||
├── event | ||
│ └── val | ||
├── gtFine | ||
│ └── val | ||
└── leftImg8bit | ||
├── train | ||
└── val | ||
``` | ||
|
||
(option) other sources: [BDD3K](https://bdd-data.berkeley.edu/), [KITTI-360](http://www.cvlibs.net/datasets/kitti-360/), [ApolloScape](http://apolloscape.auto/). | ||
|
||
(option) other modalities: dense optical flow. | ||
|
||
|
||
|
||
## Training | ||
|
||
The model of EDCNet can be found at `models/edcnet.py`. | ||
|
||
Before run the training script, please modify your own path configurations at `mypath.py`. | ||
|
||
The training configurations can be adjusted at `train.py`. | ||
|
||
An example of training is `python train.py` | ||
|
||
|
||
|
||
## Evaluation | ||
|
||
The evaluation configurations can be adjusted at `eval.py`. | ||
|
||
To achieve the evaluation result of EDCNet in D2S mode with 2 event time bins, the weights can be downloaded in [Google Drive](https://drive.google.com/drive/folders/19hUd8Mfj6K76G48AT9txq-PX9bHQN0qs?usp=sharing). | ||
|
||
Put the weight at `run/cityscapesevent/test_EDCNet_r18/model_best.pth`. | ||
|
||
An example of evaluation of the EDCNet at `B=2` event time bins is `python eval.py`. | ||
|
||
|
||
|
||
## License | ||
|
||
This repository is under the Apache-2.0 license. For commercial use, please contact with the authors. | ||
|
||
|
||
|
||
## Citation | ||
|
||
If you are interested in this work, please cite the following work: | ||
|
||
``` | ||
@INPROCEEDINGS{zhang2021issafe, | ||
author={Zhang, Jiaming and Yang, Kailun and Stiefelhagen, Rainer}, | ||
booktitle={2021 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, | ||
title={ISSAFE: Improving Semantic Segmentation in Accidents by Fusing Event-based Data}, | ||
year={2021}, | ||
pages={1132-1139}, | ||
doi={10.1109/IROS51168.2021.9636109}} | ||
@ARTICLE{zhang2021edcnet, | ||
author={Zhang, Jiaming and Yang, Kailun and Stiefelhagen, Rainer}, | ||
journal={IEEE Transactions on Intelligent Transportation Systems}, | ||
title={Exploring Event-Driven Dynamic Context for Accident Scene Segmentation}, | ||
year={2021}, | ||
pages={1-17}, | ||
doi={10.1109/TITS.2021.3134828}} | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from dataloaders import cityscapes, dada, apolloscape, kitti, bdd, merge3 | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def make_data_loader(args, **kwargs): | ||
if args.dataset == 'cityscapesevent': | ||
train_set = cityscapes.CityscapesRGBEvent(args, split='train') | ||
val_set = cityscapes.CityscapesRGBEvent(args, split='val') | ||
num_class = val_set.NUM_CLASSES | ||
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) | ||
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) | ||
test_loader = None | ||
elif args.dataset == 'dadaevent': | ||
val_set = dada.DADARGBEvent(args, split='val') | ||
num_class = val_set.NUM_CLASSES | ||
train_loader = None | ||
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) | ||
test_loader = None | ||
elif args.dataset == 'apolloscapeevent': | ||
train_set = apolloscape.ApolloscapeRGBEvent(args, split='train') | ||
val_set = apolloscape.ApolloscapeRGBEvent(args, split='val') | ||
num_class = val_set.NUM_CLASSES | ||
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) | ||
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) | ||
test_loader = None | ||
elif args.dataset == 'kittievent': | ||
train_set = kitti.KITTIRGBEvent(args, split='train') | ||
val_set = kitti.KITTIRGBEvent(args, split='val') | ||
num_class = val_set.NUM_CLASSES | ||
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) | ||
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) | ||
test_loader = None | ||
elif args.dataset == 'bdd': | ||
train_set = bdd.BDD(args, split='train') | ||
val_set = bdd.BDD(args, split='val') | ||
num_class = val_set.NUM_CLASSES | ||
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) | ||
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) | ||
test_loader = None | ||
elif args.dataset == 'merge3': | ||
train_set = merge3.Merge3(args, split='train') | ||
val_set = merge3.Merge3(args, split='val') | ||
num_class = val_set.NUM_CLASSES | ||
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) | ||
val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) | ||
test_loader = None | ||
else: | ||
raise NotImplementedError | ||
return train_loader, val_loader, test_loader, num_class | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import os | ||
import numpy as np | ||
from PIL import Image | ||
from torch.utils import data | ||
from mypath import Path | ||
from torchvision import transforms | ||
from dataloaders import custom_transforms as tr | ||
from dataloaders.mapping import APOLLO2CS, APOLLO16 | ||
|
||
|
||
class ApolloscapeRGBEvent(data.Dataset): | ||
"""return dict with img, event, label of Apolloscape""" | ||
NUM_CLASSES = 16 | ||
|
||
def __init__(self, args, root=Path.db_root_dir('apolloscapeevent'), split="train"): | ||
self.root = root | ||
self.split = split | ||
self.args = args | ||
self.images = {} | ||
self.event = {} | ||
self.labels = {} | ||
|
||
with open('dataloaders/apolloscape_txt/colors_{}.txt'.format(split), 'r') as colors_f, \ | ||
open('dataloaders/apolloscape_txt/events_{}.txt'.format(split), 'r') as events_f, \ | ||
open('dataloaders/apolloscape_txt/labels_{}.txt'.format(split), 'r') as labels_f: | ||
self.images[split] = colors_f.read().splitlines() | ||
self.event[split] = events_f.read().splitlines() | ||
self.labels[split] = labels_f.read().splitlines() | ||
print("Found %d %s RGB images" % (len(self.images[split]), split)) | ||
|
||
self.ignore_index = 255 | ||
|
||
def __len__(self): | ||
return len(self.labels[self.split]) | ||
|
||
def __getitem__(self, index): | ||
sample = dict() | ||
lbl_path = self.root + self.labels[self.split][index].rstrip() | ||
sample['label'] = self.relabel(lbl_path) | ||
|
||
img_path = self.root + self.images[self.split][index].rstrip() | ||
sample['image'] = Image.open(img_path).convert('RGB') | ||
|
||
if self.args.event_dim: | ||
event_path = self.root + self.event[self.split][index].rstrip() | ||
sample['event'] = self.get_event(event_path) | ||
|
||
# data augment | ||
if self.split == 'train': | ||
return self.transform_tr(sample) | ||
elif self.split == 'val': | ||
return self.transform_val(sample), lbl_path | ||
elif self.split == 'test': | ||
raise NotImplementedError | ||
|
||
def get_event(self, event_path): | ||
event_volume = np.load(event_path)['data'] | ||
neg_volume = event_volume[:9, ...] | ||
pos_volume = event_volume[9:, ...] | ||
if self.args.event_dim == 18: | ||
event_volume = np.concatenate((neg_volume, pos_volume), axis=0) | ||
elif self.args.event_dim == 2: | ||
neg_img = np.sum(neg_volume, axis=0, keepdims=True) | ||
pos_img = np.sum(pos_volume, axis=0, keepdims=True) | ||
event_volume = np.concatenate((neg_img, pos_img), axis=0) | ||
elif self.args.event_dim == 1: | ||
neg_img = np.sum(neg_volume, axis=0, keepdims=True) | ||
pos_img = np.sum(pos_volume, axis=0, keepdims=True) | ||
event_volume = neg_img + pos_img | ||
return event_volume | ||
|
||
def relabel(self, label_path): | ||
"""from apollo to the 18 class (Cityscapes without 'train', cls=16)""" | ||
_temp = np.array(Image.open(label_path)) | ||
label_mapping1 = {m[0]: m[1] for m in APOLLO2CS} | ||
for k, v in label_mapping1.items(): | ||
_temp[_temp == k] = v | ||
|
||
label_mapping2 = {m[0]: m[1] for m in APOLLO16} | ||
for k, v in label_mapping2.items(): | ||
_temp[_temp == k] = v | ||
return Image.fromarray(_temp.astype(np.uint8)) | ||
|
||
def transform_tr(self, sample): | ||
composed_transforms = transforms.Compose([ | ||
tr.FixedResize(size=(1024, 2048)), | ||
tr.ColorJitter(), | ||
tr.RandomGaussianBlur(), | ||
tr.RandomMotionBlur(), | ||
tr.RandomHorizontalFlip(), | ||
tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), | ||
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | ||
tr.ToTensor()]) | ||
return composed_transforms(sample) | ||
|
||
def transform_val(self, sample): | ||
composed_transforms = transforms.Compose([ | ||
tr.FixedResize(size=self.args.crop_size), | ||
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | ||
tr.ToTensor()]) | ||
return composed_transforms(sample) | ||
|
||
|
||
if __name__ == '__main__': | ||
from dataloaders.utils import decode_segmap | ||
from torch.utils.data import DataLoader | ||
import matplotlib.pyplot as plt | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
args = parser.parse_args() | ||
args.base_size = 1024 | ||
args.crop_size = 512 | ||
args.event_dim = 2 | ||
|
||
apolloscape_train = ApolloscapeRGBEvent(args, split='train') | ||
dataloader = DataLoader(apolloscape_train, batch_size=2, shuffle=True, num_workers=0) | ||
|
||
for ii, sample in enumerate(dataloader): | ||
for jj in range(sample["image"].size()[0]): | ||
img = sample['image'].numpy() | ||
gt = sample['label'].numpy() | ||
tmp = np.array(gt[jj]).astype(np.uint8) | ||
segmap = decode_segmap(tmp, dataset='apolloscapeevent') | ||
img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) | ||
img_tmp *= (0.229, 0.224, 0.225) | ||
img_tmp += (0.485, 0.456, 0.406) | ||
img_tmp *= 255.0 | ||
img_tmp = img_tmp.astype(np.uint8) | ||
plt.figure() | ||
plt.title('display') | ||
plt.subplot(311) | ||
plt.imshow(img_tmp) | ||
plt.subplot(312) | ||
plt.imshow(segmap) | ||
|
||
if ii == 2: | ||
break | ||
|
||
plt.show(block=True) | ||
|
Oops, something went wrong.