From b9aa18520205a4037b109c23f2b3622d57de5805 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Wed, 31 Oct 2018 15:52:50 +0800 Subject: [PATCH 1/3] TGS salt example --- examples/trials/tgs-salt/README.md | 56 ++ examples/trials/tgs-salt/augmentation.py | 267 +++++++++ examples/trials/tgs-salt/config.yml | 20 + examples/trials/tgs-salt/focal_loss.py | 80 +++ examples/trials/tgs-salt/loader.py | 291 ++++++++++ examples/trials/tgs-salt/lovasz_losses.py | 252 +++++++++ examples/trials/tgs-salt/metrics.py | 85 +++ examples/trials/tgs-salt/models.py | 622 +++++++++++++++++++++ examples/trials/tgs-salt/postprocessing.py | 63 +++ examples/trials/tgs-salt/predict.py | 223 ++++++++ examples/trials/tgs-salt/preprocess.py | 97 ++++ examples/trials/tgs-salt/settings.py | 45 ++ examples/trials/tgs-salt/train.py | 258 +++++++++ examples/trials/tgs-salt/utils.py | 187 +++++++ 14 files changed, 2546 insertions(+) create mode 100644 examples/trials/tgs-salt/README.md create mode 100644 examples/trials/tgs-salt/augmentation.py create mode 100644 examples/trials/tgs-salt/config.yml create mode 100644 examples/trials/tgs-salt/focal_loss.py create mode 100644 examples/trials/tgs-salt/loader.py create mode 100644 examples/trials/tgs-salt/lovasz_losses.py create mode 100644 examples/trials/tgs-salt/metrics.py create mode 100644 examples/trials/tgs-salt/models.py create mode 100644 examples/trials/tgs-salt/postprocessing.py create mode 100644 examples/trials/tgs-salt/predict.py create mode 100644 examples/trials/tgs-salt/preprocess.py create mode 100644 examples/trials/tgs-salt/settings.py create mode 100644 examples/trials/tgs-salt/train.py create mode 100644 examples/trials/tgs-salt/utils.py diff --git a/examples/trials/tgs-salt/README.md b/examples/trials/tgs-salt/README.md new file mode 100644 index 0000000000..f0ed660d39 --- /dev/null +++ b/examples/trials/tgs-salt/README.md @@ -0,0 +1,56 @@ +## 33rd place solution code for Kaggle [TGS Salt Identification Chanllenge](https://www.kaggle.com/c/tgs-salt-identification-challenge) + +This example shows how to enable AutoML for competition code by running it on NNI without any code change. +To run this code on NNI, firstly you need to run it standalone, then configure the config.yml and: +``` +nnictl create --config config.yml +``` + +This code can still run standalone, the code is for reference, it requires at least one week effort to reproduce the competition result. + +[Solution summary](https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/69593) + +Preparation: + +Download competition data, run preprocess.py to prepare training data. + +Stage 1: + +Train fold 0-3 for 100 epochs, for each fold, train 3 models: +``` +python3 train.py --ifolds 0 --epochs 100 --model_name UNetResNetV4 +python3 train.py --ifolds 0 --epochs 100 --model_name UNetResNetV5 --layers 50 +python3 train.py --ifolds 0 --epochs 100 --model_name UNetResNetV6 +``` + +Stage 2: + +Fine tune stage 1 models for 300 epochs with cosine annealing lr scheduler: + +``` +python3 train.py --ifolds 0 --epochs 300 --lrs cosine --lr 0.001 --min_lr 0.0001 --model_name UNetResNetV4 +``` + +Stage 3: + +Fine tune Stage 2 models with depths channel: + +``` +python3 train.py --ifolds 0 --epochs 300 --lrs cosine --lr 0.001 --min_lr 0.0001 --model_name UNetResNetV4 --depths +``` + +Stage 4: + +Make prediction for each model, then ensemble the result to generate peasdo labels. + +Stage 5: + +Fine tune stage 3 models with pseudo labels + +``` +python3 train.py --ifolds 0 --epochs 300 --lrs cosine --lr 0.001 --min_lr 0.0001 --model_name UNetResNetV4 --depths --pseudo +``` + +Stage 6: +Ensemble all stage 3 and stage 5 models. + diff --git a/examples/trials/tgs-salt/augmentation.py b/examples/trials/tgs-salt/augmentation.py new file mode 100644 index 0000000000..633a02cb21 --- /dev/null +++ b/examples/trials/tgs-salt/augmentation.py @@ -0,0 +1,267 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os +import cv2 +import numpy as np +import random +import torchvision.transforms.functional as F +from torchvision.transforms import RandomResizedCrop, ColorJitter, RandomAffine +import PIL +from PIL import Image +import collections + +import settings + + +class RandomHFlipWithMask(object): + def __init__(self, p=0.5): + self.p = p + def __call__(self, *imgs): + if random.random() < self.p: + return map(F.hflip, imgs) + else: + return imgs + +class RandomVFlipWithMask(object): + def __init__(self, p=0.5): + self.p = p + def __call__(self, *imgs): + if random.random() < self.p: + return map(F.vflip, imgs) + else: + return imgs + +class RandomResizedCropWithMask(RandomResizedCrop): + def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): + super(RandomResizedCropWithMask, self).__init__(size, scale, ratio, interpolation) + def __call__(self, *imgs): + i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio) + #print(i,j,h,w) + return map(lambda x: F.resized_crop(x, i, j, h, w, self.size, self.interpolation), imgs) + +class RandomAffineWithMask(RandomAffine): + def __init__(self, degrees, translate=None, scale=None, shear=None, resample='edge'): + super(RandomAffineWithMask, self).__init__(degrees, translate, scale, shear, resample) + def __call__(self, *imgs): + ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, imgs[0].size) + w, h = imgs[0].size + imgs = map(lambda x: F.pad(x, w//2, 0, self.resample), imgs) + imgs = map(lambda x: F.affine(x, *ret, resample=0), imgs) + imgs = map(lambda x: F.center_crop(x, (w, h)), imgs) + return imgs + +class RandomRotateWithMask(object): + def __init__(self, degrees, pad_mode='reflect', expand=False, center=None): + self.pad_mode = pad_mode + self.expand = expand + self.center = center + self.degrees = degrees + + def __call__(self, *imgs): + angle = self.get_angle() + if angle == int(angle) and angle % 90 == 0: + if angle == 0: + return imgs + else: + #print(imgs) + return map(lambda x: F.rotate(x, angle, False, False, None), imgs) + else: + return map(lambda x: self._pad_rotate(x, angle), imgs) + + def get_angle(self): + if isinstance(self.degrees, collections.Sequence): + index = int(random.random() * len(self.degrees)) + return self.degrees[index] + else: + return random.uniform(-self.degrees, self.degrees) + + def _pad_rotate(self, img, angle): + w, h = img.size + img = F.pad(img, w//2, 0, self.pad_mode) + img = F.rotate(img, angle, False, self.expand, self.center) + img = F.center_crop(img, (w, h)) + return img + +class CropWithMask(object): + def __init__(self, i, j, h, w): + self.i = i + self.j = j + self.h = h + self.w = w + def __call__(self, *imgs): + return map(lambda x: F.crop(x, self.i, self.j, self.h, self.w), imgs) + +class PadWithMask(object): + def __init__(self, padding, padding_mode): + self.padding = padding + self.padding_mode = padding_mode + def __call__(self, *imgs): + return map(lambda x: F.pad(x, self.padding, padding_mode=self.padding_mode), imgs) + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, *imgs): + for t in self.transforms: + imgs = t(*imgs) + return imgs + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + +def get_img_mask_augments(train_mode, pad_mode): + if pad_mode == 'resize': + img_mask_aug_train = Compose([ + RandomHFlipWithMask(), + RandomAffineWithMask(10, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=None) + ]) + img_mask_aug_val = None + else: + img_mask_aug_train = Compose([ + PadWithMask((28, 28), padding_mode=pad_mode), + RandomHFlipWithMask(), + RandomAffineWithMask(10, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=None), + RandomResizedCropWithMask(128, scale=(1., 1.), ratio=(1., 1.)) + ]) + img_mask_aug_val = PadWithMask((13, 14), padding_mode=pad_mode) + + return img_mask_aug_train, img_mask_aug_val + + +def test_transform(): + img_id = '0b73b427d1.png' + img = Image.open(os.path.join(settings.TRAIN_IMG_DIR, img_id)).convert('RGB') + mask = Image.open(os.path.join(settings.TRAIN_MASK_DIR, img_id)).convert('L').point(lambda x: 0 if x < 128 else 1, 'L') + + img_id = '0a1ea1af4.jpg' + img = Image.open(os.path.join(r'D:\data\ship\train_v2', img_id)).convert('RGB') + mask = Image.open(os.path.join(r'D:\data\ship\train_masks', img_id)).convert('L').point(lambda x: 0 if x < 128 else 1, 'L') + + #trans = RandomResizedCropWithMask(768, scale=(0.6, 1)) + trans = Compose([ + RandomHFlipWithMask(), + RandomVFlipWithMask(), + RandomRotateWithMask([0, 90, 180, 270]), + #RandomRotateWithMask(15), + RandomResizedCropWithMask(768, scale=(0.81, 1)) + ]) + + trans2 = RandomAffineWithMask(45, (0.2,0.2), (0.9, 1.1)) + #trans = RandomRotateWithMask([0, 90, 180, 270]) + trans3, trans4 = get_img_mask_augments(True, 'edge') + + img, mask = trans4(img, mask) + + img.show() + mask.point(lambda x: x*255).show() + +def test_color_trans(): + img_id = '00abc623a.jpg' + img = Image.open(os.path.join(settings.TRAIN_IMG_DIR, img_id)).convert('RGB') + trans = ColorJitter(0.1, 0.1, 0.1, 0.1) + + img2 = trans(img) + img.show() + img2.show() + + +class TTATransform(object): + def __init__(self, index): + self.index = index + def __call__(self, img): + trans = { + 0: lambda x: x, + 1: lambda x: F.hflip(x), + 2: lambda x: F.vflip(x), + 3: lambda x: F.vflip(F.hflip(x)), + 4: lambda x: F.rotate(x, 90, False, False), + 5: lambda x: F.hflip(F.rotate(x, 90, False, False)), + 6: lambda x: F.vflip(F.rotate(x, 90, False, False)), + 7: lambda x: F.vflip(F.hflip(F.rotate(x, 90, False, False))) + } + return trans[self.index](img) + +# i is tta index, 0: no change, 1: horizon flip, 2: vertical flip, 3: do both +def tta_back_mask_np(img, index): + print(img.shape) + trans = { + 0: lambda x: x, + 1: lambda x: np.flip(x, 2), + 2: lambda x: np.flip(x, 1), + 3: lambda x: np.flip(np.flip(x, 2), 1), + 4: lambda x: np.rot90(x, 3, axes=(1,2)), + 5: lambda x: np.rot90(np.flip(x, 2), 3, axes=(1,2)), + 6: lambda x: np.rot90(np.flip(x, 1), 3, axes=(1,2)), + 7: lambda x: np.rot90(np.flip(np.flip(x,2), 1), 3, axes=(1,2)) + } + + return trans[index](img) + +def test_tta(): + img_f = os.path.join(settings.TEST_IMG_DIR, '0c2637aa9.jpg') + img = Image.open(img_f) + img = img.convert('RGB') + + tta_index = 7 + trans1 = TTATransform(tta_index) + img = trans1(img) + #img.show() + + img_np = np.array(img) + img_np = np.expand_dims(img_np, 0) + print(img_np.shape) + img_np = tta_back_mask_np(img_np, tta_index) + img_np = np.reshape(img_np, (768, 768, 3)) + img_back = F.to_pil_image(img_np) + img_back.show() + + +def test_rotate(): + img_f = os.path.join(settings.TEST_IMG_DIR, '0c2637aa9.jpg') + img = Image.open(img_f) + img = img.convert('RGB') + #img_np = np.array(img) + #img_np_r90 = np.rot90(img_np,1) + #img_np_r90 = np.rot90(img_np_r90,3) + #img_2 = F.to_pil_image(img_np_r90) + #img = F.rotate(img, 90, False, False) + #ImageDraw.Draw(img_2) + #img_2.show() + #img.show() + + img_aug = tta_7(img) + #img_aug = tta_7_back(img_aug) + img_aug = tta_back_np(img_aug, 7) + img_aug.show() + + +if __name__ == '__main__': + #test_augment() + #test_rotate() + #test_tta() + test_transform() + #test_color_trans() \ No newline at end of file diff --git a/examples/trials/tgs-salt/config.yml b/examples/trials/tgs-salt/config.yml new file mode 100644 index 0000000000..1a0db8a51f --- /dev/null +++ b/examples/trials/tgs-salt/config.yml @@ -0,0 +1,20 @@ +authorName: default +experimentName: example_tgs +trialConcurrency: 2 +maxExecDuration: 10h +maxTrialNum: 10 +#choice: local, remote, pai +trainingServicePlatform: local +#choice: true, false +useAnnotation: true +tuner: + #choice: TPE, Random, Anneal, Evolution, BatchTuner + #SMAC (SMAC should be installed through nnictl) + builtinTunerName: TPE + classArgs: + #choice: maximize, minimize + optimize_mode: maximize +trial: + command: python3 train.py + codeDir: . + gpuNum: 1 diff --git a/examples/trials/tgs-salt/focal_loss.py b/examples/trials/tgs-salt/focal_loss.py new file mode 100644 index 0000000000..e987ef847e --- /dev/null +++ b/examples/trials/tgs-salt/focal_loss.py @@ -0,0 +1,80 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FocalLoss2d(nn.Module): + + def __init__(self, gamma=2, size_average=True): + super(FocalLoss2d, self).__init__() + self.gamma = gamma + self.size_average = size_average + + + def forward(self, logit, target, class_weight=None, type='sigmoid'): + target = target.view(-1, 1).long() + + if type=='sigmoid': + if class_weight is None: + class_weight = [1]*2 #[0.5, 0.5] + + prob = torch.sigmoid(logit) + prob = prob.view(-1, 1) + prob = torch.cat((1-prob, prob), 1) + select = torch.FloatTensor(len(prob), 2).zero_().cuda() + select.scatter_(1, target, 1.) + + elif type=='softmax': + B,C,H,W = logit.size() + if class_weight is None: + class_weight =[1]*C #[1/C]*C + + logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C) + prob = F.softmax(logit,1) + select = torch.FloatTensor(len(prob), C).zero_().cuda() + select.scatter_(1, target, 1.) + + class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1) + class_weight = torch.gather(class_weight, 0, target) + + prob = (prob*select).sum(1).view(-1,1) + prob = torch.clamp(prob,1e-8,1-1e-8) + batch_loss = - class_weight *(torch.pow((1-prob), self.gamma))*prob.log() + + if self.size_average: + loss = batch_loss.mean() + else: + loss = batch_loss + + return loss + + +if __name__ == '__main__': + L = FocalLoss2d() + out = torch.randn(2, 3, 3).cuda() + #target = torch.ones(2, 3, 3).cuda() + target = (torch.sigmoid(out) > 0.5).float() + #print(target, out) + loss = L(out, target) + print(loss) + #pass \ No newline at end of file diff --git a/examples/trials/tgs-salt/loader.py b/examples/trials/tgs-salt/loader.py new file mode 100644 index 0000000000..089e48b903 --- /dev/null +++ b/examples/trials/tgs-salt/loader.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os, cv2, glob +import numpy as np +from PIL import Image + +import torch +import torch.utils.data as data +from torchvision import datasets, models, transforms +from utils import read_masks, get_test_meta, get_nfold_split +import augmentation as aug +from settings import * + +class ImageDataset(data.Dataset): + def __init__(self, train_mode, meta, augment_with_target=None, + image_augment=None, image_transform=None, mask_transform=None): + self.augment_with_target = augment_with_target + self.image_augment = image_augment + self.image_transform = image_transform + self.mask_transform = mask_transform + + self.train_mode = train_mode + self.meta = meta + + self.img_ids = meta[ID_COLUMN].values + self.salt_exists = meta['salt_exists'].values + self.is_train = meta['is_train'].values + + if self.train_mode: + self.mask_filenames = meta[Y_COLUMN].values + + def __getitem__(self, index): + base_img_fn = '{}.png'.format(self.img_ids[index]) + if self.is_train[index]: #self.train_mode: + img_fn = os.path.join(TRAIN_IMG_DIR, base_img_fn) + else: + img_fn = os.path.join(TEST_IMG_DIR, base_img_fn) + img = self.load_image(img_fn) + + if self.train_mode: + base_mask_fn = '{}.png'.format(self.img_ids[index]) + if self.is_train[index]: + mask_fn = os.path.join(TRAIN_MASK_DIR, base_mask_fn) + else: + mask_fn = os.path.join(TEST_DIR, 'masks', base_mask_fn) + mask = self.load_image(mask_fn, True) + img, mask = self.aug_image(img, mask) + return img, mask, self.salt_exists[index] + else: + img = self.aug_image(img) + return [img] + + def aug_image(self, img, mask=None): + if mask is not None: + if self.augment_with_target is not None: + img, mask = self.augment_with_target(img, mask) + if self.image_augment is not None: + img = self.image_augment(img) + if self.mask_transform is not None: + mask = self.mask_transform(mask) + if self.image_transform is not None: + img = self.image_transform(img) + return img, mask + else: + if self.image_augment is not None: + img = self.image_augment(img) + if self.image_transform is not None: + img = self.image_transform(img) + return img + + def load_image(self, img_filepath, grayscale=False): + image = Image.open(img_filepath, 'r') + if not grayscale: + image = image.convert('RGB') + else: + image = image.convert('L').point(lambda x: 0 if x < 128 else 1, 'L') + return image + + def __len__(self): + return len(self.img_ids) + + def collate_fn(self, batch): + imgs = [x[0] for x in batch] + inputs = torch.stack(imgs) + + if self.train_mode: + masks = [x[1] for x in batch] + labels = torch.stack(masks) + + salt_target = [x[2] for x in batch] + return inputs, labels, torch.FloatTensor(salt_target) + else: + return inputs + +def mask_to_tensor(x): + x = np.array(x).astype(np.float32) + x = np.expand_dims(x, axis=0) + x = torch.from_numpy(x) + return x + +img_transforms = [ + transforms.Grayscale(num_output_channels=3), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] + +def get_tta_transforms(index, pad_mode): + tta_transforms = { + 0: [], + 1: [transforms.RandomHorizontalFlip(p=2.)], + 2: [transforms.RandomVerticalFlip(p=2.)], + 3: [transforms.RandomHorizontalFlip(p=2.), transforms.RandomVerticalFlip(p=2.)] + } + if pad_mode == 'resize': + return transforms.Compose([transforms.Resize((H, W)), *(tta_transforms[index]), *img_transforms]) + else: + return transforms.Compose([*(tta_transforms[index]), *img_transforms]) + +def get_image_transform(pad_mode): + if pad_mode == 'resize': + return transforms.Compose([transforms.Resize((H, W)), *img_transforms]) + else: + return transforms.Compose(img_transforms) + +def get_mask_transform(pad_mode): + if pad_mode == 'resize': + return transforms.Compose( + [ + transforms.Resize((H, W)), + transforms.Lambda(mask_to_tensor), + ] + ) + else: + return transforms.Compose( + [ + transforms.Lambda(mask_to_tensor), + ] + ) + +def get_img_mask_augments(pad_mode, depths_channel=False): + if depths_channel: + affine_aug = aug.RandomAffineWithMask(5, translate=(0.1, 0.), scale=(0.9, 1.1), shear=None) + else: + affine_aug = aug.RandomAffineWithMask(15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=None) + + if pad_mode == 'resize': + img_mask_aug_train = aug.Compose([ + aug.RandomHFlipWithMask(), + affine_aug + ]) + img_mask_aug_val = None + else: + img_mask_aug_train = aug.Compose([ + aug.PadWithMask((28, 28), padding_mode=pad_mode), + aug.RandomHFlipWithMask(), + affine_aug, + aug.RandomResizedCropWithMask(H, scale=(1., 1.), ratio=(1., 1.)) + ]) + img_mask_aug_val = aug.PadWithMask((13, 13, 14, 14), padding_mode=pad_mode) + + return img_mask_aug_train, img_mask_aug_val + +def get_train_loaders(ifold, batch_size=8, dev_mode=False, pad_mode='edge', meta_version=1, pseudo_label=False, depths=False): + train_shuffle = True + train_meta, val_meta = get_nfold_split(ifold, nfold=10, meta_version=meta_version) + + if pseudo_label: + test_meta = get_test_meta() + train_meta = train_meta.append(test_meta, sort=True) + + if dev_mode: + train_shuffle = False + train_meta = train_meta.iloc[:10] + val_meta = val_meta.iloc[:10] + #print(val_meta[X_COLUMN].values[:5]) + #print(val_meta[Y_COLUMN].values[:5]) + print(train_meta.shape, val_meta.shape) + img_mask_aug_train, img_mask_aug_val = get_img_mask_augments(pad_mode, depths) + + train_set = ImageDataset(True, train_meta, + augment_with_target=img_mask_aug_train, + image_augment=transforms.ColorJitter(0.2, 0.2, 0.2, 0.2), + image_transform=get_image_transform(pad_mode), + mask_transform=get_mask_transform(pad_mode)) + + train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=train_shuffle, num_workers=4, collate_fn=train_set.collate_fn, drop_last=True) + train_loader.num = len(train_set) + + val_set = ImageDataset(True, val_meta, + augment_with_target=img_mask_aug_val, + image_augment=None, + image_transform=get_image_transform(pad_mode), + mask_transform=get_mask_transform(pad_mode)) + val_loader = data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=val_set.collate_fn) + val_loader.num = len(val_set) + val_loader.y_true = read_masks(val_meta[ID_COLUMN].values) + + return train_loader, val_loader + +def get_test_loader(batch_size=16, index=0, dev_mode=False, pad_mode='edge'): + test_meta = get_test_meta() + if dev_mode: + test_meta = test_meta.iloc[:10] + test_set = ImageDataset(False, test_meta, + image_augment=None if pad_mode == 'resize' else transforms.Pad((13,13,14,14), padding_mode=pad_mode), + image_transform=get_tta_transforms(index, pad_mode)) + test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=test_set.collate_fn, drop_last=False) + test_loader.num = len(test_set) + test_loader.meta = test_set.meta + + return test_loader + +depth_channel_tensor = None + +def get_depth_tensor(pad_mode): + global depth_channel_tensor + + if depth_channel_tensor is not None: + return depth_channel_tensor + + depth_tensor = None + + if pad_mode == 'resize': + depth_tensor = np.zeros((H, W)) + for row, const in enumerate(np.linspace(0, 1, H)): + depth_tensor[row, :] = const + else: + depth_tensor = np.zeros((ORIG_H, ORIG_W)) + for row, const in enumerate(np.linspace(0, 1, ORIG_H)): + depth_tensor[row, :] = const + depth_tensor = np.pad(depth_tensor, (14,14), mode=pad_mode) # edge or reflect + depth_tensor = depth_tensor[:H, :W] + + depth_channel_tensor = torch.Tensor(depth_tensor) + return depth_channel_tensor + +def add_depth_channel(img_tensor, pad_mode): + ''' + img_tensor: N, C, H, W + ''' + img_tensor[:, 1] = get_depth_tensor(pad_mode) + img_tensor[:, 2] = img_tensor[:, 0] * get_depth_tensor(pad_mode) + + +def test_train_loader(): + train_loader, val_loader = get_train_loaders(1, batch_size=4, dev_mode=False, pad_mode='edge', meta_version=2, pseudo_label=True) + print(train_loader.num, val_loader.num) + for i, data in enumerate(train_loader): + imgs, masks, salt_exists = data + #pdb.set_trace() + print(imgs.size(), masks.size(), salt_exists.size()) + print(salt_exists) + add_depth_channel(imgs, 'resize') + print(masks) + break + #print(imgs) + #print(masks) + +def test_test_loader(): + test_loader = get_test_loader(4, pad_mode='resize') + print(test_loader.num) + for i, data in enumerate(test_loader): + print(data.size()) + if i > 5: + break + +if __name__ == '__main__': + test_test_loader() + #test_train_loader() + #small_dict, img_ids = load_small_train_ids() + #print(img_ids[:10]) + #print(get_tta_transforms(3, 'edge')) diff --git a/examples/trials/tgs-salt/lovasz_losses.py b/examples/trials/tgs-salt/lovasz_losses.py new file mode 100644 index 0000000000..7d86a19af9 --- /dev/null +++ b/examples/trials/tgs-salt/lovasz_losses.py @@ -0,0 +1,252 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +from __future__ import print_function, division + +import torch +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np + +try: + from itertools import ifilterfalse +except ImportError: # py3k + from itertools import filterfalse + + +def lovasz_grad(gt_sorted): + """ + Computes gradient of the Lovasz extension w.r.t sorted errors + See Alg. 1 in paper + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): + """ + IoU for foreground class + binary: 1 foreground, 0 background + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + intersection = ((label == 1) & (pred == 1)).sum() + union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() + if not union: + iou = EMPTY + else: + iou = float(intersection) / union + ious.append(iou) + iou = mean(ious) # mean accross images if per_image + return 100 * iou + + +def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): + """ + Array of IoU for each (non ignored) class + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + iou = [] + for i in range(C): + if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) + intersection = ((label == i) & (pred == i)).sum() + union = ((label == i) | ((pred == i) & (label != ignore))).sum() + if not union: + iou.append(EMPTY) + else: + iou.append(float(intersection) / union) + ious.append(iou) + ious = map(mean, zip(*ious)) # mean accross images if per_image + return 100 * np.array(ious) + + +# --------------------------- BINARY LOSSES --------------------------- + + +def lovasz_hinge(logits, labels, per_image=True, ignore=None): + """ + Binary Lovasz hinge loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + per_image: compute the loss per image instead of per batch + ignore: void class id + """ + if per_image: + loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) + for log, lab in zip(logits, labels)) + else: + loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) + return loss + + +def lovasz_hinge_flat(logits, labels): + """ + Binary Lovasz hinge loss + logits: [P] Variable, logits at each prediction (between -\infty and +\infty) + labels: [P] Tensor, binary ground truth labels (0 or 1) + ignore: label to ignore + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * Variable(signs)) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.elu(errors_sorted)+1, Variable(grad)) + #loss = torch.dot(F.relu(errors_sorted), Variable(grad)) + + return loss + + +def flatten_binary_scores(scores, labels, ignore=None): + """ + Flattens predictions in the batch (binary case) + Remove labels equal to 'ignore' + """ + scores = scores.view(-1) + labels = labels.view(-1) + if ignore is None: + return scores, labels + valid = (labels != ignore) + vscores = scores[valid] + vlabels = labels[valid] + return vscores, vlabels + + +class StableBCELoss(torch.nn.modules.Module): + def __init__(self): + super(StableBCELoss, self).__init__() + def forward(self, input, target): + neg_abs = - input.abs() + loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() + return loss.mean() + + +def binary_xloss(logits, labels, ignore=None): + """ + Binary Cross entropy loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + ignore: void class id + """ + logits, labels = flatten_binary_scores(logits, labels, ignore) + loss = StableBCELoss()(logits, Variable(labels.float())) + return loss + + +# --------------------------- MULTICLASS LOSSES --------------------------- + + +def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None): + """ + Multi-class Lovasz-Softmax loss + probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1) + labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + only_present: average only on classes present in ground truth + per_image: compute the loss per image instead of per batch + ignore: void class labels + """ + if per_image: + loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present) + for prob, lab in zip(probas, labels)) + else: + loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present) + return loss + + +def lovasz_softmax_flat(probas, labels, only_present=False): + """ + Multi-class Lovasz-Softmax loss + probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) + labels: [P] Tensor, ground truth labels (between 0 and C - 1) + only_present: average only on classes present in ground truth + """ + C = probas.size(1) + losses = [] + for c in range(C): + fg = (labels == c).float() # foreground for class c + if only_present and fg.sum() == 0: + continue + errors = (Variable(fg) - probas[:, c]).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) + return mean(losses) + + +def flatten_probas(probas, labels, ignore=None): + """ + Flattens predictions in the batch + """ + B, C, H, W = probas.size() + probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C + labels = labels.view(-1) + if ignore is None: + return probas, labels + valid = (labels != ignore) + vprobas = probas[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobas, vlabels + +def xloss(logits, labels, ignore=None): + """ + Cross entropy loss + """ + return F.cross_entropy(logits, Variable(labels), ignore_index=255) + + +# --------------------------- HELPER FUNCTIONS --------------------------- + +def mean(l, ignore_nan=False, empty=0): + """ + nanmean compatible with generators. + """ + l = iter(l) + if ignore_nan: + l = ifilterfalse(np.isnan, l) + try: + n = 1 + acc = next(l) + except StopIteration: + if empty == 'raise': + raise ValueError('Empty mean') + return empty + for n, v in enumerate(l, 2): + acc += v + if n == 1: + return acc + return acc / n diff --git a/examples/trials/tgs-salt/metrics.py b/examples/trials/tgs-salt/metrics.py new file mode 100644 index 0000000000..e253fec5cd --- /dev/null +++ b/examples/trials/tgs-salt/metrics.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import numpy as np +from pycocotools import mask as cocomask +from utils import get_segmentations + + +def iou(gt, pred): + gt[gt > 0] = 1. + pred[pred > 0] = 1. + intersection = gt * pred + union = gt + pred + union[union > 0] = 1. + intersection = np.sum(intersection) + union = np.sum(union) + if union == 0: + union = 1e-09 + return intersection / union + + +def compute_ious(gt, predictions): + gt_ = get_segmentations(gt) + predictions_ = get_segmentations(predictions) + + if len(gt_) == 0 and len(predictions_) == 0: + return np.ones((1, 1)) + elif len(gt_) != 0 and len(predictions_) == 0: + return np.zeros((1, 1)) + else: + iscrowd = [0 for _ in predictions_] + ious = cocomask.iou(gt_, predictions_, iscrowd) + if not np.array(ious).size: + ious = np.zeros((1, 1)) + return ious + + +def compute_precision_at(ious, threshold): + mx1 = np.max(ious, axis=0) + mx2 = np.max(ious, axis=1) + tp = np.sum(mx2 >= threshold) + fp = np.sum(mx2 < threshold) + fn = np.sum(mx1 < threshold) + return float(tp) / (tp + fp + fn) + + +def compute_eval_metric(gt, predictions): + thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95] + ious = compute_ious(gt, predictions) + precisions = [compute_precision_at(ious, th) for th in thresholds] + return sum(precisions) / len(precisions) + + +def intersection_over_union(y_true, y_pred): + ious = [] + for y_t, y_p in list(zip(y_true, y_pred)): + iou = compute_ious(y_t, y_p) + iou_mean = 1.0 * np.sum(iou) / len(iou) + ious.append(iou_mean) + return np.mean(ious) + + +def intersection_over_union_thresholds(y_true, y_pred): + iouts = [] + for y_t, y_p in list(zip(y_true, y_pred)): + iouts.append(compute_eval_metric(y_t, y_p)) + return np.mean(iouts) diff --git a/examples/trials/tgs-salt/models.py b/examples/trials/tgs-salt/models.py new file mode 100644 index 0000000000..ef941df886 --- /dev/null +++ b/examples/trials/tgs-salt/models.py @@ -0,0 +1,622 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from torch import nn +from torch.nn import functional as F +import torch +from torchvision import models +from torchvision.models import resnet34, resnet101, resnet50, resnet152 +import torchvision + + +def conv3x3(in_, out): + return nn.Conv2d(in_, out, 3, padding=1) + + +class ConvRelu(nn.Module): + def __init__(self, in_, out): + super().__init__() + self.conv = conv3x3(in_, out) + self.activation = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.activation(x) + return x + + +class ConvBn2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(3,3), stride=(1,1), padding=(1,1)): + super(ConvBn2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_channels) + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + +# Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks +# https://arxiv.org/abs/1803.02579 + +class ChannelAttentionGate(nn.Module): + def __init__(self, channel, reduction=16): + super(ChannelAttentionGate, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return y + + +class SpatialAttentionGate(nn.Module): + def __init__(self, channel, reduction=16): + super(SpatialAttentionGate, self).__init__() + self.fc1 = nn.Conv2d(channel, reduction, kernel_size=1, padding=0) + self.fc2 = nn.Conv2d(reduction, 1, kernel_size=1, padding=0) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x, inplace=True) + x = self.fc2(x) + x = torch.sigmoid(x) + #print(x.size()) + return x + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, middle_channels, out_channels): + super(DecoderBlock, self).__init__() + self.conv1 = ConvBn2d(in_channels, middle_channels) + self.conv2 = ConvBn2d(middle_channels, out_channels) + #self.deconv = nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, padding=1) + #self.bn = nn.BatchNorm2d(out_channels) + self.spatial_gate = SpatialAttentionGate(out_channels) + self.channel_gate = ChannelAttentionGate(out_channels) + + def forward(self, x, e=None): + x = F.upsample(x, scale_factor=2, mode='bilinear', align_corners=True) + if e is not None: + x = torch.cat([x,e], 1) + + x = F.relu(self.conv1(x), inplace=True) + x = F.relu(self.conv2(x), inplace=True) + + g1 = self.spatial_gate(x) + g2 = self.channel_gate(x) + x = x*g1 + x*g2 + + return x + +class EncoderBlock(nn.Module): + def __init__(self, block, out_channels): + super(EncoderBlock, self).__init__() + self.block = block + self.out_channels = out_channels + self.spatial_gate = SpatialAttentionGate(out_channels) + self.channel_gate = ChannelAttentionGate(out_channels) + + def forward(self, x): + x = self.block(x) + g1 = self.spatial_gate(x) + g2 = self.channel_gate(x) + + return x*g1 + x*g2 + + +def create_resnet(layers): + if layers == 34: + return resnet34(pretrained=True), 512 + elif layers == 50: + return resnet50(pretrained=True), 2048 + elif layers == 101: + return resnet101(pretrained=True), 2048 + elif layers == 152: + return resnet152(pretrained=True), 2048 + else: + raise NotImplementedError('only 34, 50, 101, 152 version of Resnet are implemented') + +class UNetResNetV4(nn.Module): + def __init__(self, encoder_depth, num_classes=1, num_filters=32, dropout_2d=0.4, + pretrained=True, is_deconv=True): + super(UNetResNetV4, self).__init__() + self.name = 'UNetResNetV4_'+str(encoder_depth) + self.num_classes = num_classes + self.dropout_2d = dropout_2d + + self.resnet, bottom_channel_nr = create_resnet(encoder_depth) + + self.encoder1 = EncoderBlock( + nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu), + num_filters*2 + ) + self.encoder2 = EncoderBlock(self.resnet.layer1, bottom_channel_nr//8) + self.encoder3 = EncoderBlock(self.resnet.layer2, bottom_channel_nr//4) + self.encoder4 = EncoderBlock(self.resnet.layer3, bottom_channel_nr//2) + self.encoder5 = EncoderBlock(self.resnet.layer4, bottom_channel_nr) + + center_block = nn.Sequential( + ConvBn2d(bottom_channel_nr, bottom_channel_nr, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ConvBn2d(bottom_channel_nr, bottom_channel_nr//2, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2) + ) + self.center = EncoderBlock(center_block, bottom_channel_nr//2) + + self.decoder5 = DecoderBlock(bottom_channel_nr + bottom_channel_nr // 2, num_filters * 16, 64) + self.decoder4 = DecoderBlock(64 + bottom_channel_nr // 2, num_filters * 8, 64) + self.decoder3 = DecoderBlock(64 + bottom_channel_nr // 4, num_filters * 4, 64) + self.decoder2 = DecoderBlock(64 + bottom_channel_nr // 8, num_filters * 2, 64) + self.decoder1 = DecoderBlock(64, num_filters, 64) + + self.logit = nn.Sequential( + nn.Conv2d(320, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 1, kernel_size=1, padding=0) + ) + + def forward(self, x): + x = self.encoder1(x) #; print('x:', x.size()) + e2 = self.encoder2(x) #; print('e2:', e2.size()) + e3 = self.encoder3(e2) #; print('e3:', e3.size()) + e4 = self.encoder4(e3) #; print('e4:', e4.size()) + e5 = self.encoder5(e4) #; print('e5:', e5.size()) + + center = self.center(e5) #; print('center:', center.size()) + + d5 = self.decoder5(center, e5) #; print('d5:', d5.size()) + d4 = self.decoder4(d5, e4) #; print('d4:', d4.size()) + d3 = self.decoder3(d4, e3) #; print('d3:', d3.size()) + d2 = self.decoder2(d3, e2) #; print('d2:', d2.size()) + d1 = self.decoder1(d2) #; print('d1:', d1.size()) + + f = torch.cat([ + d1, + F.upsample(d2, scale_factor=2, mode='bilinear', align_corners=False), + F.upsample(d3, scale_factor=4, mode='bilinear', align_corners=False), + F.upsample(d4, scale_factor=8, mode='bilinear', align_corners=False), + F.upsample(d5, scale_factor=16, mode='bilinear', align_corners=False), + ], 1) + + f = F.dropout2d(f, p=self.dropout_2d) + + return self.logit(f), None + + def freeze_bn(self): + '''Freeze BatchNorm layers.''' + for layer in self.modules(): + if isinstance(layer, nn.BatchNorm2d): + layer.eval() + + def get_params(self, base_lr): + group1 = [self.encoder1, self.encoder2, self.encoder3, self.encoder4, self.encoder5] + group2 = [self.decoder1, self.decoder2, self.decoder3, self.decoder4, self.decoder5, self.center, self.logit] + + params1 = [] + for x in group1: + for p in x.parameters(): + params1.append(p) + + param_group1 = {'params': params1, 'lr': base_lr / 5} + + params2 = [] + for x in group2: + for p in x.parameters(): + params2.append(p) + param_group2 = {'params': params2, 'lr': base_lr} + + return [param_group1, param_group2] + +class DecoderBlockV5(nn.Module): + def __init__(self, in_channels_x, in_channels_e, middle_channels, out_channels): + super(DecoderBlockV5, self).__init__() + self.in_channels = in_channels_x + in_channels_e + self.conv1 = ConvBn2d(self.in_channels, middle_channels) + self.conv2 = ConvBn2d(middle_channels, out_channels) + self.deconv = nn.ConvTranspose2d(in_channels_x, in_channels_x, kernel_size=4, stride=2, padding=1) + self.bn = nn.BatchNorm2d(self.in_channels) + self.spatial_gate = SpatialAttentionGate(out_channels) + self.channel_gate = ChannelAttentionGate(out_channels) + + def forward(self, x, e=None): + #x = F.upsample(x, scale_factor=2, mode='bilinear', align_corners=True) + x = self.deconv(x) + if e is not None: + x = torch.cat([x,e], 1) + x = self.bn(x) + + x = F.relu(self.conv1(x), inplace=True) + x = F.relu(self.conv2(x), inplace=True) + + g1 = self.spatial_gate(x) + g2 = self.channel_gate(x) + x = x*g1 + x*g2 + + return x + + + +class UNetResNetV5(nn.Module): + def __init__(self, encoder_depth, num_classes=1, num_filters=32, dropout_2d=0.5): + super(UNetResNetV5, self).__init__() + self.name = 'UNetResNetV5_'+str(encoder_depth) + self.num_classes = num_classes + self.dropout_2d = dropout_2d + + self.resnet, bottom_channel_nr = create_resnet(encoder_depth) + + self.encoder1 = EncoderBlock( + nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu), + num_filters*2 + ) + self.encoder2 = EncoderBlock(self.resnet.layer1, bottom_channel_nr//8) + self.encoder3 = EncoderBlock(self.resnet.layer2, bottom_channel_nr//4) + self.encoder4 = EncoderBlock(self.resnet.layer3, bottom_channel_nr//2) + self.encoder5 = EncoderBlock(self.resnet.layer4, bottom_channel_nr) + + center_block = nn.Sequential( + ConvBn2d(bottom_channel_nr, bottom_channel_nr, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ConvBn2d(bottom_channel_nr, bottom_channel_nr//2, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2) + ) + self.center = EncoderBlock(center_block, bottom_channel_nr//2) + + self.decoder5 = DecoderBlockV5(bottom_channel_nr // 2, bottom_channel_nr, num_filters * 16, 64) + self.decoder4 = DecoderBlockV5(64, bottom_channel_nr // 2, num_filters * 8, 64) + self.decoder3 = DecoderBlockV5(64, bottom_channel_nr // 4, num_filters * 4, 64) + self.decoder2 = DecoderBlockV5(64, bottom_channel_nr // 8, num_filters * 2, 64) + self.decoder1 = DecoderBlockV5(64, 0, num_filters, 64) + + self.logit = nn.Sequential( + nn.Conv2d(320, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 1, kernel_size=1, padding=0) + ) + + def forward(self, x): + x = self.encoder1(x) #; print('x:', x.size()) + e2 = self.encoder2(x) #; print('e2:', e2.size()) + e3 = self.encoder3(e2) #; print('e3:', e3.size()) + e4 = self.encoder4(e3) #; print('e4:', e4.size()) + e5 = self.encoder5(e4) #; print('e5:', e5.size()) + + center = self.center(e5) #; print('center:', center.size()) + + d5 = self.decoder5(center, e5) #; print('d5:', d5.size()) + d4 = self.decoder4(d5, e4) #; print('d4:', d4.size()) + d3 = self.decoder3(d4, e3) #; print('d3:', d3.size()) + d2 = self.decoder2(d3, e2) #; print('d2:', d2.size()) + d1 = self.decoder1(d2) #; print('d1:', d1.size()) + + f = torch.cat([ + d1, + F.interpolate(d2, scale_factor=2, mode='bilinear', align_corners=False), + F.interpolate(d3, scale_factor=4, mode='bilinear', align_corners=False), + F.interpolate(d4, scale_factor=8, mode='bilinear', align_corners=False), + F.interpolate(d5, scale_factor=16, mode='bilinear', align_corners=False), + ], 1) + + f = F.dropout2d(f, p=self.dropout_2d) + + return self.logit(f), None + +class UNetResNetV6(nn.Module): + ''' + 1. Remove first pool from UNetResNetV5, such that resolution is doubled + 2. Remove scSE from center block + 3. Increase default dropout + ''' + def __init__(self, encoder_depth, num_filters=32, dropout_2d=0.5): + super(UNetResNetV6, self).__init__() + assert encoder_depth == 34, 'UNetResNetV6: only 34 layers is supported!' + self.name = 'UNetResNetV6_'+str(encoder_depth) + self.dropout_2d = dropout_2d + + self.resnet, bottom_channel_nr = create_resnet(encoder_depth) + + self.encoder1 = EncoderBlock( + nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu), + num_filters*2 + ) + + self.encoder2 = EncoderBlock(self.resnet.layer1, bottom_channel_nr//8) + self.encoder3 = EncoderBlock(self.resnet.layer2, bottom_channel_nr//4) + self.encoder4 = EncoderBlock(self.resnet.layer3, bottom_channel_nr//2) + self.encoder5 = EncoderBlock(self.resnet.layer4, bottom_channel_nr) + + self.center = nn.Sequential( + ConvBn2d(bottom_channel_nr, bottom_channel_nr, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ConvBn2d(bottom_channel_nr, bottom_channel_nr//2, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2) + ) + #self.center = EncoderBlock(center_block, bottom_channel_nr//2) + + self.decoder5 = DecoderBlockV5(bottom_channel_nr // 2, bottom_channel_nr, num_filters * 16, 64) + self.decoder4 = DecoderBlockV5(64, bottom_channel_nr // 2, num_filters * 8, 64) + self.decoder3 = DecoderBlockV5(64, bottom_channel_nr // 4, num_filters * 4, 64) + self.decoder2 = DecoderBlockV5(64, bottom_channel_nr // 8, num_filters * 2, 64) + self.decoder1 = DecoderBlockV5(64, 0, num_filters, 64) + + self.logit = nn.Sequential( + nn.Conv2d(512, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 1, kernel_size=1, padding=0) + ) + + self.logit_image = nn.Sequential( + nn.Linear(512, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1) + ) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + x = self.encoder1(x) #; print('x:', x.size()) + e2 = self.encoder2(x) #; print('e2:', e2.size()) + e3 = self.encoder3(e2) #; print('e3:', e3.size()) + e4 = self.encoder4(e3) #; print('e4:', e4.size()) + e5 = self.encoder5(e4) #; print('e5:', e5.size()) + + center = self.center(e5) #; print('center:', center.size()) + + d5 = self.decoder5(center, e5) #; print('d5:', d5.size()) + d4 = self.decoder4(d5, e4) #; print('d4:', d4.size()) + d3 = self.decoder3(d4, e3) #; print('d3:', d3.size()) + d2 = self.decoder2(d3, e2) #; print('d2:', d2.size()) + #d1 = self.decoder1(d2) ; print('d1:', d1.size()) + + f = torch.cat([ + d2, + F.interpolate(d3, scale_factor=2, mode='bilinear', align_corners=False), + F.interpolate(d4, scale_factor=4, mode='bilinear', align_corners=False), + F.interpolate(d5, scale_factor=8, mode='bilinear', align_corners=False), + F.interpolate(center, scale_factor=16, mode='bilinear', align_corners=False), + ], 1) + + f = F.dropout2d(f, p=self.dropout_2d, training=self.training) + + # empty mask classifier + img_f = F.adaptive_avg_pool2d(e5, 1).view(x.size(0), -1) + img_f = F.dropout(img_f, p=0.5, training=self.training) + img_logit = self.logit_image(img_f).view(-1) + + return self.logit(f), img_logit + + +class DecoderBlockV7(nn.Module): + def __init__(self, in_channels_x, in_channels_e, middle_channels, out_channels): + super(DecoderBlockV7, self).__init__() + self.in_channels = in_channels_x + in_channels_e + self.conv1 = ConvBn2d(self.in_channels, middle_channels) + self.conv2 = ConvBn2d(middle_channels, out_channels) + self.deconv = nn.ConvTranspose2d(in_channels_x, in_channels_x, kernel_size=4, stride=2, padding=1) + self.bn = nn.BatchNorm2d(self.in_channels) + self.spatial_gate = SpatialAttentionGate(out_channels) + self.channel_gate = ChannelAttentionGate(out_channels) + + def forward(self, x, e=None, upsample=True): + #x = F.upsample(x, scale_factor=2, mode='bilinear', align_corners=True) + if upsample: + x = self.deconv(x) + if e is not None: + x = torch.cat([x,e], 1) + x = self.bn(x) + + x = F.relu(self.conv1(x), inplace=True) + x = F.relu(self.conv2(x), inplace=True) + + g1 = self.spatial_gate(x) + g2 = self.channel_gate(x) + x = x*g1 + x*g2 + + return x + +class UNet7(nn.Module): + def __init__(self, encoder_depth, num_classes=1, num_filters=32, dropout_2d=0.5): + super(UNet7, self).__init__() + nf = num_filters + self.name = 'UNet7_'+str(encoder_depth)+'_nf'+str(nf) + self.num_classes = num_classes + self.dropout_2d = dropout_2d + + self.resnet, nbtm = create_resnet(encoder_depth) + + self.encoder1 = EncoderBlock( + nn.Sequential( + nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + ), + 64 + ) + self.encoder2 = EncoderBlock( + nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + self.resnet.layer1, + ), + nbtm//8 + ) + self.encoder3 = EncoderBlock(self.resnet.layer2, nbtm//4) + self.encoder4 = EncoderBlock(self.resnet.layer3, nbtm//2) + self.encoder5 = EncoderBlock(self.resnet.layer4, nbtm) + + center_block = nn.Sequential( + ConvBn2d(nbtm, nbtm, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ConvBn2d(nbtm, nbtm//2, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + #nn.MaxPool2d(kernel_size=2, stride=2) # remove + ) + self.center = EncoderBlock(center_block, nbtm//2) + + self.decoder5 = DecoderBlockV7(nbtm // 2, nbtm, nf * 16, nf*2) + self.decoder4 = DecoderBlockV7(nf*2, nbtm // 2, nf * 8, nf*2) + self.decoder3 = DecoderBlockV7(nf*2, nbtm // 4, nf * 4, nf*2) + self.decoder2 = DecoderBlockV7(nf*2, nbtm // 8, nf * 2, nf*2) + self.decoder1 = DecoderBlockV7(nf*2, 64, nf*2, nf*2) + + self.logit = nn.Sequential( + nn.Conv2d(nf*10, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 1, kernel_size=1, padding=0) + ) + + self.logit_image = nn.Sequential( + nn.Linear(nbtm, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1), + ) + + def forward(self, x): + e1 = self.encoder1(x) #; print('e1:', e1.size()) + e2 = self.encoder2(e1) #; print('e2:', e2.size()) + e3 = self.encoder3(e2) #; print('e3:', e3.size()) + e4 = self.encoder4(e3) #; print('e4:', e4.size()) + e5 = self.encoder5(e4) #; print('e5:', e5.size()) + + center = self.center(e5) #; print('center:', center.size()) + + d5 = self.decoder5(center, e5, upsample=False) #; print('d5:', d5.size()) + d4 = self.decoder4(d5, e4) #; print('d4:', d4.size()) + d3 = self.decoder3(d4, e3) #; print('d3:', d3.size()) + d2 = self.decoder2(d3, e2) #; print('d2:', d2.size()) + d1 = self.decoder1(d2, e1) #; print('d1:', d1.size()) + + f = torch.cat([ + d1, + F.interpolate(d2, scale_factor=2, mode='bilinear', align_corners=False), + F.interpolate(d3, scale_factor=4, mode='bilinear', align_corners=False), + F.interpolate(d4, scale_factor=8, mode='bilinear', align_corners=False), + F.interpolate(d5, scale_factor=16, mode='bilinear', align_corners=False), + ], 1) + + f = F.dropout2d(f, p=self.dropout_2d) + + # empty mask classifier + img_f = F.adaptive_avg_pool2d(e5, 1).view(x.size(0), -1) + img_f = F.dropout(img_f, p=0.5, training=self.training) + img_logit = self.logit_image(img_f).view(-1) + + return self.logit(f), img_logit + + +class UNet8(nn.Module): + def __init__(self, encoder_depth, num_classes=1, num_filters=32, dropout_2d=0.5): + super(UNet8, self).__init__() + nf = num_filters + self.name = 'UNet8_'+str(encoder_depth)+'_nf'+str(nf) + self.num_classes = num_classes + self.dropout_2d = dropout_2d + + self.resnet, nbtm = create_resnet(encoder_depth) + + self.encoder1 = EncoderBlock( + nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu), + 64 + ) + + self.encoder2 = EncoderBlock(self.resnet.layer1, nbtm//8) + self.encoder3 = EncoderBlock(self.resnet.layer2, nbtm//4) + self.encoder4 = EncoderBlock(self.resnet.layer3, nbtm//2) + self.encoder5 = EncoderBlock(self.resnet.layer4, nbtm) + + center_block = nn.Sequential( + ConvBn2d(nbtm, nbtm, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ConvBn2d(nbtm, nbtm//2, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + #nn.MaxPool2d(kernel_size=2, stride=2) # remove + ) + self.center = EncoderBlock(center_block, nbtm//2) + + self.decoder5 = DecoderBlockV7(nbtm // 2, nbtm, nf * 16, nf*2) + self.decoder4 = DecoderBlockV7(nf*2, nbtm // 2, nf * 8, nf*2) + self.decoder3 = DecoderBlockV7(nf*2, nbtm // 4, nf * 4, nf*2) + self.decoder2 = DecoderBlockV7(nf*2, nbtm // 8, nf * 2, nf*2) + self.decoder1 = DecoderBlockV7(nf*2+64, 3, nf*2, nf*2) + + self.logit = nn.Sequential( + nn.Conv2d(nf*10, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 1, kernel_size=1, padding=0) + ) + + self.logit_image = nn.Sequential( + nn.Linear(nbtm, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 1), + ) + + def forward(self, x): + e1 = self.encoder1(x) #; print('e1:', e1.size()) + e2 = self.encoder2(e1) #; print('e2:', e2.size()) + e3 = self.encoder3(e2) #; print('e3:', e3.size()) + e4 = self.encoder4(e3) #; print('e4:', e4.size()) + e5 = self.encoder5(e4) #; print('e5:', e5.size()) + + center = self.center(e5) #; print('center:', center.size()) + + d5 = self.decoder5(center, e5, upsample=False) #; print('d5:', d5.size()) + d4 = self.decoder4(d5, e4) #; print('d4:', d4.size()) + d3 = self.decoder3(d4, e3) #; print('d3:', d3.size()) + d2 = self.decoder2(d3, e2) #; print('d2:', d2.size()) + d1 = self.decoder1(torch.cat([d2, e1], 1), x) #; print('d1:', d1.size()) + + f = torch.cat([ + d1, + F.interpolate(d2, scale_factor=2, mode='bilinear', align_corners=False), + F.interpolate(d3, scale_factor=4, mode='bilinear', align_corners=False), + F.interpolate(d4, scale_factor=8, mode='bilinear', align_corners=False), + F.interpolate(d5, scale_factor=16, mode='bilinear', align_corners=False), + ], 1) + + f = F.dropout2d(f, p=self.dropout_2d) + + # empty mask classifier + img_f = F.adaptive_avg_pool2d(e5, 1).view(x.size(0), -1) + img_f = F.dropout(img_f, p=0.5, training=self.training) + img_logit = self.logit_image(img_f).view(-1) + + return self.logit(f), img_logit + + +def test(): + model = UNet8(50, num_filters=32).cuda() + inputs = torch.randn(2,3,128,128).cuda() + out, _ = model(inputs) + #print(model) + print(out.size(), _.size()) #, cls_taret.size()) + #print(out) + + +if __name__ == '__main__': + test() diff --git a/examples/trials/tgs-salt/postprocessing.py b/examples/trials/tgs-salt/postprocessing.py new file mode 100644 index 0000000000..9da2b8a7e7 --- /dev/null +++ b/examples/trials/tgs-salt/postprocessing.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os +import numpy as np +import pandas as pd +from scipy import ndimage as ndi +import cv2 + +from utils import get_crop_pad_sequence, run_length_decoding +import settings + +def resize_image(image, target_size): + resized_image = cv2.resize(image, target_size) + return resized_image + +def crop_image(image, target_size): + top_crop, right_crop, bottom_crop, left_crop = get_crop_pad_sequence(image.shape[0] - target_size[0], + image.shape[1] - target_size[1]) + cropped_image = image[top_crop:image.shape[0] - bottom_crop, left_crop:image.shape[1] - right_crop] + return cropped_image + +def binarize(image, threshold): + image_binarized = (image > threshold).astype(np.uint8) + return image_binarized + +def save_pseudo_label_masks(submission_file): + df = pd.read_csv(submission_file, na_filter=False) + print(df.head()) + + img_dir = os.path.join(settings.TEST_DIR, 'masks') + + for i, row in enumerate(df.values): + decoded_mask = run_length_decoding(row[1], (101,101)) + filename = os.path.join(img_dir, '{}.png'.format(row[0])) + rgb_mask = cv2.cvtColor(decoded_mask,cv2.COLOR_GRAY2RGB) + print(filename) + cv2.imwrite(filename, decoded_mask) + if i % 100 == 0: + print(i) + + + +if __name__ == '__main__': + save_pseudo_label_masks('V456_ensemble_1011.csv') \ No newline at end of file diff --git a/examples/trials/tgs-salt/predict.py b/examples/trials/tgs-salt/predict.py new file mode 100644 index 0000000000..ff4bb66f43 --- /dev/null +++ b/examples/trials/tgs-salt/predict.py @@ -0,0 +1,223 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os +import glob +import argparse +import numpy as np +import torch +import torch.optim as optim +import torch.nn.functional as F + +import settings +from loader import get_test_loader, add_depth_channel +from models import UNetResNetV4, UNetResNetV5, UNetResNetV6, UNet7, UNet8 +from postprocessing import crop_image, binarize, resize_image +from metrics import intersection_over_union, intersection_over_union_thresholds +from utils import create_submission + +def do_tta_predict(args, model, ckp_path, tta_num=4): + ''' + return 18000x128x128 np array + ''' + model.eval() + preds = [] + meta = None + + # i is tta index, 0: no change, 1: horizon flip, 2: vertical flip, 3: do both + for flip_index in range(tta_num): + print('flip_index:', flip_index) + test_loader = get_test_loader(args.batch_size, index=flip_index, dev_mode=False, pad_mode=args.pad_mode) + meta = test_loader.meta + outputs = None + with torch.no_grad(): + for i, img in enumerate(test_loader): + add_depth_channel(img, args.pad_mode) + img = img.cuda() + output, _ = model(img) + output = torch.sigmoid(output) + if outputs is None: + outputs = output.squeeze() + else: + outputs = torch.cat([outputs, output.squeeze()], 0) + + print('{} / {}'.format(args.batch_size*(i+1), test_loader.num), end='\r') + outputs = outputs.cpu().numpy() + # flip back masks + if flip_index == 1: + outputs = np.flip(outputs, 2) + elif flip_index == 2: + outputs = np.flip(outputs, 1) + elif flip_index == 3: + outputs = np.flip(outputs, 2) + outputs = np.flip(outputs, 1) + #print(outputs.shape) + preds.append(outputs) + + parent_dir = ckp_path+'_out' + if not os.path.exists(parent_dir): + os.makedirs(parent_dir) + np_file = os.path.join(parent_dir, 'pred.npy') + + model_pred_result = np.mean(preds, 0) + np.save(np_file, model_pred_result) + + return model_pred_result, meta + +def predict(args, model, checkpoint, out_file): + print('predicting {}...'.format(checkpoint)) + pred, meta = do_tta_predict(args, model, checkpoint, tta_num=2) + print(pred.shape) + y_pred_test = generate_preds(pred, (settings.ORIG_H, settings.ORIG_W), pad_mode=args.pad_mode) + + submission = create_submission(meta, y_pred_test) + submission.to_csv(out_file, index=None, encoding='utf-8') + + +def ensemble(args, model, checkpoints): + preds = [] + meta = None + for checkpoint in checkpoints: + model.load_state_dict(torch.load(checkpoint)) + model = model.cuda() + print('predicting...', checkpoint) + + pred, meta = do_tta_predict(args, model, checkpoint, tta_num=2) + preds.append(pred) + + y_pred_test = generate_preds(np.mean(preds, 0), (settings.ORIG_H, settings.ORIG_W), args.pad_mode) + + submission = create_submission(meta, y_pred_test) + submission.to_csv(args.sub_file, index=None, encoding='utf-8') + +def ensemble_np(args, np_files, save_np=None): + preds = [] + for np_file in np_files: + pred = np.load(np_file) + print(np_file, pred.shape) + preds.append(pred) + + y_pred_test = generate_preds(np.mean(preds, 0), (settings.ORIG_H, settings.ORIG_W), args.pad_mode) + + if save_np is not None: + np.save(save_np, np.mean(preds, 0)) + + meta = get_test_loader(args.batch_size, index=0, dev_mode=False, pad_mode=args.pad_mode).meta + + submission = create_submission(meta, y_pred_test) + #submission_filepath = 'v4_1378.csv' + submission.to_csv(args.sub_file, index=None, encoding='utf-8') + +def generate_preds(outputs, target_size, pad_mode, threshold=0.5): + preds = [] + + for output in outputs: + #print(output.shape) + if pad_mode == 'resize': + cropped = resize_image(output, target_size=target_size) + else: + cropped = crop_image_softmax(output, target_size=target_size) + pred = binarize(cropped, threshold) + preds.append(pred) + + return preds + + +def ensemble_predict(args): + model = eval(args.model_name)(args.layers, num_filters=args.nf) + + #checkpoints = [ + # r'G:\salt\models\152_new\best_0.pth', r'G:\salt\models\152_new\best_1.pth', + # r'G:\salt\models\152_new\best_2.pth' + #] + #checkpoints = [ LB841 + # r'D:\data\salt\models\UNetResNetV4_34\best_0.pth', r'D:\data\salt\models\UNetResNetV4_34\best_1.pth', + # r'D:\data\salt\models\UNetResNetV4_34\best_2.pth'#, r'D:\data\salt\models\UNetResNetV4_34\best_3.pth' + #] + + # LB861 + #checkpoints = [ + # r'D:\data\salt\models\depths\UNetResNetV4_34\edge\best_0.pth', + # r'D:\data\salt\models\depths\UNetResNetV4_34\edge\best_1.pth', + # r'D:\data\salt\models\depths\UNetResNetV4_34\edge\best_2.pth', + # r'D:\data\salt\models\depths\UNetResNetV4_34\edge\best_3.pth' + #] + + #checkpoints= glob.glob(r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best*.pth') + checkpoints = [ + r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best_5.pth', + r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best_6.pth', + r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best_8.pth', + r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best_9.pth' + ] + print(checkpoints) + #ensemble(checkpoints) + + ensemble(args, model, checkpoints) + +def ensemble_np_results(args): + np_files1 = glob.glob(r'D:\data\salt\models\depths\UNetResNetV5_50\edge\*pth_out\*.npy') + np_files2 = glob.glob(r'D:\data\salt\models\depths\UNetResNetV4_34\edge\*pth_out\*.npy') + np_files3 = glob.glob(r'D:\data\salt\models\depths\UNetResNetV6_34\edge\*pth_out\*.npy') + #np_files4 = glob.glob(r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\*pth_out\*.npy') + #np_files5 = glob.glob(r'D:\data\salt\models\pseudo\UNetResNetV6_34\edge\*pth_out\*.npy') + np_files6 = glob.glob(r'D:\data\salt\models\ensemble\*.npy') + np_files = np_files6 #np_files1 + np_files2 + np_files3 #+ np_files4 + np_files5 + #np_files = + print(np_files) + ensemble_np(args, np_files) #, save_np=os.path.join(settings.MODEL_DIR, 'ensemble', 'v456_lb864.npy')) + +def predict_model(args): + model = eval(args.model_name)(args.layers, num_filters=args.nf) + model_subdir = args.pad_mode + if args.meta_version == 2: + model_subdir = args.pad_mode+'_meta2' + if args.exp_name is None: + model_file = os.path.join(settings.MODEL_DIR, model.name,model_subdir, 'best_{}.pth'.format(args.ifold)) + else: + model_file = os.path.join(settings.MODEL_DIR, args.exp_name, model.name, model_subdir, 'best_{}.pth'.format(args.ifold)) + + if os.path.exists(model_file): + print('loading {}...'.format(model_file)) + model.load_state_dict(torch.load(model_file)) + else: + raise ValueError('model file not found: {}'.format(model_file)) + model = model.cuda() + predict(args, model, model_file, args.sub_file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Salt segmentation') + parser.add_argument('--model_name', required=True, type=str, help='') + parser.add_argument('--layers', default=34, type=int, help='model layers') + parser.add_argument('--nf', default=32, type=int, help='num_filters param for model') + parser.add_argument('--ifold', required=True, type=int, help='kfold indices') + parser.add_argument('--batch_size', default=32, type=int, help='batch_size') + parser.add_argument('--pad_mode', required=True, choices=['reflect', 'edge', 'resize'], help='pad method') + parser.add_argument('--exp_name', default='depths', type=str, help='exp name') + parser.add_argument('--meta_version', default=2, type=int, help='meta version') + parser.add_argument('--sub_file', default='all_ensemble.csv', type=str, help='submission file') + + args = parser.parse_args() + + predict_model(args) + #ensemble_predict(args) + #ensemble_np_results(args) diff --git a/examples/trials/tgs-salt/preprocess.py b/examples/trials/tgs-salt/preprocess.py new file mode 100644 index 0000000000..1f80252089 --- /dev/null +++ b/examples/trials/tgs-salt/preprocess.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os +import pandas as pd +import numpy as np +import json +import torch +import torch.nn as nn +from keras.preprocessing.image import load_img +from sklearn.model_selection import StratifiedKFold +import settings +import utils + +DATA_DIR = settings.DATA_DIR + +def prepare_metadata(): + print('creating metadata') + meta = utils.generate_metadata(train_images_dir=settings.TRAIN_DIR, + test_images_dir=settings.TEST_DIR, + depths_filepath=settings.DEPTHS_FILE + ) + meta.to_csv(settings.META_FILE, index=None) + +def cov_to_class(val): + for i in range(0, 11): + if val * 10 <= i : + return i + +def generate_stratified_metadata(): + train_df = pd.read_csv(os.path.join(DATA_DIR, "train.csv"), index_col="id", usecols=[0]) + depths_df = pd.read_csv(os.path.join(DATA_DIR, "depths.csv"), index_col="id") + train_df = train_df.join(depths_df) + train_df["masks"] = [np.array(load_img(os.path.join(DATA_DIR, "train", "masks", "{}.png".format(idx)), grayscale=True)) / 255 for idx in train_df.index] + train_df["coverage"] = train_df.masks.map(np.sum) / pow(settings.ORIG_H, 2) + train_df["coverage_class"] = train_df.coverage.map(cov_to_class) + train_df["salt_exists"] = train_df.coverage_class.map(lambda x: 0 if x == 0 else 1) + train_df["is_train"] = 1 + train_df["file_path_image"] = train_df.index.map(lambda x: os.path.join(settings.TRAIN_IMG_DIR, '{}.png'.format(x))) + train_df["file_path_mask"] = train_df.index.map(lambda x: os.path.join(settings.TRAIN_MASK_DIR, '{}.png'.format(x))) + + train_df.to_csv(os.path.join(settings.DATA_DIR, 'train_meta2.csv'), + columns=['file_path_image','file_path_mask','is_train','z','salt_exists', 'coverage_class', 'coverage']) + train_splits = {} + + kf = StratifiedKFold(n_splits=10) + for i, (train_index, valid_index) in enumerate(kf.split(train_df.index.values.reshape(-1), train_df.coverage_class.values.reshape(-1))): + train_splits[str(i)] = { + 'train_index': train_index.tolist(), + 'val_index': valid_index.tolist() + } + with open(os.path.join(settings.DATA_DIR, 'train_split.json'), 'w') as f: + json.dump(train_splits, f, indent=4) + + print('done') + + +def test(): + meta = pd.read_csv(settings.META_FILE) + meta_train = meta[meta['is_train'] == 1] + print(type(meta_train)) + + cv = utils.KFoldBySortedValue() + for train_idx, valid_idx in cv.split(meta_train[settings.DEPTH_COLUMN].values.reshape(-1)): + print(len(train_idx), len(valid_idx)) + print(train_idx[:10]) + print(valid_idx[:10]) + #break + + meta_train_split, meta_valid_split = meta_train.iloc[train_idx], meta_train.iloc[valid_idx] + print(type(meta_train_split)) + print(meta_train_split[settings.X_COLUMN].values[:10]) + +if __name__ == '__main__': + #prepare_metadata() + #convert_model2() + #get_mask_existence() + generate_stratified_metadata() + #get_nfold_split2(0) \ No newline at end of file diff --git a/examples/trials/tgs-salt/settings.py b/examples/trials/tgs-salt/settings.py new file mode 100644 index 0000000000..a5d232bb8c --- /dev/null +++ b/examples/trials/tgs-salt/settings.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os + +DATA_DIR = r'/mnt/chicm/data/salt' + +TRAIN_DIR = os.path.join(DATA_DIR, 'train') +TEST_DIR = os.path.join(DATA_DIR, 'test') + +TRAIN_IMG_DIR = os.path.join(TRAIN_DIR, 'images') +TRAIN_MASK_DIR = os.path.join(TRAIN_DIR, 'masks') +TEST_IMG_DIR = os.path.join(TEST_DIR, 'images') + +LABEL_FILE = os.path.join(DATA_DIR, 'train.csv') +DEPTHS_FILE = os.path.join(DATA_DIR, 'depths.csv') +META_FILE = os.path.join(DATA_DIR, 'meta.csv') + +MODEL_DIR = os.path.join(DATA_DIR, 'models') + +ID_COLUMN = 'id' +DEPTH_COLUMN = 'z' +X_COLUMN = 'file_path_image' +Y_COLUMN = 'file_path_mask' + +H = W = 128 +ORIG_H = ORIG_W = 101 \ No newline at end of file diff --git a/examples/trials/tgs-salt/train.py b/examples/trials/tgs-salt/train.py new file mode 100644 index 0000000000..a627bef4c4 --- /dev/null +++ b/examples/trials/tgs-salt/train.py @@ -0,0 +1,258 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os +import argparse +import time + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau + +from loader import get_train_loaders, add_depth_channel +from models import UNetResNetV4, UNetResNetV5, UNetResNetV6 +from lovasz_losses import lovasz_hinge +from focal_loss import FocalLoss2d +from postprocessing import binarize, crop_image, resize_image +from metrics import intersection_over_union, intersection_over_union_thresholds +import settings + +MODEL_DIR = settings.MODEL_DIR +focal_loss2d = FocalLoss2d() + +def weighted_loss(args, output, target, epoch=0): + mask_output, salt_output = output + mask_target, salt_target = target + + lovasz_loss = lovasz_hinge(mask_output, mask_target) + focal_loss = focal_loss2d(mask_output, mask_target) + + focal_weight = 0.2 + + if salt_output is not None and args.train_cls: + salt_loss = F.binary_cross_entropy_with_logits(salt_output, salt_target) + return salt_loss, focal_loss.item(), lovasz_loss.item(), salt_loss.item(), lovasz_loss.item() + focal_loss.item()*focal_weight + + return lovasz_loss+focal_loss*focal_weight, focal_loss.item(), lovasz_loss.item(), 0., lovasz_loss.item() + focal_loss.item()*focal_weight + +def train(args): + print('start training...') + + """@nni.variable(nni.choice('UNetResNetV4', 'UNetResNetV5', 'UNetResNetV6'), name=model_name)""" + model_name = args.model_name + + model = eval(model_name)(args.layers, num_filters=args.nf) + model_subdir = args.pad_mode + if args.meta_version == 2: + model_subdir = args.pad_mode+'_meta2' + if args.exp_name is None: + model_file = os.path.join(MODEL_DIR, model.name,model_subdir, 'best_{}.pth'.format(args.ifold)) + else: + model_file = os.path.join(MODEL_DIR, args.exp_name, model.name, model_subdir, 'best_{}.pth'.format(args.ifold)) + + parent_dir = os.path.dirname(model_file) + if not os.path.exists(parent_dir): + os.makedirs(parent_dir) + + if args.init_ckp is not None: + CKP = args.init_ckp + else: + CKP = model_file + if os.path.exists(CKP): + print('loading {}...'.format(CKP)) + model.load_state_dict(torch.load(CKP)) + model = model.cuda() + + if args.optim == 'Adam': + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0001) + else: + optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001) + + train_loader, val_loader = get_train_loaders(args.ifold, batch_size=args.batch_size, dev_mode=args.dev_mode, \ + pad_mode=args.pad_mode, meta_version=args.meta_version, pseudo_label=args.pseudo, depths=args.depths) + + if args.lrs == 'plateau': + lr_scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=args.factor, patience=args.patience, min_lr=args.min_lr) + else: + lr_scheduler = CosineAnnealingLR(optimizer, args.t_max, eta_min=args.min_lr) + + print('epoch | lr | % | loss | avg | f loss | lovaz | iou | iout | best | time | save | salt |') + + best_iout, _iou, _f, _l, _salt, best_mix_score = validate(args, model, val_loader, args.start_epoch) + print('val | | | | | {:.4f} | {:.4f} | {:.4f} | {:.4f} | {:.4f} | | | {:.4f} |'.format( + _f, _l, _iou, best_iout, best_iout, _salt)) + if args.val: + return + + model.train() + + if args.lrs == 'plateau': + lr_scheduler.step(best_iout) + else: + lr_scheduler.step() + + for epoch in range(args.start_epoch, args.epochs): + train_loss = 0 + + current_lr = get_lrs(optimizer) + bg = time.time() + for batch_idx, data in enumerate(train_loader): + img, target, salt_target = data + if args.depths: + add_depth_channel(img, args.pad_mode) + img, target, salt_target = img.cuda(), target.cuda(), salt_target.cuda() + optimizer.zero_grad() + output, salt_out = model(img) + + loss, *_ = weighted_loss(args, (output, salt_out), (target, salt_target), epoch=epoch) + loss.backward() + + if args.optim == 'Adam' and args.adamw: + wd = 0.0001 + for group in optimizer.param_groups: + for param in group['params']: + param.data = param.data.add(-wd * group['lr'], param.data) + + optimizer.step() + + train_loss += loss.item() + print('\r {:4d} | {:.5f} | {:4d}/{} | {:.4f} | {:.4f} |'.format( + epoch, float(current_lr[0]), args.batch_size*(batch_idx+1), train_loader.num, loss.item(), train_loss/(batch_idx+1)), end='') + + iout, iou, focal_loss, lovaz_loss, salt_loss, mix_score = validate(args, model, val_loader, epoch=epoch) + """@nni.report_intermediate_result(iout)""" + + _save_ckp = '' + if iout > best_iout: + best_iout = iout + torch.save(model.state_dict(), model_file) + _save_ckp = '*' + if args.store_loss_model and mix_score > best_mix_score: + best_mix_score = mix_score + torch.save(model.state_dict(), model_file+'_loss') + _save_ckp += '.' + print(' {:.4f} | {:.4f} | {:.4f} | {:.4f} | {:.4f} | {:.2f} | {:4s} | {:.4f} |'.format( + focal_loss, lovaz_loss, iou, iout, best_iout, (time.time() - bg) / 60, _save_ckp, salt_loss)) + + model.train() + + if args.lrs == 'plateau': + lr_scheduler.step(best_iout) + else: + lr_scheduler.step() + + del model, train_loader, val_loader, optimizer, lr_scheduler + """@nni.report_final_result(best_iout)""" + +def get_lrs(optimizer): + lrs = [] + for pgs in optimizer.state_dict()['param_groups']: + lrs.append(pgs['lr']) + lrs = ['{:.6f}'.format(x) for x in lrs] + return lrs + +def validate(args, model, val_loader, epoch=0, threshold=0.5): + model.eval() + outputs = [] + focal_loss, lovaz_loss, salt_loss, w_loss = 0, 0, 0, 0 + with torch.no_grad(): + for img, target, salt_target in val_loader: + if args.depths: + add_depth_channel(img, args.pad_mode) + img, target, salt_target = img.cuda(), target.cuda(), salt_target.cuda() + output, salt_out = model(img) + + _, floss, lovaz, _salt_loss, _w_loss = weighted_loss(args, (output, salt_out), (target, salt_target), epoch=epoch) + focal_loss += floss + lovaz_loss += lovaz + salt_loss += _salt_loss + w_loss += _w_loss + output = torch.sigmoid(output) + + for o in output.cpu(): + outputs.append(o.squeeze().numpy()) + + n_batches = val_loader.num // args.batch_size if val_loader.num % args.batch_size == 0 else val_loader.num // args.batch_size + 1 + + # y_pred, list of np array, each np array's shape is 101,101 + y_pred = generate_preds(args, outputs, (settings.ORIG_H, settings.ORIG_W), threshold) + + iou_score = intersection_over_union(val_loader.y_true, y_pred) + iout_score = intersection_over_union_thresholds(val_loader.y_true, y_pred) + + return iout_score, iou_score, focal_loss / n_batches, lovaz_loss / n_batches, salt_loss / n_batches, iout_score*4 - w_loss + + +def generate_preds(args, outputs, target_size, threshold=0.5): + preds = [] + + for output in outputs: + if args.pad_mode == 'resize': + cropped = resize_image(output, target_size=target_size) + else: + cropped = crop_image(output, target_size=target_size) + pred = binarize(cropped, threshold) + preds.append(pred) + + return preds + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='TGS Salt segmentation') + parser.add_argument('--layers', default=34, type=int, help='model layers') + parser.add_argument('--nf', default=32, type=int, help='num_filters param for model') + parser.add_argument('--lr', default=0.001, type=float, help='learning rate') + parser.add_argument('--min_lr', default=0.0001, type=float, help='min learning rate') + parser.add_argument('--ifolds', default='0', type=str, help='kfold indices') + parser.add_argument('--batch_size', default=32, type=int, help='batch_size') + parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') + parser.add_argument('--epochs', default=200, type=int, help='epoch') + parser.add_argument('--optim', default='SGD', choices=['SGD', 'Adam'], help='optimizer') + parser.add_argument('--lrs', default='cosine', choices=['cosine', 'plateau'], help='LR sceduler') + parser.add_argument('--patience', default=6, type=int, help='lr scheduler patience') + parser.add_argument('--factor', default=0.5, type=float, help='lr scheduler factor') + parser.add_argument('--t_max', default=15, type=int, help='lr scheduler patience') + parser.add_argument('--pad_mode', default='edge', choices=['reflect', 'edge', 'resize'], help='pad method') + parser.add_argument('--exp_name', default=None, type=str, help='exp name') + parser.add_argument('--model_name', default='UNetResNetV4', type=str, help='') + parser.add_argument('--init_ckp', default=None, type=str, help='resume from checkpoint path') + parser.add_argument('--val', action='store_true') + parser.add_argument('--store_loss_model', action='store_true') + parser.add_argument('--train_cls', action='store_true') + parser.add_argument('--meta_version', default=2, type=int, help='meta version') + parser.add_argument('--pseudo', action='store_true') + parser.add_argument('--depths', action='store_true') + parser.add_argument('--dev_mode', action='store_true') + parser.add_argument('--adamw', action='store_true') + + args = parser.parse_args() + + '''@nni.get_next_parameter()''' + + print(args) + ifolds = [int(x) for x in args.ifolds.split(',')] + print(ifolds) + + for i in ifolds: + args.ifold = i + train(args) diff --git a/examples/trials/tgs-salt/utils.py b/examples/trials/tgs-salt/utils.py new file mode 100644 index 0000000000..67b1867c95 --- /dev/null +++ b/examples/trials/tgs-salt/utils.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os +import json +import sys +import time +import numpy as np +import pandas as pd +from PIL import Image +from tqdm import tqdm +from pycocotools import mask as cocomask +from sklearn.model_selection import KFold + +import settings + +def create_submission(meta, predictions): + output = [] + for image_id, mask in zip(meta['id'].values, predictions): + rle_encoded = ' '.join(str(rle) for rle in run_length_encoding(mask)) + output.append([image_id, rle_encoded]) + + submission = pd.DataFrame(output, columns=['id', 'rle_mask']).astype(str) + return submission + + +def encode_rle(predictions): + return [run_length_encoding(mask) for mask in predictions] + + +def read_masks(img_ids): + masks = [] + for img_id in img_ids: + base_filename = '{}.png'.format(img_id) + mask = Image.open(os.path.join(settings.TRAIN_MASK_DIR, base_filename)) + mask = np.asarray(mask.convert('L').point(lambda x: 0 if x < 128 else 1)).astype(np.uint8) + masks.append(mask) + return masks + + +def run_length_encoding(x): + bs = np.where(x.T.flatten())[0] + + rle = [] + prev = -2 + for b in bs: + if (b > prev + 1): rle.extend((b + 1, 0)) + rle[-1] += 1 + prev = b + return rle + + +def run_length_decoding(mask_rle, shape): + s = mask_rle.split() + starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] + starts -= 1 + ends = starts + lengths + img = np.zeros(shape[1] * shape[0], dtype=np.uint8) + for lo, hi in zip(starts, ends): + img[lo:hi] = 255 + return img.reshape((shape[1], shape[0])).T + +def get_salt_existence(): + train_mask = pd.read_csv(settings.LABEL_FILE) + salt_exists_dict = {} + for row in train_mask.values: + #print(row[1] is np.nan) + salt_exists_dict[row[0]] = 0 if (row[1] is np.nan or len(row[1]) < 1) else 1 + return salt_exists_dict + +def generate_metadata(train_images_dir, test_images_dir, depths_filepath): + depths = pd.read_csv(depths_filepath) + salt_exists_dict = get_salt_existence() + + metadata = {} + for filename in tqdm(os.listdir(os.path.join(train_images_dir, 'images'))): + image_filepath = os.path.join(train_images_dir, 'images', filename) + mask_filepath = os.path.join(train_images_dir, 'masks', filename) + image_id = filename.split('.')[0] + depth = depths[depths['id'] == image_id]['z'].values[0] + + metadata.setdefault('file_path_image', []).append(image_filepath) + metadata.setdefault('file_path_mask', []).append(mask_filepath) + metadata.setdefault('is_train', []).append(1) + metadata.setdefault('id', []).append(image_id) + metadata.setdefault('z', []).append(depth) + metadata.setdefault('salt_exists', []).append(salt_exists_dict[image_id]) + + for filename in tqdm(os.listdir(os.path.join(test_images_dir, 'images'))): + image_filepath = os.path.join(test_images_dir, 'images', filename) + image_id = filename.split('.')[0] + depth = depths[depths['id'] == image_id]['z'].values[0] + + metadata.setdefault('file_path_image', []).append(image_filepath) + metadata.setdefault('file_path_mask', []).append(None) + metadata.setdefault('is_train', []).append(0) + metadata.setdefault('id', []).append(image_id) + metadata.setdefault('z', []).append(depth) + metadata.setdefault('salt_exists', []).append(0) + + return pd.DataFrame(metadata) + +def rle_from_binary(prediction): + prediction = np.asfortranarray(prediction) + return cocomask.encode(prediction) + + +def binary_from_rle(rle): + return cocomask.decode(rle) + + +def get_segmentations(labeled): + nr_true = labeled.max() + segmentations = [] + for i in range(1, nr_true + 1): + msk = labeled == i + segmentation = rle_from_binary(msk.astype('uint8')) + segmentation['counts'] = segmentation['counts'].decode("UTF-8") + segmentations.append(segmentation) + return segmentations + + +def get_crop_pad_sequence(vertical, horizontal): + top = int(vertical / 2) + bottom = vertical - top + right = int(horizontal / 2) + left = horizontal - right + return (top, right, bottom, left) + + +def get_nfold_split(ifold, nfold=10, meta_version=1): + if meta_version == 2: + return get_nfold_split2(ifold, nfold) + + meta = pd.read_csv(settings.META_FILE, na_filter=False) + meta_train = meta[meta['is_train'] == 1] + + kf = KFold(n_splits=nfold) + for i, (train_index, valid_index) in enumerate(kf.split(meta_train[settings.ID_COLUMN].values.reshape(-1))): + if i == ifold: + break + #print(train_index[:10], train_index[-10:]) + #print(valid_index[:10], valid_index[-10:]) + + return meta_train.iloc[train_index], meta_train.iloc[valid_index] + +def get_nfold_split2(ifold, nfold=10): + meta_train = pd.read_csv(os.path.join(settings.DATA_DIR, 'train_meta2.csv')) + + with open(os.path.join(settings.DATA_DIR, 'train_split.json'), 'r') as f: + train_splits = json.load(f) + train_index = train_splits[str(ifold)]['train_index'] + valid_index = train_splits[str(ifold)]['val_index'] + #print(train_index[:10], train_index[-10:]) + #print(valid_index[:10], valid_index[-10:]) + #print(meta_train.iloc[train_index].head()) + + return meta_train.iloc[train_index], meta_train.iloc[valid_index] + + +def get_test_meta(): + meta = pd.read_csv(settings.META_FILE, na_filter=False) + test_meta = meta[meta['is_train'] == 0] + print(len(test_meta.values)) + return test_meta + +if __name__ == '__main__': + #get_test_meta() + get_nfold_split(2) From f9c599a19c8829c7024693de1e6e9e39d957b5b8 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Wed, 31 Oct 2018 16:00:47 +0800 Subject: [PATCH 2/3] updates --- examples/trials/tgs-salt/augmentation.py | 26 ----------------------- examples/trials/tgs-salt/focal_loss.py | 3 --- examples/trials/tgs-salt/predict.py | 27 ++---------------------- examples/trials/tgs-salt/preprocess.py | 4 ---- examples/trials/tgs-salt/utils.py | 8 ------- 5 files changed, 2 insertions(+), 66 deletions(-) diff --git a/examples/trials/tgs-salt/augmentation.py b/examples/trials/tgs-salt/augmentation.py index 633a02cb21..7e558ef130 100644 --- a/examples/trials/tgs-salt/augmentation.py +++ b/examples/trials/tgs-salt/augmentation.py @@ -161,7 +161,6 @@ def test_transform(): img = Image.open(os.path.join(r'D:\data\ship\train_v2', img_id)).convert('RGB') mask = Image.open(os.path.join(r'D:\data\ship\train_masks', img_id)).convert('L').point(lambda x: 0 if x < 128 else 1, 'L') - #trans = RandomResizedCropWithMask(768, scale=(0.6, 1)) trans = Compose([ RandomHFlipWithMask(), RandomVFlipWithMask(), @@ -171,7 +170,6 @@ def test_transform(): ]) trans2 = RandomAffineWithMask(45, (0.2,0.2), (0.9, 1.1)) - #trans = RandomRotateWithMask([0, 90, 180, 270]) trans3, trans4 = get_img_mask_augments(True, 'edge') img, mask = trans4(img, mask) @@ -239,29 +237,5 @@ def test_tta(): img_back = F.to_pil_image(img_np) img_back.show() - -def test_rotate(): - img_f = os.path.join(settings.TEST_IMG_DIR, '0c2637aa9.jpg') - img = Image.open(img_f) - img = img.convert('RGB') - #img_np = np.array(img) - #img_np_r90 = np.rot90(img_np,1) - #img_np_r90 = np.rot90(img_np_r90,3) - #img_2 = F.to_pil_image(img_np_r90) - #img = F.rotate(img, 90, False, False) - #ImageDraw.Draw(img_2) - #img_2.show() - #img.show() - - img_aug = tta_7(img) - #img_aug = tta_7_back(img_aug) - img_aug = tta_back_np(img_aug, 7) - img_aug.show() - - if __name__ == '__main__': - #test_augment() - #test_rotate() - #test_tta() test_transform() - #test_color_trans() \ No newline at end of file diff --git a/examples/trials/tgs-salt/focal_loss.py b/examples/trials/tgs-salt/focal_loss.py index e987ef847e..1ed8887a31 100644 --- a/examples/trials/tgs-salt/focal_loss.py +++ b/examples/trials/tgs-salt/focal_loss.py @@ -72,9 +72,6 @@ def forward(self, logit, target, class_weight=None, type='sigmoid'): if __name__ == '__main__': L = FocalLoss2d() out = torch.randn(2, 3, 3).cuda() - #target = torch.ones(2, 3, 3).cuda() target = (torch.sigmoid(out) > 0.5).float() - #print(target, out) loss = L(out, target) print(loss) - #pass \ No newline at end of file diff --git a/examples/trials/tgs-salt/predict.py b/examples/trials/tgs-salt/predict.py index ff4bb66f43..28a9d1f183 100644 --- a/examples/trials/tgs-salt/predict.py +++ b/examples/trials/tgs-salt/predict.py @@ -123,7 +123,6 @@ def ensemble_np(args, np_files, save_np=None): meta = get_test_loader(args.batch_size, index=0, dev_mode=False, pad_mode=args.pad_mode).meta submission = create_submission(meta, y_pred_test) - #submission_filepath = 'v4_1378.csv' submission.to_csv(args.sub_file, index=None, encoding='utf-8') def generate_preds(outputs, target_size, pad_mode, threshold=0.5): @@ -144,24 +143,6 @@ def generate_preds(outputs, target_size, pad_mode, threshold=0.5): def ensemble_predict(args): model = eval(args.model_name)(args.layers, num_filters=args.nf) - #checkpoints = [ - # r'G:\salt\models\152_new\best_0.pth', r'G:\salt\models\152_new\best_1.pth', - # r'G:\salt\models\152_new\best_2.pth' - #] - #checkpoints = [ LB841 - # r'D:\data\salt\models\UNetResNetV4_34\best_0.pth', r'D:\data\salt\models\UNetResNetV4_34\best_1.pth', - # r'D:\data\salt\models\UNetResNetV4_34\best_2.pth'#, r'D:\data\salt\models\UNetResNetV4_34\best_3.pth' - #] - - # LB861 - #checkpoints = [ - # r'D:\data\salt\models\depths\UNetResNetV4_34\edge\best_0.pth', - # r'D:\data\salt\models\depths\UNetResNetV4_34\edge\best_1.pth', - # r'D:\data\salt\models\depths\UNetResNetV4_34\edge\best_2.pth', - # r'D:\data\salt\models\depths\UNetResNetV4_34\edge\best_3.pth' - #] - - #checkpoints= glob.glob(r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best*.pth') checkpoints = [ r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best_5.pth', r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best_6.pth', @@ -169,7 +150,6 @@ def ensemble_predict(args): r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\best_9.pth' ] print(checkpoints) - #ensemble(checkpoints) ensemble(args, model, checkpoints) @@ -177,13 +157,10 @@ def ensemble_np_results(args): np_files1 = glob.glob(r'D:\data\salt\models\depths\UNetResNetV5_50\edge\*pth_out\*.npy') np_files2 = glob.glob(r'D:\data\salt\models\depths\UNetResNetV4_34\edge\*pth_out\*.npy') np_files3 = glob.glob(r'D:\data\salt\models\depths\UNetResNetV6_34\edge\*pth_out\*.npy') - #np_files4 = glob.glob(r'D:\data\salt\models\pseudo\UNetResNetV4_34\edge\*pth_out\*.npy') - #np_files5 = glob.glob(r'D:\data\salt\models\pseudo\UNetResNetV6_34\edge\*pth_out\*.npy') np_files6 = glob.glob(r'D:\data\salt\models\ensemble\*.npy') - np_files = np_files6 #np_files1 + np_files2 + np_files3 #+ np_files4 + np_files5 - #np_files = + np_files = np_files1 + np_files2 + np_files3 + np_files6 print(np_files) - ensemble_np(args, np_files) #, save_np=os.path.join(settings.MODEL_DIR, 'ensemble', 'v456_lb864.npy')) + ensemble_np(args, np_files) def predict_model(args): model = eval(args.model_name)(args.layers, num_filters=args.nf) diff --git a/examples/trials/tgs-salt/preprocess.py b/examples/trials/tgs-salt/preprocess.py index 1f80252089..f23cb419af 100644 --- a/examples/trials/tgs-salt/preprocess.py +++ b/examples/trials/tgs-salt/preprocess.py @@ -90,8 +90,4 @@ def test(): print(meta_train_split[settings.X_COLUMN].values[:10]) if __name__ == '__main__': - #prepare_metadata() - #convert_model2() - #get_mask_existence() generate_stratified_metadata() - #get_nfold_split2(0) \ No newline at end of file diff --git a/examples/trials/tgs-salt/utils.py b/examples/trials/tgs-salt/utils.py index 67b1867c95..fa8c8bbba5 100644 --- a/examples/trials/tgs-salt/utils.py +++ b/examples/trials/tgs-salt/utils.py @@ -82,7 +82,6 @@ def get_salt_existence(): train_mask = pd.read_csv(settings.LABEL_FILE) salt_exists_dict = {} for row in train_mask.values: - #print(row[1] is np.nan) salt_exists_dict[row[0]] = 0 if (row[1] is np.nan or len(row[1]) < 1) else 1 return salt_exists_dict @@ -157,9 +156,6 @@ def get_nfold_split(ifold, nfold=10, meta_version=1): for i, (train_index, valid_index) in enumerate(kf.split(meta_train[settings.ID_COLUMN].values.reshape(-1))): if i == ifold: break - #print(train_index[:10], train_index[-10:]) - #print(valid_index[:10], valid_index[-10:]) - return meta_train.iloc[train_index], meta_train.iloc[valid_index] def get_nfold_split2(ifold, nfold=10): @@ -169,9 +165,6 @@ def get_nfold_split2(ifold, nfold=10): train_splits = json.load(f) train_index = train_splits[str(ifold)]['train_index'] valid_index = train_splits[str(ifold)]['val_index'] - #print(train_index[:10], train_index[-10:]) - #print(valid_index[:10], valid_index[-10:]) - #print(meta_train.iloc[train_index].head()) return meta_train.iloc[train_index], meta_train.iloc[valid_index] @@ -183,5 +176,4 @@ def get_test_meta(): return test_meta if __name__ == '__main__': - #get_test_meta() get_nfold_split(2) From 0cabe79c0aa26dd58105c0432d7c53d88866ad5c Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Wed, 31 Oct 2018 17:38:20 +0800 Subject: [PATCH 3/3] updates --- examples/trials/{tgs-salt => kaggle-tgs-salt}/README.md | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/augmentation.py | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/config.yml | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/focal_loss.py | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/loader.py | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/lovasz_losses.py | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/metrics.py | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/models.py | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/postprocessing.py | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/predict.py | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/preprocess.py | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/settings.py | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/train.py | 0 examples/trials/{tgs-salt => kaggle-tgs-salt}/utils.py | 0 14 files changed, 0 insertions(+), 0 deletions(-) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/README.md (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/augmentation.py (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/config.yml (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/focal_loss.py (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/loader.py (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/lovasz_losses.py (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/metrics.py (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/models.py (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/postprocessing.py (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/predict.py (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/preprocess.py (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/settings.py (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/train.py (100%) rename examples/trials/{tgs-salt => kaggle-tgs-salt}/utils.py (100%) diff --git a/examples/trials/tgs-salt/README.md b/examples/trials/kaggle-tgs-salt/README.md similarity index 100% rename from examples/trials/tgs-salt/README.md rename to examples/trials/kaggle-tgs-salt/README.md diff --git a/examples/trials/tgs-salt/augmentation.py b/examples/trials/kaggle-tgs-salt/augmentation.py similarity index 100% rename from examples/trials/tgs-salt/augmentation.py rename to examples/trials/kaggle-tgs-salt/augmentation.py diff --git a/examples/trials/tgs-salt/config.yml b/examples/trials/kaggle-tgs-salt/config.yml similarity index 100% rename from examples/trials/tgs-salt/config.yml rename to examples/trials/kaggle-tgs-salt/config.yml diff --git a/examples/trials/tgs-salt/focal_loss.py b/examples/trials/kaggle-tgs-salt/focal_loss.py similarity index 100% rename from examples/trials/tgs-salt/focal_loss.py rename to examples/trials/kaggle-tgs-salt/focal_loss.py diff --git a/examples/trials/tgs-salt/loader.py b/examples/trials/kaggle-tgs-salt/loader.py similarity index 100% rename from examples/trials/tgs-salt/loader.py rename to examples/trials/kaggle-tgs-salt/loader.py diff --git a/examples/trials/tgs-salt/lovasz_losses.py b/examples/trials/kaggle-tgs-salt/lovasz_losses.py similarity index 100% rename from examples/trials/tgs-salt/lovasz_losses.py rename to examples/trials/kaggle-tgs-salt/lovasz_losses.py diff --git a/examples/trials/tgs-salt/metrics.py b/examples/trials/kaggle-tgs-salt/metrics.py similarity index 100% rename from examples/trials/tgs-salt/metrics.py rename to examples/trials/kaggle-tgs-salt/metrics.py diff --git a/examples/trials/tgs-salt/models.py b/examples/trials/kaggle-tgs-salt/models.py similarity index 100% rename from examples/trials/tgs-salt/models.py rename to examples/trials/kaggle-tgs-salt/models.py diff --git a/examples/trials/tgs-salt/postprocessing.py b/examples/trials/kaggle-tgs-salt/postprocessing.py similarity index 100% rename from examples/trials/tgs-salt/postprocessing.py rename to examples/trials/kaggle-tgs-salt/postprocessing.py diff --git a/examples/trials/tgs-salt/predict.py b/examples/trials/kaggle-tgs-salt/predict.py similarity index 100% rename from examples/trials/tgs-salt/predict.py rename to examples/trials/kaggle-tgs-salt/predict.py diff --git a/examples/trials/tgs-salt/preprocess.py b/examples/trials/kaggle-tgs-salt/preprocess.py similarity index 100% rename from examples/trials/tgs-salt/preprocess.py rename to examples/trials/kaggle-tgs-salt/preprocess.py diff --git a/examples/trials/tgs-salt/settings.py b/examples/trials/kaggle-tgs-salt/settings.py similarity index 100% rename from examples/trials/tgs-salt/settings.py rename to examples/trials/kaggle-tgs-salt/settings.py diff --git a/examples/trials/tgs-salt/train.py b/examples/trials/kaggle-tgs-salt/train.py similarity index 100% rename from examples/trials/tgs-salt/train.py rename to examples/trials/kaggle-tgs-salt/train.py diff --git a/examples/trials/tgs-salt/utils.py b/examples/trials/kaggle-tgs-salt/utils.py similarity index 100% rename from examples/trials/tgs-salt/utils.py rename to examples/trials/kaggle-tgs-salt/utils.py