Skip to content

Commit

Permalink
adjust for fixed train-val split
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Scherr committed Mar 21, 2022
1 parent 39c028e commit 0846601
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 139 deletions.
65 changes: 27 additions & 38 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,24 @@ def main():
# Get arguments
parser = argparse.ArgumentParser(description='Conic Challenge - Evaluation')
parser.add_argument('--model', '-m', required=True, type=str, help='Model to use')
parser.add_argument('--dataset', '-ds', default='conic_patches', type=str, help='"conic_patches" or "lizard"')
parser.add_argument('--batch_size', '-bs', default=8, type=int, help='Batch size')
parser.add_argument('--multi_gpu', '-mgpu', default=False, action='store_true', help='Use multiple GPUs')
parser.add_argument('--save_raw_pred', '-srp', default=False, action='store_true', help='Save raw predictions')
parser.add_argument('--th_cell', '-tc', default=0.07, nargs='+', help='Threshold for adjusting cell size')
parser.add_argument('--th_seed', '-ts', default=0.45, nargs='+', help='Threshold for seeds')
parser.add_argument('--tta', '-tta', default=False, action='store_true', help='Use test-time augmentation')
parser.add_argument('--eval_split', '-es', default=80, type=int, help='Train split in %')
parser.add_argument('--upsample', '-u', default=False, action='store_true', help='Apply rescaling (1.25) for inference')
parser.add_argument('--calc_perfect_class_metric', '-cpcm', default=False, action='store_true',
help='Calculate metric for predicted segmentation and ground truth classification')
args = parser.parse_args()

# Paths
path_data = Path(__file__).parent / 'training_data' / 'conic_fixed_train_valid'
path_models = Path(__file__).parent / 'models'
if args.upsample:
path_train_data = Path(__file__).parent / 'training_data' / args.dataset / 'upsampled'
path_train_data = path_data / 'upsampled'
else:
path_train_data = Path(__file__).parent / 'training_data' / args.dataset / 'original_scale'

if args.dataset == 'lizard':
raise NotImplementedError
path_train_data = path_data / 'original_scale'

# Set device for using CPU or GPU
device, num_gpus = torch.device("cuda" if torch.cuda.is_available() else "cpu"), 1
Expand All @@ -54,38 +50,38 @@ def main():
if args.multi_gpu:
num_gpus = torch.cuda.device_count()

# Check if data to evaluate exists
if not (path_train_data / 'images.npy').is_file() or not (path_train_data / 'labels.npy').is_file() \
or not (path_train_data / 'gts.npy').is_file():
# Check if training data (labels_train.npy) already exist
if not (path_train_data / 'train_labels.npy').is_file() or not (path_train_data / 'valid_labels.npy').is_file():
# Create training sets
print(f'No training data found. Creating training data.\nUse upsampling: {args.upsample}')
if not (path_train_data.parent / 'images.npy').is_file():
raise Exception('images.npy not found in {}'.format(path_train_data.parent))
if not (path_train_data.parent / 'labels.npy').is_file():
raise Exception('labels.npy not found in {}'.format(path_train_data.parent))
if not (path_data / 'train_imgs.npy').is_file():
raise Exception('train_imgs.npy not found in {}'.format(path_data))
if not (path_data / 'train_anns.npy').is_file():
raise Exception('train_anns.npy not found in {}'.format(path_data))
if not (path_data / 'valid_imgs.npy').is_file():
raise Exception('valid_imgs.npy not found in {}'.format(path_data))
if not (path_data / 'valid_anns.npy').is_file():
raise Exception('valid_anns.npy not found in {}'.format(path_data))
path_train_data.mkdir(exist_ok=True)
create_conic_training_sets(path_data=path_train_data.parent,
path_train_data=path_train_data,
upsample=args.upsample)
create_conic_training_sets(path_data=path_data, path_train_data=path_train_data, upsample=args.upsample,
mode='train')
create_conic_training_sets(path_data=path_data, path_train_data=path_train_data, upsample=args.upsample,
mode='valid')

# Load model
model = path_models / "{}.pth".format(args.model)

# Directory for results
path_seg_results = path_train_data / f"{model.stem}_{args.eval_split}"
path_seg_results = path_train_data / f"{model.stem}"
path_seg_results.mkdir(exist_ok=True)
print(f"Evaluation of {model.stem}. Seed thresholds: {args.th_seed}, mask thresholds: {args.th_cell}, "
f"upsampling: {args.upsample}, tta: {args.tta}")

inference_args = deepcopy(args)

if args.dataset == "conic_patches":
dataset = ConicDataset(root_dir=path_train_data,
mode="eval",
transform=ToTensor(min_value=0, max_value=255),
train_split=args.eval_split)
else:
raise NotImplementedError(f'Dataset {args.dataset} not implemented')
dataset = ConicDataset(root_dir=path_train_data,
mode="eval",
transform=ToTensor(min_value=0, max_value=255))

inference_2d(model=model,
dataset=dataset,
Expand Down Expand Up @@ -121,20 +117,13 @@ def main():
else:
metrics_perfect_class = -1

# r2 metric
pred_counts = pd.read_csv(path_seg_results_th / "counts.csv")
gt_counts = dataset.counts
gt_counts = gt_counts.sort_index()
r2 = get_multi_r2(gt_counts, pred_counts)
print(f" R2: {r2}")

result = pd.DataFrame([[args.model, args.dataset, args.upsample, th[0], th[1], metrics[0], metrics[1],
metrics_perfect_class, r2, args.tta]],
columns=["model_name", "dataset", "upsampling", "th_cell", "th_seed", "multi_pq+", "pq_metrics_avg",
"multi_pq+_perfect_class", "R2", "tta"])
result = pd.DataFrame([[args.model, args.upsample, th[0], th[1], metrics[0], metrics[1],
metrics_perfect_class, args.tta]],
columns=["model_name", "upsampling", "th_cell", "th_seed", "multi_pq+", "pq_metrics_avg",
"multi_pq+_perfect_class", "tta"])

result.to_csv(Path(__file__).parent / f"scores{args.eval_split}.csv",
header=not (Path(__file__).parent / f"scores{args.eval_split}.csv").exists(),
result.to_csv(Path(__file__).parent / "scores_post-challenge-analysis.csv",
header=not (Path(__file__).parent / "scores_post-challenge-analysis.csv").exists(),
index=False,
mode="a")

Expand Down
72 changes: 18 additions & 54 deletions segmentation/training/cell_segmentation_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import numpy as np
import pandas as pd

from torch.utils.data import Dataset


class ConicDataset(Dataset):
""" Pytorch data set for CoNIC Challenge """

def __init__(self, root_dir, mode='train', transform=lambda x: x, train_split=80):
def __init__(self, root_dir, mode, transform=lambda x: x):
"""
:param root_dir: Directory containing the dataset.
Expand All @@ -16,30 +15,29 @@ def __init__(self, root_dir, mode='train', transform=lambda x: x, train_split=80
:type mode: str
:param transform: transforms.
:type transform:
:param train_split: percent of the data used for training
:type train_split: int
:return: Dict (image, cell_label, border_label, id).
"""

imgs = np.load(root_dir/"images.npy")

if mode in ['train', 'val']:
labels = np.load(root_dir / "labels.npy")
assert imgs.shape[0] == labels.shape[0], "Missmatch between images.npy and labels_train.npy"
counts = pd.read_csv(root_dir / "counts.csv")
if mode == 'train':
self.imgs = np.load(root_dir / "train_images.npy")
self.labels = np.load(root_dir / "train_labels.npy")
# Add some randomness
ids = np.arange(len(self.imgs))
np.random.shuffle(ids)
self.imgs = self.imgs[ids]
self.labels = self.labels[ids]
assert self.imgs.shape[0] == self.labels.shape[0], "Missmatch between images.npy and labels_train.npy"
elif mode == 'val':
self.imgs = np.load(root_dir / "valid_images.npy")
self.labels = np.load(root_dir / "valid_labels.npy")
assert self.imgs.shape[0] == self.labels.shape[0], "Missmatch between images.npy and labels_train.npy"
elif mode == 'eval':
labels = np.load(root_dir / "gts.npy").astype(np.int64) # pytorchs default_colate cannot handle uint16
counts = pd.read_csv(root_dir / "counts.csv")
self.imgs = np.load(root_dir / "valid_images.npy")
self.labels = np.load(root_dir / "valid_gts.npy").astype(np.int64) # pytorchs default_colate cannot handle uint16

self.root_dir = root_dir
self.mode = mode
self.train_split = train_split
self.ids = self.extract_train_val_ids(imgs.shape[0], 0)
self.imgs = imgs[self.ids, ...]
self.len = len(self.ids)
if mode in ['train', 'val', 'eval']:
self.labels = labels[self.ids, ...]
self.counts = self.get_counts(counts=counts)
self.len = len(self.imgs)
self.transform = transform

def __len__(self):
Expand All @@ -48,40 +46,6 @@ def __len__(self):
def __getitem__(self, idx):
sample = {'image': np.copy(self.imgs[idx, ...]),
'label': np.copy(self.labels[idx, ...]),
'id': self.ids[idx]}
'id': idx}
sample = self.transform(sample)
return sample

def extract_train_val_ids(self, n_imgs, seed):
"""
:param n_imgs:
:param seed:
:return:
"""
np.random.seed(seed) # seed numpy to always get the same images for the same seed
ids = np.arange(n_imgs)
np.random.shuffle(ids) # shuffle inplace
if self.mode == "train":
ids = ids[0:int(np.round(len(ids)*self.train_split/100))]
elif self.mode in ["val", "eval"]:
ids = ids[int(np.round(len(ids)*self.train_split/100)):]
else: # use all ids
pass
return ids

def get_counts(self, counts):
"""
:param counts:
:type counts: pandas DataFrame
:return: sorted nuclear composition DataFrame
"""
total_counts = counts.iloc[self.ids].sum(axis=0)
total_counts.name = "counts"
total_counts.to_csv(self.root_dir / f"total_counts_{self.mode}_{self.train_split}.csv", index=False)
counts = counts.iloc[self.ids]
counts = counts.sort_index()
counts.to_csv(self.root_dir / f"counts_{self.mode}_{self.train_split}.csv", index=False)

return counts
41 changes: 19 additions & 22 deletions segmentation/training/create_training_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from segmentation.training.train_data_representations import distance_label


def create_conic_training_sets(path_data, path_train_data, upsample):
def create_conic_training_sets(path_data, path_train_data, upsample, mode):
""" Create training sets for CoNIC Challenge data.
:param path_data: Path to the directory containing the CoNIC Challenge data / training data.
Expand All @@ -16,16 +16,15 @@ def create_conic_training_sets(path_data, path_train_data, upsample):
:type path_data: Pathlib Path object.
:param upsample: Apply upsampling (factor 1.25).
:type upsample: bool
:param mode: 'train' or 'valid'
:type mode: str
:return: None
"""

imgs = np.load(path_data / "images.npy")
gts = np.load(path_data / "labels.npy")
counts = pd.read_csv(path_data / "counts.csv")
print(f"Create data for mode {mode}.")

print("0.1/99.9 percentile channel 0: {}".format(np.percentile(imgs[..., 0], (0.1, 99.9))))
print("0.1/99.9 percentile channel 1: {}".format(np.percentile(imgs[..., 1], (0.1, 99.9))))
print("0.1/99.9 percentile channel 2: {}".format(np.percentile(imgs[..., 2], (0.1, 99.9))))
imgs = np.load(path_data / f"{mode}_imgs.npy")
gts = np.load(path_data / f"{mode}_anns.npy")

if upsample: # results for conic patches in 320-by-320 patches
scale = 1.25
Expand Down Expand Up @@ -72,23 +71,21 @@ def create_conic_training_sets(path_data, path_train_data, upsample):
imgs = np.delete(imgs, np.array(slice_ids), axis=0)
labels_train = np.delete(labels_train, np.array(slice_ids), axis=0)
gts = np.delete(gts, np.array(slice_ids), axis=0)
counts = counts.drop(slice_ids)

np.save(path_train_data / "images.npy", imgs)
np.save(path_train_data / "labels.npy", labels_train)
np.save(path_train_data / "gts.npy", gts)
counts.to_csv(path_train_data / "counts.csv", index=False)
np.save(path_train_data / f"{mode}_images.npy", imgs)
np.save(path_train_data / f"{mode}_labels.npy", labels_train)
np.save(path_train_data / f"{mode}_gts.npy", gts)

# save tiffs for imagej visualization
tifffile.imsave(path_train_data / "labels_channel_0.tiff", labels_train[..., 0])
tifffile.imsave(path_train_data / "labels_channel_1.tiff", labels_train[..., 1])
tifffile.imsave(path_train_data / "labels_channel_2.tiff", labels_train[..., 2])
tifffile.imsave(path_train_data / "labels_channel_3.tiff", labels_train[..., 3])
tifffile.imsave(path_train_data / "labels_channel_4.tiff", labels_train[..., 4])
tifffile.imsave(path_train_data / "labels_channel_5.tiff", labels_train[..., 5])
tifffile.imsave(path_train_data / "labels_channel_6.tiff", labels_train[..., 6])
tifffile.imsave(path_train_data / "gts_instance.tiff", gts[..., 0])
tifffile.imsave(path_train_data / "gts_class.tiff", gts[..., 1])
tifffile.imsave(path_train_data / "images.tiff", imgs)
tifffile.imsave(path_train_data / f"{mode}_labels_channel_0.tiff", labels_train[..., 0])
tifffile.imsave(path_train_data / f"{mode}_labels_channel_1.tiff", labels_train[..., 1])
tifffile.imsave(path_train_data / f"{mode}_labels_channel_2.tiff", labels_train[..., 2])
tifffile.imsave(path_train_data / f"{mode}_labels_channel_3.tiff", labels_train[..., 3])
tifffile.imsave(path_train_data / f"{mode}_labels_channel_4.tiff", labels_train[..., 4])
tifffile.imsave(path_train_data / f"{mode}_labels_channel_5.tiff", labels_train[..., 5])
tifffile.imsave(path_train_data / f"{mode}_labels_channel_6.tiff", labels_train[..., 6])
tifffile.imsave(path_train_data / f"{mode}_gts_instance.tiff", gts[..., 0])
tifffile.imsave(path_train_data / f"{mode}_gts_class.tiff", gts[..., 1])
tifffile.imsave(path_train_data / f"{mode}_images.tiff", imgs)

return None
Loading

0 comments on commit 0846601

Please sign in to comment.