diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4f95ae8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,148 @@ +# Custom ignores +# Starting from pip 21.3 setup.py is not needed anymore +# and we rely only on setup.cfg for env +setup.py +deepsea/_version.py + +# Ignore macOS files +.DS_Store + +__pycache__/ +*.pyc +setup_acdc.docx +setup_acdc.pdf + +# Byte-compiled / optimized / DLL files +__pycache__/ +/*.py[cod] +!cellacdc/javabridge/_javabridge.cp37-win_amd64.pyd +!cellacdc/javabridge/_javabridge.cp38-win_amd64.pyd +/*.dll + +*$py.class + +# C extensions +/*.so +!cellacdc/javabridge/_javabridge.cpython-38-x86_64-linux-gnu.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/deepsea/__init__.py b/deepsea/__init__.py new file mode 100644 index 0000000..6899555 --- /dev/null +++ b/deepsea/__init__.py @@ -0,0 +1,8 @@ +try: + from setuptools_scm import get_version + __version__ = get_version(root='..', relative_to=__file__) +except Exception as e: + try: + from ._version import version as __version__ + except ImportError: + __version__ = "not-installed" \ No newline at end of file diff --git a/data.py b/deepsea/data.py similarity index 98% rename from data.py rename to deepsea/data.py index bc24fbe..5388739 100644 --- a/data.py +++ b/deepsea/data.py @@ -1,181 +1,181 @@ -import cv2 -from torch.utils.data import Dataset -from pathlib import Path -import os -from PIL import Image -import numpy as np -from tqdm import tqdm -from scipy import ndimage as ndi - -class BasicSegmentationDataset(Dataset): - def __init__(self, images_dir: str, masks_dir: str,unetwmaps_dir: str, transforms=None, mask_suffix: str = '',if_train_aug=False,train_aug_iter=1): - self.images_dir = Path(images_dir) - self.masks_dir = Path(masks_dir) - self.unetwmaps_dir = Path(unetwmaps_dir) - self.mask_suffix = mask_suffix - self.transforms=transforms - if if_train_aug: - self.ids = [os.path.splitext(file)[0] for file in sorted(os.listdir(images_dir)) if not file.startswith('.')] - tmp=[] - for i in range(train_aug_iter): - tmp+=self.ids - self.ids=tmp - else: - self.ids = [os.path.splitext(file)[0] for file in sorted(os.listdir(images_dir)) if not file.startswith('.')] - if not self.ids: - raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there') - - def __len__(self): - return len(self.ids) - - @classmethod - def preprocess(cls, pil_img, pil_mask,pil_wmap,transforms): - tensor_img,tensor_mask,tensor_wmap=transforms(pil_img,pil_mask,pil_wmap) - return tensor_img,tensor_mask,tensor_wmap - - def __getitem__(self, idx): - name = self.ids[idx] - mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*')) - unetwmap_file = list(self.unetwmaps_dir.glob(name + self.mask_suffix + '.*')) - img_file = list(self.images_dir.glob(name + '.*')) - - assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' - assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' - - mask = cv2.imread(mask_file[0].as_posix())[:, :, 0] > 0 - wmap=cv2.imread(unetwmap_file[0].as_posix())[:, :, 0] - - mask=mask & (wmap<255) - mask = Image.fromarray(np.uint8(mask.astype('float32') * 255)) - wmap = Image.fromarray(np.uint8(wmap)) - - img = Image.open(img_file[0]) - - - assert img.size == mask.size, \ - 'Image and mask {name} should be the same size, but are {img.size} and {mask.size}' - - tensor_img,tensor_mask,tensor_wmap = self.preprocess(img, mask,wmap,self.transforms) - - return { - 'image': tensor_img, - 'mask': tensor_mask, - 'wmap': tensor_wmap - } - -class BasicTrackerDataset(Dataset): - def __init__(self, data_dir: str, transforms=None, mask_suffix: str = '_cell_area_masked',label_suffix: str = '_cell_pos_labels',window_size=70,if_train_aug=False,train_aug_iter=1,if_test=False): - self.mask_suffix = mask_suffix - self.transforms=transforms - self.ids={} - cell_id=0 - print('INFO: Read subfolders of video sequences and prepare training dataset ...') - print('INFO: Wait until finished, it takes relatively long time depending on dataset size and complexity ...') - for subfolder in tqdm(sorted(os.listdir(data_dir))): - image_list=sorted(os.listdir(os.path.join(data_dir,subfolder,'images'))) - for idx in range(len(image_list)-1): - img_curr_name = image_list[idx+1] - image_curr = cv2.imread(os.path.join(data_dir, subfolder, 'images', img_curr_name))[:, :, 0] - mask_curr = cv2.imread(os.path.join(data_dir, subfolder, 'masks', img_curr_name.replace('.png', mask_suffix + '.png')))[:,:, 0] - - img_prev_name=image_list[idx] - image_prev = cv2.imread(os.path.join(data_dir, subfolder, 'images', img_prev_name))[:, :, 0] - mask_prev=cv2.imread(os.path.join(data_dir,subfolder,'masks',img_prev_name.replace('.png',mask_suffix+'.png')))[:, :, 0] - lines=open(os.path.join(data_dir,subfolder,'labels',img_prev_name.replace('.png',label_suffix+'.txt'))).readlines() - lines.pop(0) - labels_prev,centroids_prev=[],[] - for line in lines: - line_info=line.replace('\n','').split('\t') - labels_prev.append(line_info[0]) - centroids_prev.append([int(line_info[1]),int(line_info[2])]) - - lines = open(os.path.join(data_dir, subfolder, 'labels',img_curr_name.replace('.png', label_suffix + '.txt'))).readlines() - lines.pop(0) - labels_curr, centroids_curr = [], [] - for line in lines: - line_info = line.replace('\n', '').split('\t') - labels_curr.append(line_info[0]) - centroids_curr.append([int(line_info[1]), int(line_info[2])]) - - markers_prev, num_labels_prev = ndi.label(mask_prev) - markers_curr, num_labels_curr = ndi.label(mask_curr) - for i in range(1,num_labels_prev): - if np.sum(markers_prev==i)<20: - continue - centroid = np.array(ndi.measurements.center_of_mass(mask_prev, markers_prev, i)) - distances = np.array([np.sqrt(np.sum((centroid - np.array([point[1], point[0]])) ** 2)) for point in centroids_prev]) - if len(np.where(distances < 5)[0])==0: - continue - target_cell_prev_idx = np.where(distances < 5)[0][0] - label_prev=labels_prev[target_cell_prev_idx] - if label_prev in labels_curr: - target_cell_curr_idx = [labels_curr.index(label_prev)] - elif label_prev+'_1' in labels_curr and label_prev+'_2' in labels_curr: - target_cell_curr_idx=[labels_curr.index(label_prev+'_1'),labels_curr.index(label_prev+'_2')] - else: - continue - - crop_prev = markers_prev.copy() - crop_prev[crop_prev != i] = 0 - crop_prev[crop_prev>0] = 1 - tmp = np.where(crop_prev > 0) - crop_prev = crop_prev.astype('float32') * image_prev - crop_curr = (mask_curr > 0).astype('float32') * image_curr - window_size = np.max([(np.max(tmp[0]) - np.min(tmp[0])) * 6, (np.max(tmp[1]) - np.min(tmp[1])) * 6]) - crop_prev = Image.fromarray(crop_prev) - crop_prev = crop_prev.crop((int(centroid[1] - window_size / 2), int(centroid[0] - window_size / 2), - int(centroid[1] + window_size / 2), int(centroid[0] + window_size / 2))) - crop_prev = np.uint8(np.asarray(crop_prev)) - - crop_curr = Image.fromarray(crop_curr) - crop_curr = crop_curr.crop((int(centroid[1] - window_size / 2), int(centroid[0] - window_size / 2), - int(centroid[1] + window_size / 2), int(centroid[0] + window_size / 2))) - crop_curr = np.uint8(np.asarray(crop_curr)) - - crop_out=np.zeros(markers_curr.shape) - for idx in target_cell_curr_idx: - traget_marker_curr_val=markers_curr[centroids_curr[idx][1],centroids_curr[idx][0]] - tmp = markers_curr.copy() - tmp[tmp != traget_marker_curr_val] = 0 - tmp[tmp > 0] = 1 - crop_out+=tmp - crop_out[crop_out>0]=1 - crop_out = Image.fromarray(crop_out) - crop_out = crop_out.crop((int(centroid[1] - window_size / 2), int(centroid[0] - window_size / 2), - int(centroid[1] + window_size / 2), int(centroid[0] + window_size / 2))) - crop_out = np.uint8(np.asarray(crop_out).astype('float32') * 255) - - if np.sum(crop_out): - num_copies=1 - if len(target_cell_curr_idx)>1 and if_test==False: - num_copies=50 - if if_train_aug: - for _ in range(num_copies*train_aug_iter): - self.ids[cell_id]=[crop_prev,crop_curr,crop_out] - cell_id += 1 - else: - for _ in range(num_copies): - self.ids[cell_id]=[crop_prev,crop_curr,crop_out] - cell_id += 1 - - - def __len__(self): - return len(self.ids) - - @classmethod - def preprocess(cls, pil_img_prev, pil_curr,pil_mask,transforms): - tensor_img_prev,tensor_img_curr,tensor_mask=transforms(pil_img_prev, pil_curr,pil_mask) - return tensor_img_prev,tensor_img_curr,tensor_mask - - def __getitem__(self, idx): - img_prev,img_curr,mask = self.ids[idx] - img_prev=Image.fromarray(img_prev) - img_curr = Image.fromarray(img_curr) - mask = Image.fromarray(mask) - tensor_img_prev,tensor_img_curr,tensor_mask = self.preprocess(img_prev,img_curr,mask,self.transforms) - - return { - 'image_prev': tensor_img_prev, - 'image_curr': tensor_img_curr, - 'mask': tensor_mask - } +import cv2 +from torch.utils.data import Dataset +from pathlib import Path +import os +from PIL import Image +import numpy as np +from tqdm import tqdm +from scipy import ndimage as ndi + +class BasicSegmentationDataset(Dataset): + def __init__(self, images_dir: str, masks_dir: str,unetwmaps_dir: str, transforms=None, mask_suffix: str = '',if_train_aug=False,train_aug_iter=1): + self.images_dir = Path(images_dir) + self.masks_dir = Path(masks_dir) + self.unetwmaps_dir = Path(unetwmaps_dir) + self.mask_suffix = mask_suffix + self.transforms=transforms + if if_train_aug: + self.ids = [os.path.splitext(file)[0] for file in sorted(os.listdir(images_dir)) if not file.startswith('.')] + tmp=[] + for i in range(train_aug_iter): + tmp+=self.ids + self.ids=tmp + else: + self.ids = [os.path.splitext(file)[0] for file in sorted(os.listdir(images_dir)) if not file.startswith('.')] + if not self.ids: + raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there') + + def __len__(self): + return len(self.ids) + + @classmethod + def preprocess(cls, pil_img, pil_mask,pil_wmap,transforms): + tensor_img,tensor_mask,tensor_wmap=transforms(pil_img,pil_mask,pil_wmap) + return tensor_img,tensor_mask,tensor_wmap + + def __getitem__(self, idx): + name = self.ids[idx] + mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*')) + unetwmap_file = list(self.unetwmaps_dir.glob(name + self.mask_suffix + '.*')) + img_file = list(self.images_dir.glob(name + '.*')) + + assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' + assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' + + mask = cv2.imread(mask_file[0].as_posix())[:, :, 0] > 0 + wmap=cv2.imread(unetwmap_file[0].as_posix())[:, :, 0] + + mask=mask & (wmap<255) + mask = Image.fromarray(np.uint8(mask.astype('float32') * 255)) + wmap = Image.fromarray(np.uint8(wmap)) + + img = Image.open(img_file[0]) + + + assert img.size == mask.size, \ + 'Image and mask {name} should be the same size, but are {img.size} and {mask.size}' + + tensor_img,tensor_mask,tensor_wmap = self.preprocess(img, mask,wmap,self.transforms) + + return { + 'image': tensor_img, + 'mask': tensor_mask, + 'wmap': tensor_wmap + } + +class BasicTrackerDataset(Dataset): + def __init__(self, data_dir: str, transforms=None, mask_suffix: str = '_cell_area_masked',label_suffix: str = '_cell_pos_labels',window_size=70,if_train_aug=False,train_aug_iter=1,if_test=False): + self.mask_suffix = mask_suffix + self.transforms=transforms + self.ids={} + cell_id=0 + print('INFO: Read subfolders of video sequences and prepare training dataset ...') + print('INFO: Wait until finished, it takes relatively long time depending on dataset size and complexity ...') + for subfolder in tqdm(sorted(os.listdir(data_dir))): + image_list=sorted(os.listdir(os.path.join(data_dir,subfolder,'images'))) + for idx in range(len(image_list)-1): + img_curr_name = image_list[idx+1] + image_curr = cv2.imread(os.path.join(data_dir, subfolder, 'images', img_curr_name))[:, :, 0] + mask_curr = cv2.imread(os.path.join(data_dir, subfolder, 'masks', img_curr_name.replace('.png', mask_suffix + '.png')))[:,:, 0] + + img_prev_name=image_list[idx] + image_prev = cv2.imread(os.path.join(data_dir, subfolder, 'images', img_prev_name))[:, :, 0] + mask_prev=cv2.imread(os.path.join(data_dir,subfolder,'masks',img_prev_name.replace('.png',mask_suffix+'.png')))[:, :, 0] + lines=open(os.path.join(data_dir,subfolder,'labels',img_prev_name.replace('.png',label_suffix+'.txt'))).readlines() + lines.pop(0) + labels_prev,centroids_prev=[],[] + for line in lines: + line_info=line.replace('\n','').split('\t') + labels_prev.append(line_info[0]) + centroids_prev.append([int(line_info[1]),int(line_info[2])]) + + lines = open(os.path.join(data_dir, subfolder, 'labels',img_curr_name.replace('.png', label_suffix + '.txt'))).readlines() + lines.pop(0) + labels_curr, centroids_curr = [], [] + for line in lines: + line_info = line.replace('\n', '').split('\t') + labels_curr.append(line_info[0]) + centroids_curr.append([int(line_info[1]), int(line_info[2])]) + + markers_prev, num_labels_prev = ndi.label(mask_prev) + markers_curr, num_labels_curr = ndi.label(mask_curr) + for i in range(1,num_labels_prev): + if np.sum(markers_prev==i)<20: + continue + centroid = np.array(ndi.measurements.center_of_mass(mask_prev, markers_prev, i)) + distances = np.array([np.sqrt(np.sum((centroid - np.array([point[1], point[0]])) ** 2)) for point in centroids_prev]) + if len(np.where(distances < 5)[0])==0: + continue + target_cell_prev_idx = np.where(distances < 5)[0][0] + label_prev=labels_prev[target_cell_prev_idx] + if label_prev in labels_curr: + target_cell_curr_idx = [labels_curr.index(label_prev)] + elif label_prev+'_1' in labels_curr and label_prev+'_2' in labels_curr: + target_cell_curr_idx=[labels_curr.index(label_prev+'_1'),labels_curr.index(label_prev+'_2')] + else: + continue + + crop_prev = markers_prev.copy() + crop_prev[crop_prev != i] = 0 + crop_prev[crop_prev>0] = 1 + tmp = np.where(crop_prev > 0) + crop_prev = crop_prev.astype('float32') * image_prev + crop_curr = (mask_curr > 0).astype('float32') * image_curr + window_size = np.max([(np.max(tmp[0]) - np.min(tmp[0])) * 6, (np.max(tmp[1]) - np.min(tmp[1])) * 6]) + crop_prev = Image.fromarray(crop_prev) + crop_prev = crop_prev.crop((int(centroid[1] - window_size / 2), int(centroid[0] - window_size / 2), + int(centroid[1] + window_size / 2), int(centroid[0] + window_size / 2))) + crop_prev = np.uint8(np.asarray(crop_prev)) + + crop_curr = Image.fromarray(crop_curr) + crop_curr = crop_curr.crop((int(centroid[1] - window_size / 2), int(centroid[0] - window_size / 2), + int(centroid[1] + window_size / 2), int(centroid[0] + window_size / 2))) + crop_curr = np.uint8(np.asarray(crop_curr)) + + crop_out=np.zeros(markers_curr.shape) + for idx in target_cell_curr_idx: + traget_marker_curr_val=markers_curr[centroids_curr[idx][1],centroids_curr[idx][0]] + tmp = markers_curr.copy() + tmp[tmp != traget_marker_curr_val] = 0 + tmp[tmp > 0] = 1 + crop_out+=tmp + crop_out[crop_out>0]=1 + crop_out = Image.fromarray(crop_out) + crop_out = crop_out.crop((int(centroid[1] - window_size / 2), int(centroid[0] - window_size / 2), + int(centroid[1] + window_size / 2), int(centroid[0] + window_size / 2))) + crop_out = np.uint8(np.asarray(crop_out).astype('float32') * 255) + + if np.sum(crop_out): + num_copies=1 + if len(target_cell_curr_idx)>1 and if_test==False: + num_copies=50 + if if_train_aug: + for _ in range(num_copies*train_aug_iter): + self.ids[cell_id]=[crop_prev,crop_curr,crop_out] + cell_id += 1 + else: + for _ in range(num_copies): + self.ids[cell_id]=[crop_prev,crop_curr,crop_out] + cell_id += 1 + + + def __len__(self): + return len(self.ids) + + @classmethod + def preprocess(cls, pil_img_prev, pil_curr,pil_mask,transforms): + tensor_img_prev,tensor_img_curr,tensor_mask=transforms(pil_img_prev, pil_curr,pil_mask) + return tensor_img_prev,tensor_img_curr,tensor_mask + + def __getitem__(self, idx): + img_prev,img_curr,mask = self.ids[idx] + img_prev=Image.fromarray(img_prev) + img_curr = Image.fromarray(img_curr) + mask = Image.fromarray(mask) + tensor_img_prev,tensor_img_curr,tensor_mask = self.preprocess(img_prev,img_curr,mask,self.transforms) + + return { + 'image_prev': tensor_img_prev, + 'image_curr': tensor_img_curr, + 'mask': tensor_mask + } diff --git a/evaluate.py b/deepsea/evaluate.py similarity index 97% rename from evaluate.py rename to deepsea/evaluate.py index cff8c2e..b2360d6 100644 --- a/evaluate.py +++ b/deepsea/evaluate.py @@ -1,277 +1,277 @@ -import numpy as np -from tqdm import tqdm -import torch -import torch.nn.functional as F -from skimage.morphology import remove_small_objects -import copy -import cv2 -import os -from loss import multiclass_dice_coeff -from scipy.optimize import linear_sum_assignment -from utils import visualize_segmentation -from scipy import ndimage as ndi - - -def evaluate_segmentation(net, valid_iterator, device,n_valid_examples,is_avg_prec=False,prec_thresholds=[0.5,0.7,0.9],output_dir=None): - net.eval() - num_val_batches = len(valid_iterator) - dice_score = 0 - mask_list, pred_list,wmap_list = [], [],[] - # iterate over the validation set - # loss=0 - with tqdm(total=n_valid_examples, desc='Segmentation Val round', unit='img') as pbar: - for batch_idx,batch in enumerate(valid_iterator): - images, true_masks,wmap = batch['image'], batch['mask'],batch['wmap'] - images_device = images.to(device=device, dtype=torch.float32) - true_masks = true_masks.to(device=device, dtype=torch.long) - true_masks = torch.squeeze(true_masks, dim=1) - true_masks_copy = copy.deepcopy(true_masks) - true_masks = F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float() - - with torch.no_grad(): - # predict the mask - mask_pred,edge_pred = net(images_device) - - # convert to one-hot format - mask_pred_copy = copy.deepcopy(mask_pred.argmax(dim=1)) - mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() - # compute the Dice score, ignoring background - dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], true_masks[:, 1:, ...], - reduce_batch_first=False) - # loss += dice_loss(F.softmax(mask_pred, dim=1).float(), - # F.one_hot(true_masks_copy, net.n_classes).permute(0, 3, 1, 2).float(), - # multiclass=True) - if is_avg_prec: - true_masks_copy=true_masks_copy.cpu().numpy() - mask_pred_copy = mask_pred_copy.cpu().numpy() - for i in range(true_masks_copy.shape[0]): - mask,_=ndi.label(remove_small_objects(true_masks_copy[i,:,:]>0,min_size=20,connectivity=1)) - mask_list.append(mask) - pred, _ = ndi.label(remove_small_objects(mask_pred_copy[i, :, :]>0,min_size=20,connectivity=1)) - if output_dir: - img=images[i].cpu().numpy()[0, :, :] - img=(img-np.min(img))/(np.max(img)-np.min(img))*255 - overlay_img = visualize_segmentation(pred, inp_img=img, overlay_img=True) - cv2.imwrite(os.path.join(output_dir,'input_segmentation_images','images_{:04d}.png'.format(batch_idx*true_masks_copy.shape[0]+i)),img) - cv2.imwrite(os.path.join(output_dir, 'segmentation_predictions', 'images_{:04d}.png'.format(batch_idx*true_masks_copy.shape[0]+ i)),overlay_img) - - wmap_list.append(wmap[i].cpu().numpy()[0, :, :]) - pred_list.append(pred) - - pbar.update(images.shape[0]) - - - if is_avg_prec: - avg_list=average_precision(mask_list, pred_list, threshold=prec_thresholds)[0] - easy_samples,hard_samples=[],[] - for i,wmap in enumerate(wmap_list): - if np.sum(wmap)==0: - easy_samples.append(avg_list[i]) - else: - hard_samples.append(avg_list[i]) - if output_dir: - np.savetxt(os.path.join(output_dir,'precisions.txt'), avg_list, delimiter=',') - avg_prec=np.mean(avg_list,axis=0) - easy_avg_prec=np.mean(np.array(easy_samples),axis=0) - hard_avg_prec = np.mean(np.array(hard_samples), axis=0) - return dice_score.cpu().numpy() / num_val_batches, avg_prec, easy_avg_prec, hard_avg_prec - - return dice_score.cpu().numpy() / num_val_batches, None,None,None - - -def evaluate_tracker(net, valid_iterator, device,n_valid_examples,is_avg_prec=False,prec_thresholds=[0.5,0.7,0.9],output_dir=None): - net.eval() - num_val_batches = len(valid_iterator) - dice_score = 0 - mask_list, pred_list= [], [] - # iterate over the validation set - with tqdm(total=n_valid_examples, desc='Tracking Val round', unit='img') as pbar: - for batch_idx,batch in enumerate(valid_iterator): - images_prev, images_curr,true_masks = batch['image_prev'], batch['image_curr'],batch['mask'] - - images_device_prev = images_prev.to(device=device, dtype=torch.float32) - images_device_curr = images_curr.to(device=device, dtype=torch.float32) - true_masks = true_masks.to(device=device, dtype=torch.long) - true_masks = torch.squeeze(true_masks, dim=1) - true_masks_copy = copy.deepcopy(true_masks) - true_masks = F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float() - - with torch.no_grad(): - # predict the mask - mask_pred = net(images_device_prev,images_device_curr) - - # convert to one-hot format - mask_pred_copy = copy.deepcopy(mask_pred.argmax(dim=1)) - # edge_pred_copy = copy.deepcopy(edge_pred.argmax(dim=1)) - mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() - # compute the Dice score, ignoring background - dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], true_masks[:, 1:, ...], - reduce_batch_first=False) - if is_avg_prec: - true_masks_copy=true_masks_copy.cpu().numpy() - mask_pred_copy = mask_pred_copy.cpu().numpy() - for i in range(true_masks_copy.shape[0]): - mask,_=ndi.label(remove_small_objects(true_masks_copy[i,:,:]>0,min_size=20,connectivity=1)) - mask_list.append(mask) - pred, _ = ndi.label(remove_small_objects(mask_pred_copy[i, :, :]>0,min_size=20,connectivity=1)) - - if output_dir: - img=images_prev[i].cpu().numpy()[0, :, :] - img=(img-np.min(img))/(np.max(img)-np.min(img))*255 - overlay_img = visualize_segmentation(pred, inp_img=img, overlay_img=True) - cv2.imwrite(os.path.join(output_dir,'input_tracking_images','images_{:04d}.png'.format(batch_idx*true_masks_copy.shape[0]+i)),img) - cv2.imwrite(os.path.join(output_dir, 'tracking_predictions', 'images_{:04d}.png'.format(batch_idx*true_masks_copy.shape[0]+ i)),overlay_img) - - pred_list.append(pred) - - pbar.update(images_prev.shape[0]) - - if is_avg_prec: - avg_list = average_precision(mask_list, pred_list, threshold=prec_thresholds)[0] - single_cell_tracking, mitosis_tracking = [], [] - for i, mask in enumerate(mask_list): - if np.max(mask) >1: - mitosis_tracking.append(avg_list[i]) - else: - single_cell_tracking.append(avg_list[i]) - - if output_dir: - np.savetxt(os.path.join(output_dir, 'precisions.txt'), avg_list, delimiter=',') - avg_prec = np.mean(avg_list, axis=0) - single_cell_avg_prec = np.mean(np.array(single_cell_tracking), axis=0) - mitosis_avg_prec = np.mean(np.array(mitosis_tracking), axis=0) - return dice_score.cpu().numpy() / num_val_batches, avg_prec, single_cell_avg_prec, mitosis_avg_prec - - return dice_score.cpu().numpy() / num_val_batches, None, None, None - -def _label_overlap(x, y): - """ fast function to get pixel overlaps between masks in x and y - - Parameters - ------------ - - x: ND-array, int - where 0=NO masks; 1,2... are mask labels - y: ND-array, int - where 0=NO masks; 1,2... are mask labels - - Returns - ------------ - - overlap: ND-array, int - matrix of pixel overlaps of size [x.max()+1, y.max()+1] - - """ - x = x.ravel() - y = y.ravel() - overlap = np.zeros((1 + x.max(), 1 + y.max()), dtype=np.uint) - for i in range(len(x)): - overlap[x[i], y[i]] += 1 - return overlap - -def _intersection_over_union(masks_true, masks_pred): - """ intersection over union of all mask pairs - - Parameters - ------------ - - masks_true: ND-array, int - ground truth masks, where 0=NO masks; 1,2... are mask labels - masks_pred: ND-array, int - predicted masks, where 0=NO masks; 1,2... are mask labels - - Returns - ------------ - - iou: ND-array, float - matrix of IOU pairs of size [x.max()+1, y.max()+1] - - """ - overlap = _label_overlap(masks_true, masks_pred) - n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) - n_pixels_true = np.sum(overlap, axis=1, keepdims=True) - iou = overlap / (n_pixels_pred + n_pixels_true - overlap+1e-6) - iou[np.isnan(iou)] = 0.0 - return iou - - -def _true_positive(iou, th): - """ true positive at threshold th - - Parameters - ------------ - - iou: float, ND-array - array of IOU pairs - th: float - threshold on IOU for positive label - - Returns - ------------ - - tp: float - number of true positives at threshold - - """ - n_min = min(iou.shape[0], iou.shape[1]) - costs = -(iou >= th).astype(float) - iou / (2 * n_min+1e-6) - true_ind, pred_ind = linear_sum_assignment(costs) - match_ok = iou[true_ind, pred_ind] >= th - tp = match_ok.sum() - return tp - - - -def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]): - """ average precision estimation: AP = TP / (TP + FP + FN) - - This function is based heavily on the *fast* stardist matching functions - (https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py) - - Parameters - ------------ - - masks_true: list of ND-arrays (int) or ND-array (int) - where 0=NO masks; 1,2... are mask labels - masks_pred: list of ND-arrays (int) or ND-array (int) - ND-array (int) where 0=NO masks; 1,2... are mask labels - - Returns - ------------ - - ap: array [len(masks_true) x len(threshold)] - average precision at thresholds - tp: array [len(masks_true) x len(threshold)] - number of true positives at thresholds - fp: array [len(masks_true) x len(threshold)] - number of false positives at thresholds - fn: array [len(masks_true) x len(threshold)] - number of false negatives at thresholds - - """ - not_list = False - if not isinstance(masks_true, list): - masks_true = [masks_true] - masks_pred = [masks_pred] - not_list = True - if not isinstance(threshold, list) and not isinstance(threshold, np.ndarray): - threshold = [threshold] - ap = np.zeros((len(masks_true), len(threshold)), np.float32) - tp = np.zeros((len(masks_true), len(threshold)), np.float32) - fp = np.zeros((len(masks_true), len(threshold)), np.float32) - fn = np.zeros((len(masks_true), len(threshold)), np.float32) - n_true = np.array(list(map(np.max, masks_true))) - n_pred = np.array(list(map(np.max, masks_pred))) - with tqdm(total=len(masks_true), desc='Precision measurement', unit='img') as pbar: - for n in range(len(masks_true)): - if n_pred[n] > 0: - iou = _intersection_over_union(masks_true[n], masks_pred[n])[1:, 1:] - for k, th in enumerate(threshold): - tp[n, k] = _true_positive(iou, th) - fp[n] = n_pred[n] - tp[n] - fn[n] = n_true[n] - tp[n] - ap[n] = tp[n] / (tp[n] + fp[n] + fn[n]+1e-6) - pbar.update(1) - if not_list: - ap, tp, fp, fn = ap[0], tp[0], fp[0], fn[0] +import numpy as np +from tqdm import tqdm +import torch +import torch.nn.functional as F +from skimage.morphology import remove_small_objects +import copy +import cv2 +import os +from loss import multiclass_dice_coeff +from scipy.optimize import linear_sum_assignment +from utils import visualize_segmentation +from scipy import ndimage as ndi + + +def evaluate_segmentation(net, valid_iterator, device,n_valid_examples,is_avg_prec=False,prec_thresholds=[0.5,0.7,0.9],output_dir=None): + net.eval() + num_val_batches = len(valid_iterator) + dice_score = 0 + mask_list, pred_list,wmap_list = [], [],[] + # iterate over the validation set + # loss=0 + with tqdm(total=n_valid_examples, desc='Segmentation Val round', unit='img') as pbar: + for batch_idx,batch in enumerate(valid_iterator): + images, true_masks,wmap = batch['image'], batch['mask'],batch['wmap'] + images_device = images.to(device=device, dtype=torch.float32) + true_masks = true_masks.to(device=device, dtype=torch.long) + true_masks = torch.squeeze(true_masks, dim=1) + true_masks_copy = copy.deepcopy(true_masks) + true_masks = F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float() + + with torch.no_grad(): + # predict the mask + mask_pred,edge_pred = net(images_device) + + # convert to one-hot format + mask_pred_copy = copy.deepcopy(mask_pred.argmax(dim=1)) + mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() + # compute the Dice score, ignoring background + dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], true_masks[:, 1:, ...], + reduce_batch_first=False) + # loss += dice_loss(F.softmax(mask_pred, dim=1).float(), + # F.one_hot(true_masks_copy, net.n_classes).permute(0, 3, 1, 2).float(), + # multiclass=True) + if is_avg_prec: + true_masks_copy=true_masks_copy.cpu().numpy() + mask_pred_copy = mask_pred_copy.cpu().numpy() + for i in range(true_masks_copy.shape[0]): + mask,_=ndi.label(remove_small_objects(true_masks_copy[i,:,:]>0,min_size=20,connectivity=1)) + mask_list.append(mask) + pred, _ = ndi.label(remove_small_objects(mask_pred_copy[i, :, :]>0,min_size=20,connectivity=1)) + if output_dir: + img=images[i].cpu().numpy()[0, :, :] + img=(img-np.min(img))/(np.max(img)-np.min(img))*255 + overlay_img = visualize_segmentation(pred, inp_img=img, overlay_img=True) + cv2.imwrite(os.path.join(output_dir,'input_segmentation_images','images_{:04d}.png'.format(batch_idx*true_masks_copy.shape[0]+i)),img) + cv2.imwrite(os.path.join(output_dir, 'segmentation_predictions', 'images_{:04d}.png'.format(batch_idx*true_masks_copy.shape[0]+ i)),overlay_img) + + wmap_list.append(wmap[i].cpu().numpy()[0, :, :]) + pred_list.append(pred) + + pbar.update(images.shape[0]) + + + if is_avg_prec: + avg_list=average_precision(mask_list, pred_list, threshold=prec_thresholds)[0] + easy_samples,hard_samples=[],[] + for i,wmap in enumerate(wmap_list): + if np.sum(wmap)==0: + easy_samples.append(avg_list[i]) + else: + hard_samples.append(avg_list[i]) + if output_dir: + np.savetxt(os.path.join(output_dir,'precisions.txt'), avg_list, delimiter=',') + avg_prec=np.mean(avg_list,axis=0) + easy_avg_prec=np.mean(np.array(easy_samples),axis=0) + hard_avg_prec = np.mean(np.array(hard_samples), axis=0) + return dice_score.cpu().numpy() / num_val_batches, avg_prec, easy_avg_prec, hard_avg_prec + + return dice_score.cpu().numpy() / num_val_batches, None,None,None + + +def evaluate_tracker(net, valid_iterator, device,n_valid_examples,is_avg_prec=False,prec_thresholds=[0.5,0.7,0.9],output_dir=None): + net.eval() + num_val_batches = len(valid_iterator) + dice_score = 0 + mask_list, pred_list= [], [] + # iterate over the validation set + with tqdm(total=n_valid_examples, desc='Tracking Val round', unit='img') as pbar: + for batch_idx,batch in enumerate(valid_iterator): + images_prev, images_curr,true_masks = batch['image_prev'], batch['image_curr'],batch['mask'] + + images_device_prev = images_prev.to(device=device, dtype=torch.float32) + images_device_curr = images_curr.to(device=device, dtype=torch.float32) + true_masks = true_masks.to(device=device, dtype=torch.long) + true_masks = torch.squeeze(true_masks, dim=1) + true_masks_copy = copy.deepcopy(true_masks) + true_masks = F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float() + + with torch.no_grad(): + # predict the mask + mask_pred = net(images_device_prev,images_device_curr) + + # convert to one-hot format + mask_pred_copy = copy.deepcopy(mask_pred.argmax(dim=1)) + # edge_pred_copy = copy.deepcopy(edge_pred.argmax(dim=1)) + mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() + # compute the Dice score, ignoring background + dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], true_masks[:, 1:, ...], + reduce_batch_first=False) + if is_avg_prec: + true_masks_copy=true_masks_copy.cpu().numpy() + mask_pred_copy = mask_pred_copy.cpu().numpy() + for i in range(true_masks_copy.shape[0]): + mask,_=ndi.label(remove_small_objects(true_masks_copy[i,:,:]>0,min_size=20,connectivity=1)) + mask_list.append(mask) + pred, _ = ndi.label(remove_small_objects(mask_pred_copy[i, :, :]>0,min_size=20,connectivity=1)) + + if output_dir: + img=images_prev[i].cpu().numpy()[0, :, :] + img=(img-np.min(img))/(np.max(img)-np.min(img))*255 + overlay_img = visualize_segmentation(pred, inp_img=img, overlay_img=True) + cv2.imwrite(os.path.join(output_dir,'input_tracking_images','images_{:04d}.png'.format(batch_idx*true_masks_copy.shape[0]+i)),img) + cv2.imwrite(os.path.join(output_dir, 'tracking_predictions', 'images_{:04d}.png'.format(batch_idx*true_masks_copy.shape[0]+ i)),overlay_img) + + pred_list.append(pred) + + pbar.update(images_prev.shape[0]) + + if is_avg_prec: + avg_list = average_precision(mask_list, pred_list, threshold=prec_thresholds)[0] + single_cell_tracking, mitosis_tracking = [], [] + for i, mask in enumerate(mask_list): + if np.max(mask) >1: + mitosis_tracking.append(avg_list[i]) + else: + single_cell_tracking.append(avg_list[i]) + + if output_dir: + np.savetxt(os.path.join(output_dir, 'precisions.txt'), avg_list, delimiter=',') + avg_prec = np.mean(avg_list, axis=0) + single_cell_avg_prec = np.mean(np.array(single_cell_tracking), axis=0) + mitosis_avg_prec = np.mean(np.array(mitosis_tracking), axis=0) + return dice_score.cpu().numpy() / num_val_batches, avg_prec, single_cell_avg_prec, mitosis_avg_prec + + return dice_score.cpu().numpy() / num_val_batches, None, None, None + +def _label_overlap(x, y): + """ fast function to get pixel overlaps between masks in x and y + + Parameters + ------------ + + x: ND-array, int + where 0=NO masks; 1,2... are mask labels + y: ND-array, int + where 0=NO masks; 1,2... are mask labels + + Returns + ------------ + + overlap: ND-array, int + matrix of pixel overlaps of size [x.max()+1, y.max()+1] + + """ + x = x.ravel() + y = y.ravel() + overlap = np.zeros((1 + x.max(), 1 + y.max()), dtype=np.uint) + for i in range(len(x)): + overlap[x[i], y[i]] += 1 + return overlap + +def _intersection_over_union(masks_true, masks_pred): + """ intersection over union of all mask pairs + + Parameters + ------------ + + masks_true: ND-array, int + ground truth masks, where 0=NO masks; 1,2... are mask labels + masks_pred: ND-array, int + predicted masks, where 0=NO masks; 1,2... are mask labels + + Returns + ------------ + + iou: ND-array, float + matrix of IOU pairs of size [x.max()+1, y.max()+1] + + """ + overlap = _label_overlap(masks_true, masks_pred) + n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) + n_pixels_true = np.sum(overlap, axis=1, keepdims=True) + iou = overlap / (n_pixels_pred + n_pixels_true - overlap+1e-6) + iou[np.isnan(iou)] = 0.0 + return iou + + +def _true_positive(iou, th): + """ true positive at threshold th + + Parameters + ------------ + + iou: float, ND-array + array of IOU pairs + th: float + threshold on IOU for positive label + + Returns + ------------ + + tp: float + number of true positives at threshold + + """ + n_min = min(iou.shape[0], iou.shape[1]) + costs = -(iou >= th).astype(float) - iou / (2 * n_min+1e-6) + true_ind, pred_ind = linear_sum_assignment(costs) + match_ok = iou[true_ind, pred_ind] >= th + tp = match_ok.sum() + return tp + + + +def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]): + """ average precision estimation: AP = TP / (TP + FP + FN) + + This function is based heavily on the *fast* stardist matching functions + (https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py) + + Parameters + ------------ + + masks_true: list of ND-arrays (int) or ND-array (int) + where 0=NO masks; 1,2... are mask labels + masks_pred: list of ND-arrays (int) or ND-array (int) + ND-array (int) where 0=NO masks; 1,2... are mask labels + + Returns + ------------ + + ap: array [len(masks_true) x len(threshold)] + average precision at thresholds + tp: array [len(masks_true) x len(threshold)] + number of true positives at thresholds + fp: array [len(masks_true) x len(threshold)] + number of false positives at thresholds + fn: array [len(masks_true) x len(threshold)] + number of false negatives at thresholds + + """ + not_list = False + if not isinstance(masks_true, list): + masks_true = [masks_true] + masks_pred = [masks_pred] + not_list = True + if not isinstance(threshold, list) and not isinstance(threshold, np.ndarray): + threshold = [threshold] + ap = np.zeros((len(masks_true), len(threshold)), np.float32) + tp = np.zeros((len(masks_true), len(threshold)), np.float32) + fp = np.zeros((len(masks_true), len(threshold)), np.float32) + fn = np.zeros((len(masks_true), len(threshold)), np.float32) + n_true = np.array(list(map(np.max, masks_true))) + n_pred = np.array(list(map(np.max, masks_pred))) + with tqdm(total=len(masks_true), desc='Precision measurement', unit='img') as pbar: + for n in range(len(masks_true)): + if n_pred[n] > 0: + iou = _intersection_over_union(masks_true[n], masks_pred[n])[1:, 1:] + for k, th in enumerate(threshold): + tp[n, k] = _true_positive(iou, th) + fp[n] = n_pred[n] - tp[n] + fn[n] = n_true[n] - tp[n] + ap[n] = tp[n] / (tp[n] + fp[n] + fn[n]+1e-6) + pbar.update(1) + if not_list: + ap, tp, fp, fn = ap[0], tp[0], fp[0], fn[0] return ap, tp, fp, fn \ No newline at end of file diff --git a/deepsea/evaluate_test_set_segmentation.py b/deepsea/evaluate_test_set_segmentation.py new file mode 100644 index 0000000..81780a6 --- /dev/null +++ b/deepsea/evaluate_test_set_segmentation.py @@ -0,0 +1,126 @@ +<<<<<<< HEAD:evaluate_test_set_segmentation.py +import torch.utils.data as data +import segmentation_transforms as transforms +import numpy as np +import argparse +import os +import random +from model import DeepSeaSegmentation +from data import BasicSegmentationDataset +import torch +from evaluate import evaluate_segmentation +from utils import get_n_params + +SEED = 1234 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True + + +def test(args,image_size = [383,512],image_means = [0.5],image_stds= [0.5],batch_size=1): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + test_transforms = transforms.Compose([ + transforms.Resize(image_size), + transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Normalize(mean = image_means, + std = image_stds) + ]) + + + test_data = BasicSegmentationDataset(os.path.join(args.test_set_dir, 'images'), os.path.join(args.test_set_dir, 'masks'),os.path.join(args.test_set_dir, 'wmaps'),transforms=test_transforms) + + test_iterator = data.DataLoader(test_data,batch_size = batch_size,shuffle=False) + + model=DeepSeaSegmentation(n_channels=1, n_classes=2, bilinear=True) + print('INFO: Num of model parameters:',get_n_params(model)) + model.load_state_dict(torch.load(args.ckpt_dir)) + model = model.to(device) + + test_score, test_avg_precision,test_easy_avg_precision,test_hard_avg_precision = evaluate_segmentation(model, test_iterator, device,len(test_data),is_avg_prec=True,prec_thresholds=[0.5,0.6,0.7,0.8,0.9],output_dir=args.output_dir) + print('INFO: Dice score:', test_score) + print('INFO: Average precision at ordered thresholds:', test_avg_precision) + print('INFO: Easy samples average precision at ordered thresholds:', test_easy_avg_precision) + print('INFO: Hard samples average precision at ordered thresholds:', test_hard_avg_precision) + +if __name__ == "__main__": + ap = argparse.ArgumentParser() + ap.add_argument("--test_set_dir",required=True,type=str,help="path for the test dataset") + ap.add_argument("--ckpt_dir",required=True,type=str,help="path for the checkpoint of segmentation model to test") + ap.add_argument("--output_dir", required=True, type=str, help="path for saving the test outputs") + + args = ap.parse_args() + + assert os.path.isdir(args.test_set_dir), 'No such file or directory: ' + args.test_set_dir + if not os.path.isdir(os.path.join(args.output_dir,'input_segmentation_images')): + os.makedirs(os.path.join(args.output_dir,'input_segmentation_images')) + if not os.path.isdir(os.path.join(args.output_dir,'segmentation_predictions')): + os.makedirs(os.path.join(args.output_dir,'segmentation_predictions')) + +======= +import torch.utils.data as data +import segmentation_transforms as transforms +import numpy as np +import argparse +import os +import random +from model import DeepSeaSegmentation +from data import BasicSegmentationDataset +import torch +from evaluate import evaluate_segmentation +from utils import get_n_params + +SEED = 1234 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True + + +def test(args,image_size = [383,512],image_means = [0.5],image_stds= [0.5],batch_size=1): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + test_transforms = transforms.Compose([ + transforms.Resize(image_size), + transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Normalize(mean = image_means, + std = image_stds) + ]) + + + test_data = BasicSegmentationDataset(os.path.join(args.test_dir, 'images'), os.path.join(args.test_dir, 'masks'),os.path.join(args.test_dir, 'wmaps'),transforms=test_transforms) + + test_iterator = data.DataLoader(test_data,batch_size = batch_size,shuffle=False) + + model=DeepSeaSegmentation(n_channels=1, n_classes=2, bilinear=True) + print('INFO: Num of model parameters:',get_n_params(model)) + model.load_state_dict(torch.load(args.ckpt_dir)) + model = model.to(device) + + test_score, test_avg_precision,test_easy_avg_precision,test_hard_avg_precision = evaluate_segmentation(model, test_iterator, device,len(test_data),is_avg_prec=True,prec_thresholds=[0.5,0.6,0.7,0.8,0.9],output_dir=args.output_dir) + print('INFO: Dice score:', test_score) + print('INFO: Average precision at ordered thresholds:', test_avg_precision) + print('INFO: Easy samples average precision at ordered thresholds:', test_easy_avg_precision) + print('INFO: Hard samples average precision at ordered thresholds:', test_hard_avg_precision) + +if __name__ == "__main__": + ap = argparse.ArgumentParser() + ap.add_argument("--test_dir",required=True,type=str,help="path for the test dataset") + ap.add_argument("--ckpt_dir",required=True,type=str,help="path for the checkpoint of segmentation model to test") + ap.add_argument("--output_dir", required=True, type=str, help="path for saving the test outputs") + + args = ap.parse_args() + + assert os.path.isdir(args.test_dir), 'No such file or directory: ' + args.test_dir + if not os.path.isdir(args.output_dir+'/input_segmentation_images'): + os.makedirs(args.output_dir+'/input_segmentation_images') + if not os.path.isdir(args.output_dir+'/segmentation_predictions'): + os.makedirs(args.output_dir+'/segmentation_predictions') + +>>>>>>> 688bc95bb88028284a86f389a4e204f93b8d0f83:deepsea/test_segmentation.py + test(args) \ No newline at end of file diff --git a/deepsea/evaluate_test_set_tracking.py b/deepsea/evaluate_test_set_tracking.py new file mode 100644 index 0000000..1dc7b5f --- /dev/null +++ b/deepsea/evaluate_test_set_tracking.py @@ -0,0 +1,132 @@ +<<<<<<< HEAD:evaluate_test_set_tracking.py +import os +import torch.utils.data as data +import tracker_transforms as transforms +import numpy as np +import argparse +import random +from model import DeepSeaTracker +from data import BasicTrackerDataset +import torch +from evaluate import evaluate_tracker +from utils import get_n_params + +SEED = 1234 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True + + + +def test(args,image_size = [128,128],image_means = [0.5],image_stds= [0.5],batch_size=1): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + test_transforms = transforms.Compose([ + transforms.Resize(image_size), + transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Normalize(mean = image_means, + std = image_stds) + ]) + + + test_data = BasicTrackerDataset(os.path.join(args.test_set_dir), transforms=test_transforms,if_test=True) + test_iterator = data.DataLoader(test_data,batch_size = batch_size) + + model=DeepSeaTracker(n_channels=1, n_classes=2, bilinear=True) + print('INFO: Num of model parameters:',get_n_params(model)) + + model.load_state_dict(torch.load(args.ckpt_dir)) + model = model.to(device) + + test_score, test_avg_precision,test_single_cell_avg_precision,test_mitosis_avg_precision = evaluate_tracker(model, test_iterator, device,len(test_data),is_avg_prec=True,prec_thresholds=[0.2,0.6,0.7,0.8,0.9],output_dir=args.output_dir) + + print('INFO: Dice score:', test_score) + print('INFO: Average precision:', test_avg_precision) + print('INFO: Single cells average precision:', test_single_cell_avg_precision) + print('INFO: Mitosis average precision:', test_mitosis_avg_precision) + + +if __name__ == "__main__": + ap = argparse.ArgumentParser() + ap.add_argument("--test_set_dir",required=True,type=str,help="path for the test dataset") + ap.add_argument("--ckpt_dir",required=True,type=str,help="path for the checkpoint of tracking model to test") + ap.add_argument("--output_dir", required=True, type=str, help="path for saving the test outputs") + + args = ap.parse_args() + + assert os.path.isdir(args.test_set_dir), 'No such file or directory: ' + args.test_set_dir + if not os.path.isdir(os.path.join(args.output_dir,'input_crops')): + os.makedirs(os.path.join(args.output_dir,'input_crops')) + if not os.path.isdir(os.path.join(args.output_dir,'tracking_predictions')): + os.makedirs(os.path.join(args.output_dir,'tracking_predictions')) + +======= +import os +import torch.utils.data as data +import tracker_transforms as transforms +import numpy as np +import argparse +import random +from model import DeepSeaTracker +from data import BasicTrackerDataset +import torch +from evaluate import evaluate_tracker +from utils import get_n_params + +SEED = 1234 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True + + + +def test(args,image_size = [128,128],image_means = [0.5],image_stds= [0.5],batch_size=1): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + test_transforms = transforms.Compose([ + transforms.Resize(image_size), + transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Normalize(mean = image_means, + std = image_stds) + ]) + + + test_data = BasicTrackerDataset(os.path.join(args.test_dir), transforms=test_transforms,if_test=True) + test_iterator = data.DataLoader(test_data,batch_size = batch_size) + + model=DeepSeaTracker(n_channels=1, n_classes=2, bilinear=True) + print('INFO: Num of model parameters:',get_n_params(model)) + + model.load_state_dict(torch.load(args.ckpt_dir)) + model = model.to(device) + + test_score, test_avg_precision,test_single_cell_avg_precision,test_mitosis_avg_precision = evaluate_tracker(model, test_iterator, device,len(test_data),is_avg_prec=True,prec_thresholds=[0.2,0.6,0.7,0.8,0.9],output_dir=args.output_dir) + + print('INFO: Dice score:', test_score) + print('INFO: Average precision:', test_avg_precision) + print('INFO: Single cells average precision:', test_single_cell_avg_precision) + print('INFO: Mitosis average precision:', test_mitosis_avg_precision) + + +if __name__ == "__main__": + ap = argparse.ArgumentParser() + ap.add_argument("--test_dir",required=True,type=str,help="path for the test dataset") + ap.add_argument("--ckpt_dir",required=True,type=str,help="path for the checkpoint of tracking model to test") + ap.add_argument("--output_dir", required=True, type=str, help="path for saving the test outputs") + + args = ap.parse_args() + + assert os.path.isdir(args.test_dir), 'No such file or directory: ' + args.test_dir + if not os.path.isdir(args.output_dir+'/input_tracking_images'): + os.makedirs(args.output_dir+'/input_tracking_images') + if not os.path.isdir(args.output_dir+'/tracking_predictions'): + os.makedirs(args.output_dir+'/tracking_predictions') + +>>>>>>> 688bc95bb88028284a86f389a4e204f93b8d0f83:deepsea/test_tracker.py + test(args) \ No newline at end of file diff --git a/loss.py b/deepsea/loss.py similarity index 97% rename from loss.py rename to deepsea/loss.py index 8e68b8d..a6171d0 100644 --- a/loss.py +++ b/deepsea/loss.py @@ -1,37 +1,37 @@ -from torch import Tensor -import torch - -def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): - # Average of Dice coefficient for all classes - assert input.size() == target.size() - dice = 0 - for channel in range(input.shape[1]): - dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) - - return dice / input.shape[1] - -def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): - # Dice loss (objective to minimize) between 0 and 1 - assert input.size() == target.size() - fn = multiclass_dice_coeff if multiclass else dice_coeff - return 1 - fn(input, target, reduce_batch_first=True) - -def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): - # Average of Dice coefficient for all batches, or for a single mask - assert input.size() == target.size() - if input.dim() == 2 and reduce_batch_first: - raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})') - - if input.dim() == 2 or reduce_batch_first: - inter = torch.dot(input.reshape(-1), target.reshape(-1)) - sets_sum = torch.sum(input) + torch.sum(target) - if sets_sum.item() == 0: - sets_sum = 2 * inter - - return (2 * inter + epsilon) / (sets_sum + epsilon) - else: - # compute and average metric for each batch element - dice = 0 - for i in range(input.shape[0]): - dice += dice_coeff(input[i, ...], target[i, ...]) +from torch import Tensor +import torch + +def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): + # Average of Dice coefficient for all classes + assert input.size() == target.size() + dice = 0 + for channel in range(input.shape[1]): + dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) + + return dice / input.shape[1] + +def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): + # Dice loss (objective to minimize) between 0 and 1 + assert input.size() == target.size() + fn = multiclass_dice_coeff if multiclass else dice_coeff + return 1 - fn(input, target, reduce_batch_first=True) + +def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): + # Average of Dice coefficient for all batches, or for a single mask + assert input.size() == target.size() + if input.dim() == 2 and reduce_batch_first: + raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})') + + if input.dim() == 2 or reduce_batch_first: + inter = torch.dot(input.reshape(-1), target.reshape(-1)) + sets_sum = torch.sum(input) + torch.sum(target) + if sets_sum.item() == 0: + sets_sum = 2 * inter + + return (2 * inter + epsilon) / (sets_sum + epsilon) + else: + # compute and average metric for each batch element + dice = 0 + for i in range(input.shape[0]): + dice += dice_coeff(input[i, ...], target[i, ...]) return dice / input.shape[0] \ No newline at end of file diff --git a/measure_MOTA.py b/deepsea/measure_MOTA.py similarity index 97% rename from measure_MOTA.py rename to deepsea/measure_MOTA.py index 93ee825..744506e 100644 --- a/measure_MOTA.py +++ b/deepsea/measure_MOTA.py @@ -108,6 +108,3 @@ def main(args,seg_img_size= [383,512],tracking_image_size = [128,128],image_mean main(args) - - - diff --git a/model.py b/deepsea/model.py similarity index 97% rename from model.py rename to deepsea/model.py index 6db32fe..df0d576 100644 --- a/model.py +++ b/deepsea/model.py @@ -1,109 +1,109 @@ -import torch.nn as nn -import torch -import torch.nn.functional as F - -class DeepSeaUp(nn.Module): - def __init__(self, in_channels, bilinear=True): - super().__init__() - if bilinear: - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - else: - self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) - - def forward(self, x1, x2): - x1 = self.up(x1) - diffY = x2.size()[2] - x1.size()[2] - diffX = x2.size()[3] - x1.size()[3] - x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2]) - x = torch.cat([x2, x1], dim=1) - return x - -class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - if not mid_channels: - mid_channels = out_channels - self.double_conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - nn.Dropout(0.2), - nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), - ) - self.bn=nn.BatchNorm2d(out_channels) - self.relu=nn.ReLU(inplace=True) - self.conv1d = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) - - def forward(self, x): - x1=self.double_conv(x) - x2 = self.conv1d(x) - x=x1+x2 - x=self.bn(x) - x=self.relu(x) - return x - -class DeepSeaSegmentation(nn.Module): - def __init__(self, n_channels, n_classes, bilinear=True): - super(DeepSeaSegmentation, self).__init__() - self.n_channels = n_channels - self.n_classes = n_classes - self.bilinear = bilinear - self.res1=ResBlock(n_channels,64) - self.down1 = nn.MaxPool2d(2) - self.res2 = ResBlock(64, 128) - self.down2 = nn.MaxPool2d(2) - self.res3 = ResBlock(128, 256) - self.up1 = DeepSeaUp(256, 128) - self.res4 = ResBlock(384, 128) - self.up2 = DeepSeaUp(128, 64) - self.res5 = ResBlock(192, 64) - self.conv3 = nn.Conv2d(64, n_classes, kernel_size=1, padding=0) - self.conv4 = nn.Conv2d(64, n_classes, kernel_size=1, padding=0) - - - def forward(self, x): - x1=self.res1(x) - x2=self.down1(x1) - x2 = self.res2(x2) - x3 = self.down2(x2) - x3 = self.res3(x3) - x4=self.up1(x3,x2) - x4 = self.res4(x4) - x5 = self.up2(x4,x1) - x6 = self.res5(x5) - logits=self.conv3(x6) - edges = self.conv4(x6) - return logits,edges - -class DeepSeaTracker(nn.Module): - def __init__(self, n_channels, n_classes, bilinear=True): - super(DeepSeaTracker, self).__init__() - self.n_channels = n_channels - self.n_classes = n_classes - self.bilinear = bilinear - self.res1=ResBlock(n_channels,64) - self.res2 = ResBlock(n_channels, 64) - self.down1 = nn.MaxPool2d(2) - self.res3 = ResBlock(128, 128) - self.down2 = nn.MaxPool2d(2) - self.res4 = ResBlock(128, 256) - self.up1 = DeepSeaUp(256, 128) - self.res5 = ResBlock(384, 128) - self.up2 = DeepSeaUp(128, 64) - self.res6 = ResBlock(256, 64) - self.conv3 = nn.Conv2d(64, n_classes, kernel_size=1, padding=0) - - def forward(self, img_prev,img_curr): - img_prev=self.res1(img_prev) - img_curr = self.res2(img_curr) - x1=torch.cat((img_prev, img_curr), 1) - x2=self.down1(x1) - x2 = self.res3(x2) - x3 = self.down2(x2) - x3 = self.res4(x3) - x4=self.up1(x3,x2) - x4 = self.res5(x4) - x5 = self.up2(x4,x1) - x6=self.res6(x5) - logits=self.conv3(x6) +import torch.nn as nn +import torch +import torch.nn.functional as F + +class DeepSeaUp(nn.Module): + def __init__(self, in_channels, bilinear=True): + super().__init__() + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + else: + self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) + + def forward(self, x1, x2): + x1 = self.up(x1) + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2]) + x = torch.cat([x2, x1], dim=1) + return x + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Dropout(0.2), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + ) + self.bn=nn.BatchNorm2d(out_channels) + self.relu=nn.ReLU(inplace=True) + self.conv1d = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + + def forward(self, x): + x1=self.double_conv(x) + x2 = self.conv1d(x) + x=x1+x2 + x=self.bn(x) + x=self.relu(x) + return x + +class DeepSeaSegmentation(nn.Module): + def __init__(self, n_channels, n_classes, bilinear=True): + super(DeepSeaSegmentation, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + self.res1=ResBlock(n_channels,64) + self.down1 = nn.MaxPool2d(2) + self.res2 = ResBlock(64, 128) + self.down2 = nn.MaxPool2d(2) + self.res3 = ResBlock(128, 256) + self.up1 = DeepSeaUp(256, 128) + self.res4 = ResBlock(384, 128) + self.up2 = DeepSeaUp(128, 64) + self.res5 = ResBlock(192, 64) + self.conv3 = nn.Conv2d(64, n_classes, kernel_size=1, padding=0) + self.conv4 = nn.Conv2d(64, n_classes, kernel_size=1, padding=0) + + + def forward(self, x): + x1=self.res1(x) + x2=self.down1(x1) + x2 = self.res2(x2) + x3 = self.down2(x2) + x3 = self.res3(x3) + x4=self.up1(x3,x2) + x4 = self.res4(x4) + x5 = self.up2(x4,x1) + x6 = self.res5(x5) + logits=self.conv3(x6) + edges = self.conv4(x6) + return logits,edges + +class DeepSeaTracker(nn.Module): + def __init__(self, n_channels, n_classes, bilinear=True): + super(DeepSeaTracker, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + self.res1=ResBlock(n_channels,64) + self.res2 = ResBlock(n_channels, 64) + self.down1 = nn.MaxPool2d(2) + self.res3 = ResBlock(128, 128) + self.down2 = nn.MaxPool2d(2) + self.res4 = ResBlock(128, 256) + self.up1 = DeepSeaUp(256, 128) + self.res5 = ResBlock(384, 128) + self.up2 = DeepSeaUp(128, 64) + self.res6 = ResBlock(256, 64) + self.conv3 = nn.Conv2d(64, n_classes, kernel_size=1, padding=0) + + def forward(self, img_prev,img_curr): + img_prev=self.res1(img_prev) + img_curr = self.res2(img_curr) + x1=torch.cat((img_prev, img_curr), 1) + x2=self.down1(x1) + x2 = self.res3(x2) + x3 = self.down2(x2) + x3 = self.res4(x3) + x4=self.up1(x3,x2) + x4 = self.res5(x4) + x5 = self.up2(x4,x1) + x6=self.res6(x5) + logits=self.conv3(x6) return logits \ No newline at end of file diff --git a/segmentation_transforms.py b/deepsea/segmentation_transforms.py similarity index 97% rename from segmentation_transforms.py rename to deepsea/segmentation_transforms.py index c084470..a6469ce 100644 --- a/segmentation_transforms.py +++ b/deepsea/segmentation_transforms.py @@ -1,1901 +1,1901 @@ -import math -import numbers -import random -import warnings -from collections.abc import Sequence -from typing import Tuple, List, Optional -from PIL import Image -import torch -from torch import Tensor - -try: - import accimage -except ImportError: - accimage = None - - -from torchvision.transforms import functional as F - -__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", - "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", - "RandomHorizontalFlip", "RandomVerticalFlip", "TenCrop", - "LinearTransformation", "ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur"] - - -class Compose: - """Composes several transforms together. This transform does not support torchscript. - Please, see the note below. - - Args: - transforms (list of ``Transform`` objects): list of transforms to compose. - - Example: - >>> transforms.Compose([ - >>> transforms.CenterCrop(10), - >>> transforms.ToTensor(), - >>> ]) - - .. note:: - In order to script the transformations, please use ``torch.nn.Sequential`` as below. - - >>> transforms = torch.nn.Sequential( - >>> transforms.CenterCrop(10), - >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - >>> ) - >>> scripted_transforms = torch.jit.script(transforms) - - Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require - `lambda` functions or ``PIL.Image``. - - """ - - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, img,mask,wmap): - for t in self.transforms: - img,mask,wmap = t(img,mask,wmap) - return img,mask,wmap - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' - return format_string - - -class ToTensor: - """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript. - - Converts a PIL Image or numpy.ndarray (H x W x C) in the range - [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] - if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) - or if the numpy.ndarray has dtype = np.uint8 - - In the other cases, tensors are returned without scaling. - - .. note:: - Because the input image is scaled to [0.0, 1.0], this transformation should not be used when - transforming target image masks. See the `references`_ for implementing the transforms for image masks. - - .. _references: https://github.com/pytorch/vision/tree/master/references/segmentation - """ - - def __call__(self, img,mask,wmap): - """ - Args: - pic (PIL Image or numpy.ndarray): Image to be converted to tensor. - - Returns: - Tensor: Converted image. - """ - return F.to_tensor(img),F.to_tensor(mask),F.to_tensor(wmap) - - def __repr__(self): - return self.__class__.__name__ + '()' - - -class PILToTensor: - """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript. - - Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). - """ - - def __call__(self, pic): - """ - Args: - pic (PIL Image): Image to be converted to tensor. - - Returns: - Tensor: Converted image. - """ - return F.pil_to_tensor(pic) - - def __repr__(self): - return self.__class__.__name__ + '()' - - -class ConvertImageDtype(torch.nn.Module): - """Convert a tensor image to the given ``dtype`` and scale the values accordingly - This function does not support PIL Image. - - Args: - dtype (torch.dtype): Desired data type of the output - - .. note:: - - When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. - If converted back and forth, this mismatch has no effect. - - Raises: - RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as - well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to - overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range - of the integer ``dtype``. - """ - - def __init__(self, dtype: torch.dtype) -> None: - super().__init__() - self.dtype = dtype - - def forward(self, image): - return F.convert_image_dtype(image, self.dtype) - - -class ToPILImage: - """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript. - - Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape - H x W x C to a PIL Image while preserving the value range. - - Args: - mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). - If ``mode`` is ``None`` (default) there are some assumptions made about the input data: - - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. - - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. - - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. - - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, - ``short``). - - .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes - """ - def __init__(self, mode=None): - self.mode = mode - - def __call__(self, pic): - """ - Args: - pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. - - Returns: - PIL Image: Image converted to PIL Image. - - """ - return F.to_pil_image(pic, self.mode) - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - if self.mode is not None: - format_string += 'mode={0}'.format(self.mode) - format_string += ')' - return format_string - - -class Normalize(torch.nn.Module): - """Normalize a tensor image with mean and standard deviation. - This transform does not support PIL Image. - Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` - channels, this transform will normalize each channel of the input - ``torch.*Tensor`` i.e., - ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` - - .. note:: - This transform acts out of place, i.e., it does not mutate the input tensor. - - Args: - mean (sequence): Sequence of means for each channel. - std (sequence): Sequence of standard deviations for each channel. - inplace(bool,optional): Bool to make this operation in-place. - - """ - - def __init__(self, mean, std, inplace=False): - super().__init__() - self.mean = mean - self.std = std - self.inplace = inplace - - def forward(self, img: Tensor,mask: Tensor,wmap: Tensor) -> Tensor: - """ - Args: - tensor (Tensor): Tensor image to be normalized. - - Returns: - Tensor: Normalized Tensor image. - """ - return F.normalize(img, self.mean, self.std, self.inplace),mask,wmap - - def __repr__(self): - return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) - - -class Resize(torch.nn.Module): - """Resize the input image to the given size. - The image can be a PIL Image or a torch Tensor, in which case it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions - - Args: - size (sequence or int): Desired output size. If size is a sequence like - (h, w), output size will be matched to this. If size is an int, - smaller edge of the image will be matched to this number. - i.e, if height > width, then image will be rescaled to - (size * height / width, size). - In torchscript mode padding as single int is not supported, use a tuple or - list of length 1: ``[size, ]``. - interpolation (int, optional): Desired interpolation enum defined by `filters`_. - Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` - and ``PIL.Image.BICUBIC`` are supported. - """ - - def __init__(self, size, img_interpolation=Image.BILINEAR,mask_interpolation=Image.NEAREST): - super().__init__() - if not isinstance(size, (int, Sequence)): - raise TypeError("Size should be int or sequence. Got {}".format(type(size))) - if isinstance(size, Sequence) and len(size) not in (1, 2): - raise ValueError("If size is a sequence, it should have 1 or 2 values") - self.size = size - self.img_interpolation = img_interpolation - self.mask_interpolation = mask_interpolation - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be scaled. - - Returns: - PIL Image or Tensor: Rescaled image. - """ - - - return F.resize(img, self.size, self.img_interpolation),F.resize(mask, self.size, self.mask_interpolation),F.resize(wmap, self.size, self.mask_interpolation) - - def __repr__(self): - interpolate_str = self.interpolation.value - return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format( - self.size, interpolate_str, self.max_size, self.antialias) - - -class Scale(Resize): - """ - Note: This transform is deprecated in favor of Resize. - """ - def __init__(self, *args, **kwargs): - warnings.warn("The use of the transforms.Scale transform is deprecated, " + - "please use transforms.Resize instead.") - super(Scale, self).__init__(*args, **kwargs) - - -class CenterCrop(torch.nn.Module): - """Crops the given image at the center. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. - If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). - """ - - def __init__(self, size): - super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - PIL Image or Tensor: Cropped image. - """ - return F.center_crop(img, self.size) - - def __repr__(self): - return self.__class__.__name__ + '(size={0})'.format(self.size) - - -class Pad(torch.nn.Module): - """Pad the given image on all sides with the given "pad" value. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, - at most 3 leading dimensions for mode edge, - and an arbitrary number of leading dimensions for mode constant - - Args: - padding (int or sequence): Padding on each border. If a single int is provided this - is used to pad all borders. If sequence of length 2 is provided this is the padding - on left/right and top/bottom respectively. If a sequence of length 4 is provided - this is the padding for the left, top, right and bottom borders respectively. - - .. note:: - In torchscript mode padding as single int is not supported, use a sequence of - length 1: ``[padding, ]``. - fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of - length 3, it is used to fill R, G, B channels respectively. - This value is only used when the padding_mode is constant. - Only number is supported for torch Tensor. - Only int or str or tuple value is supported for PIL Image. - padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. - Default is constant. - - - constant: pads with a constant value, this value is specified with fill - - - edge: pads with the last value at the edge of the image. - If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 - - - reflect: pads with reflection of image without repeating the last value on the edge. - For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode - will result in [3, 2, 1, 2, 3, 4, 3, 2] - - - symmetric: pads with reflection of image repeating the last value on the edge. - For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode - will result in [2, 1, 1, 2, 3, 4, 4, 3] - """ - - def __init__(self, padding, fill=0, padding_mode="constant"): - super().__init__() - if not isinstance(padding, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate padding arg") - - if not isinstance(fill, (numbers.Number, str, tuple)): - raise TypeError("Got inappropriate fill arg") - - if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: - raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") - - if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: - raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) - - self.padding = padding - self.fill = fill - self.padding_mode = padding_mode - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be padded. - - Returns: - PIL Image or Tensor: Padded image. - """ - return F.pad(img, self.padding, self.fill, self.padding_mode) - - def __repr__(self): - return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ - format(self.padding, self.fill, self.padding_mode) - - -class Lambda: - """Apply a user-defined lambda as a transform. This transform does not support torchscript. - - Args: - lambd (function): Lambda/function to be used for transform. - """ - - def __init__(self, lambd): - if not callable(lambd): - raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__))) - self.lambd = lambd - - def __call__(self, img): - return self.lambd(img) - - def __repr__(self): - return self.__class__.__name__ + '()' - - -class RandomTransforms: - """Base class for a list of transformations with randomness - - Args: - transforms (sequence): list of transformations - """ - - def __init__(self, transforms): - if not isinstance(transforms, Sequence): - raise TypeError("Argument transforms should be a sequence") - self.transforms = transforms - - def __call__(self, *args, **kwargs): - raise NotImplementedError() - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' - return format_string - - -class RandomApply(torch.nn.Module): - """Apply randomly a list of transformations with a given probability. - - .. note:: - In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of - transforms as shown below: - - >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ - >>> transforms.ColorJitter(), - >>> ]), p=0.3) - >>> scripted_transforms = torch.jit.script(transforms) - - Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require - `lambda` functions or ``PIL.Image``. - - Args: - transforms (sequence or torch.nn.Module): list of transformations - p (float): probability - """ - - def __init__(self, transforms, p=0.5): - super().__init__() - self.transforms = transforms - self.p = p - - def forward(self, img,mask,wmap): - if self.p < torch.rand(1): - return img,mask,wmap - for t in self.transforms: - img,mask,wmap = t(img,mask,wmap) - return img,mask,wmap - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - format_string += '\n p={}'.format(self.p) - for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' - return format_string - - -class RandomOrder(RandomTransforms): - """Apply a list of transformations in a random order. This transform does not support torchscript. - """ - def __call__(self, img,mask,wmap): - order = list(range(len(self.transforms))) - random.shuffle(order) - for i in order: - img,mask,wmap = self.transforms[i](img,mask,wmap) - return img,mask,wmap - - -class RandomChoice(RandomTransforms): - """Apply single transformation randomly picked from a list. This transform does not support torchscript. - """ - def __call__(self, img,mask,wmap): - t = random.choice(self.transforms) - return t(img,mask,wmap) - - -class RandomCrop(torch.nn.Module): - """Crop the given image at a random location. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions, - but if non-constant padding is used, the input is expected to have at most 2 leading dimensions - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). - padding (int or sequence, optional): Optional padding on each border - of the image. Default is None. If a single int is provided this - is used to pad all borders. If sequence of length 2 is provided this is the padding - on left/right and top/bottom respectively. If a sequence of length 4 is provided - this is the padding for the left, top, right and bottom borders respectively. - - .. note:: - In torchscript mode padding as single int is not supported, use a sequence of - length 1: ``[padding, ]``. - pad_if_needed (boolean): It will pad the image if smaller than the - desired size to avoid raising an exception. Since cropping is done - after padding, the padding seems to be done at a random offset. - fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of - length 3, it is used to fill R, G, B channels respectively. - This value is only used when the padding_mode is constant. - Only number is supported for torch Tensor. - Only int or str or tuple value is supported for PIL Image. - padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. - Default is constant. - - - constant: pads with a constant value, this value is specified with fill - - - edge: pads with the last value at the edge of the image. - If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 - - - reflect: pads with reflection of image without repeating the last value on the edge. - For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode - will result in [3, 2, 1, 2, 3, 4, 3, 2] - - - symmetric: pads with reflection of image repeating the last value on the edge. - For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode - will result in [2, 1, 1, 2, 3, 4, 4, 3] - """ - - @staticmethod - def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: - """Get parameters for ``crop`` for a random crop. - - Args: - img (PIL Image or Tensor): Image to be cropped. - output_size (tuple): Expected output size of the crop. - - Returns: - tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. - """ - w, h = F._get_image_size(img) - th, tw = output_size - - if h + 1 < th or w + 1 < tw: - raise ValueError( - "Required crop size {} is larger then input image size {}".format((th, tw), (h, w)) - ) - - if w == tw and h == th: - return 0, 0, h, w - - i = torch.randint(0, h - th + 1, size=(1, )).item() - j = torch.randint(0, w - tw + 1, size=(1, )).item() - return i, j, th, tw - - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): - super().__init__() - - self.size = tuple(_setup_size( - size, error_msg="Please provide only two dimensions (h, w) for size." - )) - - self.padding = padding - self.pad_if_needed = pad_if_needed - self.fill = fill - self.padding_mode = padding_mode - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - PIL Image or Tensor: Cropped image. - """ - if self.padding is not None: - img = F.pad(img, self.padding, self.fill, self.padding_mode) - - width, height = F._get_image_size(img) - # pad the width if needed - if self.pad_if_needed and width < self.size[1]: - padding = [self.size[1] - width, 0] - img = F.pad(img, padding, self.fill, self.padding_mode) - # pad the height if needed - if self.pad_if_needed and height < self.size[0]: - padding = [0, self.size[0] - height] - img = F.pad(img, padding, self.fill, self.padding_mode) - - i, j, h, w = self.get_params(img, self.size) - - return F.crop(img, i, j, h, w),F.crop(mask, i, j, h, w),F.crop(wmap, i, j, h, w) - - def __repr__(self): - return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) - - -class Pass(torch.nn.Module): - """ - No Transforms - """ - - def __init__(self): - super().__init__() - - def forward(self, img,mask,wmap): - return img,mask,wmap - - def __repr__(self): - return self.__class__.__name__ - -class RandomHorizontalFlip(torch.nn.Module): - """Horizontally flip the given image randomly with a given probability. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading - dimensions - - Args: - p (float): probability of the image being flipped. Default value is 0.5 - """ - - def __init__(self, p=0.5): - super().__init__() - self.p = p - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be flipped. - - Returns: - PIL Image or Tensor: Randomly flipped image. - """ - if torch.rand(1) < self.p: - return F.hflip(img),F.hflip(mask),F.hflip(wmap) - return img,mask,wmap - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) - - -class RandomVerticalFlip(torch.nn.Module): - """Vertically flip the given image randomly with a given probability. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading - dimensions - - Args: - p (float): probability of the image being flipped. Default value is 0.5 - """ - - def __init__(self, p=0.5): - super().__init__() - self.p = p - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be flipped. - - Returns: - PIL Image or Tensor: Randomly flipped image. - """ - if torch.rand(1) < self.p: - return F.vflip(img),F.vflip(mask),F.vflip(wmap) - return img,mask,wmap - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) - - -class RandomPerspective(torch.nn.Module): - """Performs a random perspective transformation of the given image with a given probability. - The image can be a PIL Image or a Tensor, in which case it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. - - Args: - distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. - Default is 0.5. - p (float): probability of the image being transformed. Default is 0.5. - interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and - ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated - image. If int or float, the value is used for all bands respectively. Default is 0. - This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor - input. Fill value for the area outside the transform in the output image is always 0. - - """ - - def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0): - super().__init__() - self.p = p - self.interpolation = interpolation - self.distortion_scale = distortion_scale - self.fill = fill - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be Perspectively transformed. - - Returns: - PIL Image or Tensor: Randomly transformed image. - """ - if torch.rand(1) < self.p: - width, height = F._get_image_size(img) - startpoints, endpoints = self.get_params(width, height, self.distortion_scale) - return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill) - return img - - @staticmethod - def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]: - """Get parameters for ``perspective`` for a random perspective transform. - - Args: - width (int): width of the image. - height (int): height of the image. - distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. - - Returns: - List containing [top-left, top-right, bottom-right, bottom-left] of the original image, - List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image. - """ - half_height = height // 2 - half_width = width // 2 - topleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) - ] - topright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) - ] - botright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) - ] - botleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) - ] - startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] - endpoints = [topleft, topright, botright, botleft] - return startpoints, endpoints - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) - - -# class RandomResizedCrop(torch.nn.Module): -# """Crop a random portion of image and resize it to a given size. -# -# If the image is torch Tensor, it is expected -# to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions -# -# A crop of the original image is made: the crop has a random area (H * W) -# and a random aspect ratio. This crop is finally resized to the given -# size. This is popularly used to train the Inception networks. -# -# Args: -# size (int or sequence): expected output size of the crop, for each edge. If size is an -# int instead of sequence like (h, w), a square output size ``(size, size)`` is -# made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). -# -# .. note:: -# In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. -# scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop, -# before resizing. The scale is defined with respect to the area of the original image. -# ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before -# resizing. -# interpolation (InterpolationMode): Desired interpolation enum defined by -# :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. -# If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and -# ``InterpolationMode.BICUBIC`` are supported. -# For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. -# -# """ -# -# def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR): -# super().__init__() -# self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") -# -# if not isinstance(scale, Sequence): -# raise TypeError("Scale should be a sequence") -# if not isinstance(ratio, Sequence): -# raise TypeError("Ratio should be a sequence") -# if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): -# warnings.warn("Scale and ratio should be of kind (min, max)") -# -# # Backward compatibility with integer value -# if isinstance(interpolation, int): -# warnings.warn( -# "Argument interpolation should be of type InterpolationMode instead of int. " -# "Please, use InterpolationMode enum." -# ) -# interpolation = _interpolation_modes_from_int(interpolation) -# -# self.interpolation = interpolation -# self.scale = scale -# self.ratio = ratio -# -# @staticmethod -# def get_params( -# img: Tensor, scale: List[float], ratio: List[float] -# ) -> Tuple[int, int, int, int]: -# """Get parameters for ``crop`` for a random sized crop. -# -# Args: -# img (PIL Image or Tensor): Input image. -# scale (list): range of scale of the origin size cropped -# ratio (list): range of aspect ratio of the origin aspect ratio cropped -# -# Returns: -# tuple: params (i, j, h, w) to be passed to ``crop`` for a random -# sized crop. -# """ -# width, height = F._get_image_size(img) -# area = height * width -# -# log_ratio = torch.log(torch.tensor(ratio)) -# for _ in range(10): -# target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() -# aspect_ratio = torch.exp( -# torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) -# ).item() -# -# w = int(round(math.sqrt(target_area * aspect_ratio))) -# h = int(round(math.sqrt(target_area / aspect_ratio))) -# -# if 0 < w <= width and 0 < h <= height: -# i = torch.randint(0, height - h + 1, size=(1,)).item() -# j = torch.randint(0, width - w + 1, size=(1,)).item() -# return i, j, h, w -# -# # Fallback to central crop -# in_ratio = float(width) / float(height) -# if in_ratio < min(ratio): -# w = width -# h = int(round(w / min(ratio))) -# elif in_ratio > max(ratio): -# h = height -# w = int(round(h * max(ratio))) -# else: # whole image -# w = width -# h = height -# i = (height - h) // 2 -# j = (width - w) // 2 -# return i, j, h, w -# -# def forward(self, img): -# """ -# Args: -# img (PIL Image or Tensor): Image to be cropped and resized. -# -# Returns: -# PIL Image or Tensor: Randomly cropped and resized image. -# """ -# i, j, h, w = self.get_params(img, self.scale, self.ratio) -# return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) -# -# def __repr__(self): -# interpolate_str = self.interpolation.value -# format_string = self.__class__.__name__ + '(size={0}'.format(self.size) -# format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) -# format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) -# format_string += ', interpolation={0})'.format(interpolate_str) -# return format_string - - -# class RandomSizedCrop(RandomResizedCrop): -# """ -# Note: This transform is deprecated in favor of RandomResizedCrop. -# """ -# def __init__(self, *args, **kwargs): -# warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + -# "please use transforms.RandomResizedCrop instead.") -# super(RandomSizedCrop, self).__init__(*args, **kwargs) -# -# -# class FiveCrop(torch.nn.Module): -# """Crop the given image into four corners and the central crop. -# If the image is torch Tensor, it is expected -# to have [..., H, W] shape, where ... means an arbitrary number of leading -# dimensions -# -# .. Note:: -# This transform returns a tuple of images and there may be a mismatch in the number of -# inputs and targets your Dataset returns. See below for an example of how to deal with -# this. -# -# Args: -# size (sequence or int): Desired output size of the crop. If size is an ``int`` -# instead of sequence like (h, w), a square crop of size (size, size) is made. -# If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). -# -# Example: -# >>> transform = Compose([ -# >>> FiveCrop(size), # this is a list of PIL Images -# >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor -# >>> ]) -# >>> #In your test loop you can do the following: -# >>> input, target = batch # input is a 5d tensor, target is 2d -# >>> bs, ncrops, c, h, w = input.size() -# >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops -# >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops -# """ -# -# def __init__(self, size): -# super().__init__() -# self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") -# -# def forward(self, img): -# """ -# Args: -# img (PIL Image or Tensor): Image to be cropped. -# -# Returns: -# tuple of 5 images. Image can be PIL Image or Tensor -# """ -# return F.five_crop(img, self.size) -# -# def __repr__(self): -# return self.__class__.__name__ + '(size={0})'.format(self.size) - - -class TenCrop(torch.nn.Module): - """Crop the given image into four corners and the central crop plus the flipped version of - these (horizontal flipping is used by default). - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading - dimensions - - .. Note:: - This transform returns a tuple of images and there may be a mismatch in the number of - inputs and targets your Dataset returns. See below for an example of how to deal with - this. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). - vertical_flip (bool): Use vertical flipping instead of horizontal - - Example: - >>> transform = Compose([ - >>> TenCrop(size), # this is a list of PIL Images - >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor - >>> ]) - >>> #In your test loop you can do the following: - >>> input, target = batch # input is a 5d tensor, target is 2d - >>> bs, ncrops, c, h, w = input.size() - >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops - >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops - """ - - def __init__(self, size, vertical_flip=False): - super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - self.vertical_flip = vertical_flip - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - tuple of 10 images. Image can be PIL Image or Tensor - """ - return F.ten_crop(img, self.size, self.vertical_flip) - - def __repr__(self): - return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) - - -class LinearTransformation(torch.nn.Module): - """Transform a tensor image with a square transformation matrix and a mean_vector computed - offline. - This transform does not support PIL Image. - Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and - subtract mean_vector from it which is then followed by computing the dot - product with the transformation matrix and then reshaping the tensor to its - original shape. - - Applications: - whitening transformation: Suppose X is a column vector zero-centered data. - Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X), - perform SVD on this matrix and pass it as transformation_matrix. - - Args: - transformation_matrix (Tensor): tensor [D x D], D = C x H x W - mean_vector (Tensor): tensor [D], D = C x H x W - """ - - def __init__(self, transformation_matrix, mean_vector): - super().__init__() - if transformation_matrix.size(0) != transformation_matrix.size(1): - raise ValueError("transformation_matrix should be square. Got " + - "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) - - if mean_vector.size(0) != transformation_matrix.size(0): - raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) + - " as any one of the dimensions of the transformation_matrix [{}]" - .format(tuple(transformation_matrix.size()))) - - if transformation_matrix.device != mean_vector.device: - raise ValueError("Input tensors should be on the same device. Got {} and {}" - .format(transformation_matrix.device, mean_vector.device)) - - self.transformation_matrix = transformation_matrix - self.mean_vector = mean_vector - - def forward(self, tensor: Tensor) -> Tensor: - """ - Args: - tensor (Tensor): Tensor image to be whitened. - - Returns: - Tensor: Transformed image. - """ - shape = tensor.shape - n = shape[-3] * shape[-2] * shape[-1] - if n != self.transformation_matrix.shape[0]: - raise ValueError("Input tensor and transformation matrix have incompatible shape." + - "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + - "{}".format(self.transformation_matrix.shape[0])) - - if tensor.device.type != self.mean_vector.device.type: - raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. " - "Got {} vs {}".format(tensor.device, self.mean_vector.device)) - - flat_tensor = tensor.view(-1, n) - self.mean_vector - transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) - tensor = transformed_tensor.view(shape) - return tensor - - def __repr__(self): - format_string = self.__class__.__name__ + '(transformation_matrix=' - format_string += (str(self.transformation_matrix.tolist()) + ')') - format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')') - return format_string - - -class ColorJitter(torch.nn.Module): - """Randomly change the brightness, contrast, saturation and hue of an image. - If the image is torch Tensor, it is expected - to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported. - - Args: - brightness (float or tuple of float (min, max)): How much to jitter brightness. - brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] - or the given [min, max]. Should be non negative numbers. - contrast (float or tuple of float (min, max)): How much to jitter contrast. - contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] - or the given [min, max]. Should be non negative numbers. - saturation (float or tuple of float (min, max)): How much to jitter saturation. - saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] - or the given [min, max]. Should be non negative numbers. - hue (float or tuple of float (min, max)): How much to jitter hue. - hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. - Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. - """ - - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): - super().__init__() - self.brightness = self._check_input(brightness, 'brightness') - self.contrast = self._check_input(contrast, 'contrast') - self.saturation = self._check_input(saturation, 'saturation') - self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), - clip_first_on_zero=False) - - @torch.jit.unused - def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): - if isinstance(value, numbers.Number): - if value < 0: - raise ValueError("If {} is a single number, it must be non negative.".format(name)) - value = [center - float(value), center + float(value)] - if clip_first_on_zero: - value[0] = max(value[0], 0.0) - elif isinstance(value, (tuple, list)) and len(value) == 2: - if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError("{} values should be between {}".format(name, bound)) - else: - raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name)) - - # if value is 0 or (1., 1.) for brightness/contrast/saturation - # or (0., 0.) for hue, do nothing - if value[0] == value[1] == center: - value = None - return value - - @staticmethod - def get_params(brightness: Optional[List[float]], - contrast: Optional[List[float]], - saturation: Optional[List[float]], - hue: Optional[List[float]] - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: - """Get the parameters for the randomized transform to be applied on image. - - Args: - brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen - uniformly. Pass None to turn off the transformation. - contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen - uniformly. Pass None to turn off the transformation. - saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen - uniformly. Pass None to turn off the transformation. - hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. - Pass None to turn off the transformation. - - Returns: - tuple: The parameters used to apply the randomized transform - along with their random order. - """ - fn_idx = torch.randperm(4) - - b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) - c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) - s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) - h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) - - return fn_idx, b, c, s, h - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Input image. - - Returns: - PIL Image or Tensor: Color jittered image. - """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ - self.get_params(self.brightness, self.contrast, self.saturation, self.hue) - - for fn_id in fn_idx: - if fn_id == 0 and brightness_factor is not None: - img = F.adjust_brightness(img, brightness_factor) - elif fn_id == 1 and contrast_factor is not None: - img = F.adjust_contrast(img, contrast_factor) - elif fn_id == 2 and saturation_factor is not None: - img = F.adjust_saturation(img, saturation_factor) - elif fn_id == 3 and hue_factor is not None: - img = F.adjust_hue(img, hue_factor) - - return img,mask,wmap - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - format_string += 'brightness={0}'.format(self.brightness) - format_string += ', contrast={0}'.format(self.contrast) - format_string += ', saturation={0}'.format(self.saturation) - format_string += ', hue={0})'.format(self.hue) - return format_string - -class RandomRotation(torch.nn.Module): - """Rotate the image by angle. - The image can be a PIL Image or a Tensor, in which case it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. - - Args: - degrees (sequence or float or int): Range of degrees to select from. - If degrees is a number instead of sequence like (min, max), the range of degrees - will be (-degrees, +degrees). - resample (int, optional): An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. - If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. - expand (bool, optional): Optional expansion flag. - If true, expands the output to make it large enough to hold the entire rotated image. - If false or omitted, make the output image the same size as the input image. - Note that the expand flag assumes rotation around the center and no translation. - center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner. - Default is the center of the image. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated - image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. - This option is not supported for Tensor input. Fill value for the area outside the transform in the output - image is always 0. - - .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters - - """ - - def __init__(self, degrees, resample=Image.NEAREST, expand=False, center=None, fill=0): - super().__init__() - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) - - if center is not None: - _check_sequence_input(center, "center", req_sizes=(2, )) - - self.center = center - - self.resample = resample - self.expand = expand - self.fill = fill - - @staticmethod - def get_params(degrees: List[float]) -> float: - """Get parameters for ``rotate`` for a random rotation. - - Returns: - float: angle parameter to be passed to ``rotate`` for random rotation. - """ - angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) - return angle - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be rotated. - - Returns: - PIL Image or Tensor: Rotated image. - """ - angle = self.get_params(self.degrees) - return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill),F.rotate(mask, angle, self.resample, self.expand, self.center, self.fill),F.rotate(wmap, angle, self.resample, self.expand, self.center, self.fill) - - - - -# class RandomAffine(torch.nn.Module): -# """Random affine transformation of the image keeping center invariant. -# If the image is torch Tensor, it is expected -# to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. -# -# Args: -# degrees (sequence or number): Range of degrees to select from. -# If degrees is a number instead of sequence like (min, max), the range of degrees -# will be (-degrees, +degrees). Set to 0 to deactivate rotations. -# translate (tuple, optional): tuple of maximum absolute fraction for horizontal -# and vertical translations. For example translate=(a, b), then horizontal shift -# is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is -# randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. -# scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is -# randomly sampled from the range a <= scale <= b. Will keep original scale by default. -# shear (sequence or number, optional): Range of degrees to select from. -# If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) -# will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the -# range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values, -# a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. -# Will not apply shear by default. -# interpolation (InterpolationMode): Desired interpolation enum defined by -# :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. -# If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. -# For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. -# fill (sequence or number): Pixel fill value for the area outside the transformed -# image. Default is ``0``. If given a number, the value is used for all bands respectively. -# fillcolor (sequence or number, optional): deprecated argument and will be removed since v0.10.0. -# Please use the ``fill`` parameter instead. -# resample (int, optional): deprecated argument and will be removed since v0.10.0. -# Please use the ``interpolation`` parameter instead. -# -# .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters -# -# """ -# -# def __init__( -# self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, -# fillcolor=None, resample=None -# ): -# super().__init__() -# if resample is not None: -# warnings.warn( -# "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" -# ) -# interpolation = _interpolation_modes_from_int(resample) -# -# # Backward compatibility with integer value -# if isinstance(interpolation, int): -# warnings.warn( -# "Argument interpolation should be of type InterpolationMode instead of int. " -# "Please, use InterpolationMode enum." -# ) -# interpolation = _interpolation_modes_from_int(interpolation) -# -# if fillcolor is not None: -# warnings.warn( -# "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead" -# ) -# fill = fillcolor -# -# self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) -# -# if translate is not None: -# _check_sequence_input(translate, "translate", req_sizes=(2, )) -# for t in translate: -# if not (0.0 <= t <= 1.0): -# raise ValueError("translation values should be between 0 and 1") -# self.translate = translate -# -# if scale is not None: -# _check_sequence_input(scale, "scale", req_sizes=(2, )) -# for s in scale: -# if s <= 0: -# raise ValueError("scale values should be positive") -# self.scale = scale -# -# if shear is not None: -# self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) -# else: -# self.shear = shear -# -# self.resample = self.interpolation = interpolation -# -# if fill is None: -# fill = 0 -# elif not isinstance(fill, (Sequence, numbers.Number)): -# raise TypeError("Fill should be either a sequence or a number.") -# -# self.fillcolor = self.fill = fill -# -# @staticmethod -# def get_params( -# degrees: List[float], -# translate: Optional[List[float]], -# scale_ranges: Optional[List[float]], -# shears: Optional[List[float]], -# img_size: List[int] -# ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: -# """Get parameters for affine transformation -# -# Returns: -# params to be passed to the affine transformation -# """ -# angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) -# if translate is not None: -# max_dx = float(translate[0] * img_size[0]) -# max_dy = float(translate[1] * img_size[1]) -# tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) -# ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) -# translations = (tx, ty) -# else: -# translations = (0, 0) -# -# if scale_ranges is not None: -# scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item()) -# else: -# scale = 1.0 -# -# shear_x = shear_y = 0.0 -# if shears is not None: -# shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item()) -# if len(shears) == 4: -# shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item()) -# -# shear = (shear_x, shear_y) -# -# return angle, translations, scale, shear -# -# def forward(self, img): -# """ -# img (PIL Image or Tensor): Image to be transformed. -# -# Returns: -# PIL Image or Tensor: Affine transformed image. -# """ -# fill = self.fill -# if isinstance(img, Tensor): -# if isinstance(fill, (int, float)): -# fill = [float(fill)] * F._get_image_num_channels(img) -# else: -# fill = [float(f) for f in fill] -# -# img_size = F._get_image_size(img) -# -# ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) -# -# return F.affine(img, *ret, interpolation=self.interpolation, fill=fill) -# -# def __repr__(self): -# s = '{name}(degrees={degrees}' -# if self.translate is not None: -# s += ', translate={translate}' -# if self.scale is not None: -# s += ', scale={scale}' -# if self.shear is not None: -# s += ', shear={shear}' -# if self.interpolation != InterpolationMode.NEAREST: -# s += ', interpolation={interpolation}' -# if self.fill != 0: -# s += ', fill={fill}' -# s += ')' -# d = dict(self.__dict__) -# d['interpolation'] = self.interpolation.value -# return s.format(name=self.__class__.__name__, **d) - - -class Grayscale(torch.nn.Module): - """Convert image to grayscale. - If the image is torch Tensor, it is expected - to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions - - Args: - num_output_channels (int): (1 or 3) number of channels desired for output image - - Returns: - PIL Image: Grayscale version of the input. - - - If ``num_output_channels == 1`` : returned image is single channel - - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b - - """ - - def __init__(self, num_output_channels=1): - super().__init__() - self.num_output_channels = num_output_channels - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be converted to grayscale. - - Returns: - PIL Image or Tensor: Grayscaled image. - """ - return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(mask, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(wmap, num_output_channels=self.num_output_channels) - - def __repr__(self): - return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) - -class GetBoundingBoxes(torch.nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, img,mask): - import numpy as np - - A = np.array([ - [0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 0, 2, 2, 0], - [0, 1, 1, 0, 2, 2, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 4, 4, 0, 3, 3, 0], - [0, 4, 4, 0, 3, 3, 0], - [0, 0, 0, 0, 0, 0, 0] - ]) - - bboxCorners = {} - for i in range(1, A.max() + 1): - B = np.argwhere(A == i) - bboxCorners[i] = B.min(0), B.max(0) - - print(bboxCorners) - return img - return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(mask, num_output_channels=self.num_output_channels) - - def __repr__(self): - return self.__class__.__name__ - - -class RandomGrayscale(torch.nn.Module): - """Randomly convert image to grayscale with a probability of p (default 0.1). - If the image is torch Tensor, it is expected - to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions - - Args: - p (float): probability that image should be converted to grayscale. - - Returns: - PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged - with probability (1-p). - - If input image is 1 channel: grayscale version is 1 channel - - If input image is 3 channel: grayscale version is 3 channel with r == g == b - - """ - - def __init__(self, p=0.1): - super().__init__() - self.p = p - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be converted to grayscale. - - Returns: - PIL Image or Tensor: Randomly grayscaled image. - """ - num_output_channels = F._get_image_num_channels(img) - if torch.rand(1) < self.p: - return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) - return img - - def __repr__(self): - return self.__class__.__name__ + '(p={0})'.format(self.p) - - -class RandomErasing(torch.nn.Module): - """ Randomly selects a rectangle region in an torch Tensor image and erases its pixels. - This transform does not support PIL Image. - 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 - - Args: - p: probability that the random erasing operation will be performed. - scale: range of proportion of erased area against input image. - ratio: range of aspect ratio of erased area. - value: erasing value. Default is 0. If a single int, it is used to - erase all pixels. If a tuple of length 3, it is used to erase - R, G, B channels respectively. - If a str of 'random', erasing each pixel with random values. - inplace: boolean to make this transform inplace. Default set to False. - - Returns: - Erased Image. - - Example: - >>> transform = transforms.Compose([ - >>> transforms.RandomHorizontalFlip(), - >>> transforms.ToTensor(), - >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - >>> transforms.RandomErasing(), - >>> ]) - """ - - def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): - super().__init__() - if not isinstance(value, (numbers.Number, str, tuple, list)): - raise TypeError("Argument value should be either a number or str or a sequence") - if isinstance(value, str) and value != "random": - raise ValueError("If value is str, it should be 'random'") - if not isinstance(scale, (tuple, list)): - raise TypeError("Scale should be a sequence") - if not isinstance(ratio, (tuple, list)): - raise TypeError("Ratio should be a sequence") - if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): - warnings.warn("Scale and ratio should be of kind (min, max)") - if scale[0] < 0 or scale[1] > 1: - raise ValueError("Scale should be between 0 and 1") - if p < 0 or p > 1: - raise ValueError("Random erasing probability should be between 0 and 1") - - self.p = p - self.scale = scale - self.ratio = ratio - self.value = value - self.inplace = inplace - - @staticmethod - def get_params( - img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None - ) -> Tuple[int, int, int, int, Tensor]: - """Get parameters for ``erase`` for a random erasing. - - Args: - img (Tensor): Tensor image to be erased. - scale (sequence): range of proportion of erased area against input image. - ratio (sequence): range of aspect ratio of erased area. - value (list, optional): erasing value. If None, it is interpreted as "random" - (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number, - i.e. ``value[0]``. - - Returns: - tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing. - """ - img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1] - area = img_h * img_w - - log_ratio = torch.log(torch.tensor(ratio)) - for _ in range(10): - erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() - aspect_ratio = torch.exp( - torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) - ).item() - - h = int(round(math.sqrt(erase_area * aspect_ratio))) - w = int(round(math.sqrt(erase_area / aspect_ratio))) - if not (h < img_h and w < img_w): - continue - - if value is None: - v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() - else: - v = torch.tensor(value)[:, None, None] - - i = torch.randint(0, img_h - h + 1, size=(1, )).item() - j = torch.randint(0, img_w - w + 1, size=(1, )).item() - return i, j, h, w, v - - # Return original image - return 0, 0, img_h, img_w, img - - def forward(self, img): - """ - Args: - img (Tensor): Tensor image to be erased. - - Returns: - img (Tensor): Erased Tensor image. - """ - if torch.rand(1) < self.p: - - # cast self.value to script acceptable type - if isinstance(self.value, (int, float)): - value = [self.value, ] - elif isinstance(self.value, str): - value = None - elif isinstance(self.value, tuple): - value = list(self.value) - else: - value = self.value - - if value is not None and not (len(value) in (1, img.shape[-3])): - raise ValueError( - "If value is a sequence, it should have either a single value or " - "{} (number of input channels)".format(img.shape[-3]) - ) - - x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value) - return F.erase(img, x, y, h, w, v, self.inplace) - return img - - def __repr__(self): - s = '(p={}, '.format(self.p) - s += 'scale={}, '.format(self.scale) - s += 'ratio={}, '.format(self.ratio) - s += 'value={}, '.format(self.value) - s += 'inplace={})'.format(self.inplace) - return self.__class__.__name__ + s - - -class GaussianBlur(torch.nn.Module): - """Blurs image with randomly chosen Gaussian blur. - If the image is torch Tensor, it is expected - to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions. - - Args: - kernel_size (int or sequence): Size of the Gaussian kernel. - sigma (float or tuple of float (min, max)): Standard deviation to be used for - creating kernel to perform blurring. If float, sigma is fixed. If it is tuple - of float (min, max), sigma is chosen uniformly at random to lie in the - given range. - - Returns: - PIL Image or Tensor: Gaussian blurred version of the input image. - - """ - - def __init__(self, kernel_size, sigma=(0.1, 2.0)): - super().__init__() - self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") - for ks in self.kernel_size: - if ks <= 0 or ks % 2 == 0: - raise ValueError("Kernel size value should be an odd and positive number.") - - if isinstance(sigma, numbers.Number): - if sigma <= 0: - raise ValueError("If sigma is a single number, it must be positive.") - sigma = (sigma, sigma) - elif isinstance(sigma, Sequence) and len(sigma) == 2: - if not 0. < sigma[0] <= sigma[1]: - raise ValueError("sigma values should be positive and of the form (min, max).") - else: - raise ValueError("sigma should be a single number or a list/tuple with length 2.") - - self.sigma = sigma - - @staticmethod - def get_params(sigma_min: float, sigma_max: float) -> float: - """Choose sigma for random gaussian blurring. - - Args: - sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel. - sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel. - - Returns: - float: Standard deviation to be passed to calculate kernel for gaussian blurring. - """ - return torch.empty(1).uniform_(sigma_min, sigma_max).item() - - def forward(self, img: Tensor,mask: Tensor,wmap: Tensor) -> Tensor: - """ - Args: - img (PIL Image or Tensor): image to be blurred. - - Returns: - PIL Image or Tensor: Gaussian blurred image - """ - sigma = self.get_params(self.sigma[0], self.sigma[1]) - return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]),mask,wmap - - def __repr__(self): - s = '(kernel_size={}, '.format(self.kernel_size) - s += 'sigma={})'.format(self.sigma) - return self.__class__.__name__ + s - - -def _setup_size(size, error_msg): - if isinstance(size, numbers.Number): - return int(size), int(size) - - if isinstance(size, Sequence) and len(size) == 1: - return size[0], size[0] - - if len(size) != 2: - raise ValueError(error_msg) - - return size - - -def _check_sequence_input(x, name, req_sizes): - msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes]) - if not isinstance(x, Sequence): - raise TypeError("{} should be a sequence of length {}.".format(name, msg)) - if len(x) not in req_sizes: - raise ValueError("{} should be sequence of length {}.".format(name, msg)) - - -def _setup_angle(x, name, req_sizes=(2, )): - if isinstance(x, numbers.Number): - if x < 0: - raise ValueError("If {} is a single number, it must be positive.".format(name)) - x = [-x, x] - else: - _check_sequence_input(x, name, req_sizes) - - return [float(d) for d in x] - - -class RandomInvert(torch.nn.Module): - """Inverts the colors of the given image randomly with a given probability. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, - where ... means it can have an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - p (float): probability of the image being color inverted. Default value is 0.5 - """ - - def __init__(self, p=0.5): - super().__init__() - self.p = p - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be inverted. - - Returns: - PIL Image or Tensor: Randomly color inverted image. - """ - if torch.rand(1).item() < self.p: - return F.invert(img) - return img - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) - - -class RandomPosterize(torch.nn.Module): - """Posterize the image randomly with a given probability by reducing the - number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8, - and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - bits (int): number of bits to keep for each channel (0-8) - p (float): probability of the image being color inverted. Default value is 0.5 - """ - - def __init__(self, bits, p=0.5): - super().__init__() - self.bits = bits - self.p = p - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be posterized. - - Returns: - PIL Image or Tensor: Randomly posterized image. - """ - if torch.rand(1).item() < self.p: - return F.posterize(img, self.bits) - return img - - def __repr__(self): - return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) - - -class RandomSolarize(torch.nn.Module): - """Solarize the image randomly with a given probability by inverting all pixel - values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, - where ... means it can have an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - threshold (float): all pixels equal or above this value are inverted. - p (float): probability of the image being color inverted. Default value is 0.5 - """ - - def __init__(self, threshold, p=0.5): - super().__init__() - self.threshold = threshold - self.p = p - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be solarized. - - Returns: - PIL Image or Tensor: Randomly solarized image. - """ - if torch.rand(1).item() < self.p: - return F.solarize(img, self.threshold) - return img - - def __repr__(self): - return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) - - -class RandomAdjustSharpness(torch.nn.Module): - """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor, - it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - - Args: - sharpness_factor (float): How much to adjust the sharpness. Can be - any non negative number. 0 gives a blurred image, 1 gives the - original image while 2 increases the sharpness by a factor of 2. - p (float): probability of the image being color inverted. Default value is 0.5 - """ - - def __init__(self, sharpness_factor, p=0.5): - super().__init__() - self.sharpness_factor = sharpness_factor - self.p = p - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be sharpened. - - Returns: - PIL Image or Tensor: Randomly sharpened image. - """ - if torch.rand(1).item() < self.p: - return F.adjust_sharpness(img, self.sharpness_factor),mask,wmap - return img,mask,wmap - - def __repr__(self): - return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) - - -class RandomAutocontrast(torch.nn.Module): - """Autocontrast the pixels of the given image randomly with a given probability. - If the image is torch Tensor, it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - p (float): probability of the image being autocontrasted. Default value is 0.5 - """ - - def __init__(self, p=0.5): - super().__init__() - self.p = p - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be autocontrasted. - - Returns: - PIL Image or Tensor: Randomly autocontrasted image. - """ - if torch.rand(1).item() < self.p: - return F.autocontrast(img) - return img - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) - - -class RandomEqualize(torch.nn.Module): - """Equalize the histogram of the given image randomly with a given probability. - If the image is torch Tensor, it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". - - Args: - p (float): probability of the image being equalized. Default value is 0.5 - """ - - def __init__(self, p=0.5): - super().__init__() - self.p = p - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be equalized. - - Returns: - PIL Image or Tensor: Randomly equalized image. - """ - if torch.rand(1).item() < self.p: - return F.equalize(img),mask,wmap - return img,mask,wmap - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) +import math +import numbers +import random +import warnings +from collections.abc import Sequence +from typing import Tuple, List, Optional +from PIL import Image +import torch +from torch import Tensor + +try: + import accimage +except ImportError: + accimage = None + + +from torchvision.transforms import functional as F + +__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", + "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", + "RandomHorizontalFlip", "RandomVerticalFlip", "TenCrop", + "LinearTransformation", "ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale", + "RandomPerspective", "RandomErasing", "GaussianBlur"] + + +class Compose: + """Composes several transforms together. This transform does not support torchscript. + Please, see the note below. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + + .. note:: + In order to script the transformations, please use ``torch.nn.Sequential`` as below. + + >>> transforms = torch.nn.Sequential( + >>> transforms.CenterCrop(10), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> ) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img,mask,wmap): + for t in self.transforms: + img,mask,wmap = t(img,mask,wmap) + return img,mask,wmap + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class ToTensor: + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript. + + Converts a PIL Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] + if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) + or if the numpy.ndarray has dtype = np.uint8 + + In the other cases, tensors are returned without scaling. + + .. note:: + Because the input image is scaled to [0.0, 1.0], this transformation should not be used when + transforming target image masks. See the `references`_ for implementing the transforms for image masks. + + .. _references: https://github.com/pytorch/vision/tree/master/references/segmentation + """ + + def __call__(self, img,mask,wmap): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return F.to_tensor(img),F.to_tensor(mask),F.to_tensor(wmap) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class PILToTensor: + """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript. + + Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return F.pil_to_tensor(pic) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class ConvertImageDtype(torch.nn.Module): + """Convert a tensor image to the given ``dtype`` and scale the values accordingly + This function does not support PIL Image. + + Args: + dtype (torch.dtype): Desired data type of the output + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + + def __init__(self, dtype: torch.dtype) -> None: + super().__init__() + self.dtype = dtype + + def forward(self, image): + return F.convert_image_dtype(image, self.dtype) + + +class ToPILImage: + """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript. + + Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape + H x W x C to a PIL Image while preserving the value range. + + Args: + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + If ``mode`` is ``None`` (default) there are some assumptions made about the input data: + - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. + - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. + - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. + - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, + ``short``). + + .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes + """ + def __init__(self, mode=None): + self.mode = mode + + def __call__(self, pic): + """ + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. + + Returns: + PIL Image: Image converted to PIL Image. + + """ + return F.to_pil_image(pic, self.mode) + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + if self.mode is not None: + format_string += 'mode={0}'.format(self.mode) + format_string += ')' + return format_string + + +class Normalize(torch.nn.Module): + """Normalize a tensor image with mean and standard deviation. + This transform does not support PIL Image. + Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` + channels, this transform will normalize each channel of the input + ``torch.*Tensor`` i.e., + ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` + + .. note:: + This transform acts out of place, i.e., it does not mutate the input tensor. + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation in-place. + + """ + + def __init__(self, mean, std, inplace=False): + super().__init__() + self.mean = mean + self.std = std + self.inplace = inplace + + def forward(self, img: Tensor,mask: Tensor,wmap: Tensor) -> Tensor: + """ + Args: + tensor (Tensor): Tensor image to be normalized. + + Returns: + Tensor: Normalized Tensor image. + """ + return F.normalize(img, self.mean, self.std, self.inplace),mask,wmap + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + + +class Resize(torch.nn.Module): + """Resize the input image to the given size. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size). + In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[size, ]``. + interpolation (int, optional): Desired interpolation enum defined by `filters`_. + Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` + and ``PIL.Image.BICUBIC`` are supported. + """ + + def __init__(self, size, img_interpolation=Image.BILINEAR,mask_interpolation=Image.NEAREST): + super().__init__() + if not isinstance(size, (int, Sequence)): + raise TypeError("Size should be int or sequence. Got {}".format(type(size))) + if isinstance(size, Sequence) and len(size) not in (1, 2): + raise ValueError("If size is a sequence, it should have 1 or 2 values") + self.size = size + self.img_interpolation = img_interpolation + self.mask_interpolation = mask_interpolation + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be scaled. + + Returns: + PIL Image or Tensor: Rescaled image. + """ + + + return F.resize(img, self.size, self.img_interpolation),F.resize(mask, self.size, self.mask_interpolation),F.resize(wmap, self.size, self.mask_interpolation) + + def __repr__(self): + interpolate_str = self.interpolation.value + return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format( + self.size, interpolate_str, self.max_size, self.antialias) + + +class Scale(Resize): + """ + Note: This transform is deprecated in favor of Resize. + """ + def __init__(self, *args, **kwargs): + warnings.warn("The use of the transforms.Scale transform is deprecated, " + + "please use transforms.Resize instead.") + super(Scale, self).__init__(*args, **kwargs) + + +class CenterCrop(torch.nn.Module): + """Crops the given image at the center. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + """ + + def __init__(self, size): + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + return F.center_crop(img, self.size) + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + + +class Pad(torch.nn.Module): + """Pad the given image on all sides with the given "pad" value. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, + at most 3 leading dimensions for mode edge, + and an arbitrary number of leading dimensions for mode constant + + Args: + padding (int or sequence): Padding on each border. If a single int is provided this + is used to pad all borders. If sequence of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a sequence of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. + + .. note:: + In torchscript mode padding as single int is not supported, use a sequence of + length 1: ``[padding, ]``. + fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. + Only number is supported for torch Tensor. + Only int or str or tuple value is supported for PIL Image. + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image. + If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 + + - reflect: pads with reflection of image without repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + """ + + def __init__(self, padding, fill=0, padding_mode="constant"): + super().__init__() + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + + if not isinstance(fill, (numbers.Number, str, tuple)): + raise TypeError("Got inappropriate fill arg") + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: + raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + self.padding = padding + self.fill = fill + self.padding_mode = padding_mode + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be padded. + + Returns: + PIL Image or Tensor: Padded image. + """ + return F.pad(img, self.padding, self.fill, self.padding_mode) + + def __repr__(self): + return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ + format(self.padding, self.fill, self.padding_mode) + + +class Lambda: + """Apply a user-defined lambda as a transform. This transform does not support torchscript. + + Args: + lambd (function): Lambda/function to be used for transform. + """ + + def __init__(self, lambd): + if not callable(lambd): + raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__))) + self.lambd = lambd + + def __call__(self, img): + return self.lambd(img) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class RandomTransforms: + """Base class for a list of transformations with randomness + + Args: + transforms (sequence): list of transformations + """ + + def __init__(self, transforms): + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence") + self.transforms = transforms + + def __call__(self, *args, **kwargs): + raise NotImplementedError() + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class RandomApply(torch.nn.Module): + """Apply randomly a list of transformations with a given probability. + + .. note:: + In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of + transforms as shown below: + + >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ + >>> transforms.ColorJitter(), + >>> ]), p=0.3) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + Args: + transforms (sequence or torch.nn.Module): list of transformations + p (float): probability + """ + + def __init__(self, transforms, p=0.5): + super().__init__() + self.transforms = transforms + self.p = p + + def forward(self, img,mask,wmap): + if self.p < torch.rand(1): + return img,mask,wmap + for t in self.transforms: + img,mask,wmap = t(img,mask,wmap) + return img,mask,wmap + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + format_string += '\n p={}'.format(self.p) + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class RandomOrder(RandomTransforms): + """Apply a list of transformations in a random order. This transform does not support torchscript. + """ + def __call__(self, img,mask,wmap): + order = list(range(len(self.transforms))) + random.shuffle(order) + for i in order: + img,mask,wmap = self.transforms[i](img,mask,wmap) + return img,mask,wmap + + +class RandomChoice(RandomTransforms): + """Apply single transformation randomly picked from a list. This transform does not support torchscript. + """ + def __call__(self, img,mask,wmap): + t = random.choice(self.transforms) + return t(img,mask,wmap) + + +class RandomCrop(torch.nn.Module): + """Crop the given image at a random location. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions, + but if non-constant padding is used, the input is expected to have at most 2 leading dimensions + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + padding (int or sequence, optional): Optional padding on each border + of the image. Default is None. If a single int is provided this + is used to pad all borders. If sequence of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a sequence of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. + + .. note:: + In torchscript mode padding as single int is not supported, use a sequence of + length 1: ``[padding, ]``. + pad_if_needed (boolean): It will pad the image if smaller than the + desired size to avoid raising an exception. Since cropping is done + after padding, the padding seems to be done at a random offset. + fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. + Only number is supported for torch Tensor. + Only int or str or tuple value is supported for PIL Image. + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image. + If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 + + - reflect: pads with reflection of image without repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + """ + + @staticmethod + def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL Image or Tensor): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + w, h = F._get_image_size(img) + th, tw = output_size + + if h + 1 < th or w + 1 < tw: + raise ValueError( + "Required crop size {} is larger then input image size {}".format((th, tw), (h, w)) + ) + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1, )).item() + j = torch.randint(0, w - tw + 1, size=(1, )).item() + return i, j, th, tw + + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): + super().__init__() + + self.size = tuple(_setup_size( + size, error_msg="Please provide only two dimensions (h, w) for size." + )) + + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + width, height = F._get_image_size(img) + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img = F.pad(img, padding, self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img = F.pad(img, padding, self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img, self.size) + + return F.crop(img, i, j, h, w),F.crop(mask, i, j, h, w),F.crop(wmap, i, j, h, w) + + def __repr__(self): + return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) + + +class Pass(torch.nn.Module): + """ + No Transforms + """ + + def __init__(self): + super().__init__() + + def forward(self, img,mask,wmap): + return img,mask,wmap + + def __repr__(self): + return self.__class__.__name__ + +class RandomHorizontalFlip(torch.nn.Module): + """Horizontally flip the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be flipped. + + Returns: + PIL Image or Tensor: Randomly flipped image. + """ + if torch.rand(1) < self.p: + return F.hflip(img),F.hflip(mask),F.hflip(wmap) + return img,mask,wmap + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomVerticalFlip(torch.nn.Module): + """Vertically flip the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be flipped. + + Returns: + PIL Image or Tensor: Randomly flipped image. + """ + if torch.rand(1) < self.p: + return F.vflip(img),F.vflip(mask),F.vflip(wmap) + return img,mask,wmap + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomPerspective(torch.nn.Module): + """Performs a random perspective transformation of the given image with a given probability. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. + Default is 0.5. + p (float): probability of the image being transformed. Default is 0.5. + interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and + ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. Default is 0. + This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor + input. Fill value for the area outside the transform in the output image is always 0. + + """ + + def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0): + super().__init__() + self.p = p + self.interpolation = interpolation + self.distortion_scale = distortion_scale + self.fill = fill + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be Perspectively transformed. + + Returns: + PIL Image or Tensor: Randomly transformed image. + """ + if torch.rand(1) < self.p: + width, height = F._get_image_size(img) + startpoints, endpoints = self.get_params(width, height, self.distortion_scale) + return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill) + return img + + @staticmethod + def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]: + """Get parameters for ``perspective`` for a random perspective transform. + + Args: + width (int): width of the image. + height (int): height of the image. + distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. + + Returns: + List containing [top-left, top-right, bottom-right, bottom-left] of the original image, + List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image. + """ + half_height = height // 2 + half_width = width // 2 + topleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + ] + topright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + ] + botright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + ] + botleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + ] + startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] + endpoints = [topleft, topright, botright, botleft] + return startpoints, endpoints + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +# class RandomResizedCrop(torch.nn.Module): +# """Crop a random portion of image and resize it to a given size. +# +# If the image is torch Tensor, it is expected +# to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions +# +# A crop of the original image is made: the crop has a random area (H * W) +# and a random aspect ratio. This crop is finally resized to the given +# size. This is popularly used to train the Inception networks. +# +# Args: +# size (int or sequence): expected output size of the crop, for each edge. If size is an +# int instead of sequence like (h, w), a square output size ``(size, size)`` is +# made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). +# +# .. note:: +# In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. +# scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop, +# before resizing. The scale is defined with respect to the area of the original image. +# ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before +# resizing. +# interpolation (InterpolationMode): Desired interpolation enum defined by +# :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. +# If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and +# ``InterpolationMode.BICUBIC`` are supported. +# For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. +# +# """ +# +# def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR): +# super().__init__() +# self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") +# +# if not isinstance(scale, Sequence): +# raise TypeError("Scale should be a sequence") +# if not isinstance(ratio, Sequence): +# raise TypeError("Ratio should be a sequence") +# if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): +# warnings.warn("Scale and ratio should be of kind (min, max)") +# +# # Backward compatibility with integer value +# if isinstance(interpolation, int): +# warnings.warn( +# "Argument interpolation should be of type InterpolationMode instead of int. " +# "Please, use InterpolationMode enum." +# ) +# interpolation = _interpolation_modes_from_int(interpolation) +# +# self.interpolation = interpolation +# self.scale = scale +# self.ratio = ratio +# +# @staticmethod +# def get_params( +# img: Tensor, scale: List[float], ratio: List[float] +# ) -> Tuple[int, int, int, int]: +# """Get parameters for ``crop`` for a random sized crop. +# +# Args: +# img (PIL Image or Tensor): Input image. +# scale (list): range of scale of the origin size cropped +# ratio (list): range of aspect ratio of the origin aspect ratio cropped +# +# Returns: +# tuple: params (i, j, h, w) to be passed to ``crop`` for a random +# sized crop. +# """ +# width, height = F._get_image_size(img) +# area = height * width +# +# log_ratio = torch.log(torch.tensor(ratio)) +# for _ in range(10): +# target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() +# aspect_ratio = torch.exp( +# torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) +# ).item() +# +# w = int(round(math.sqrt(target_area * aspect_ratio))) +# h = int(round(math.sqrt(target_area / aspect_ratio))) +# +# if 0 < w <= width and 0 < h <= height: +# i = torch.randint(0, height - h + 1, size=(1,)).item() +# j = torch.randint(0, width - w + 1, size=(1,)).item() +# return i, j, h, w +# +# # Fallback to central crop +# in_ratio = float(width) / float(height) +# if in_ratio < min(ratio): +# w = width +# h = int(round(w / min(ratio))) +# elif in_ratio > max(ratio): +# h = height +# w = int(round(h * max(ratio))) +# else: # whole image +# w = width +# h = height +# i = (height - h) // 2 +# j = (width - w) // 2 +# return i, j, h, w +# +# def forward(self, img): +# """ +# Args: +# img (PIL Image or Tensor): Image to be cropped and resized. +# +# Returns: +# PIL Image or Tensor: Randomly cropped and resized image. +# """ +# i, j, h, w = self.get_params(img, self.scale, self.ratio) +# return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) +# +# def __repr__(self): +# interpolate_str = self.interpolation.value +# format_string = self.__class__.__name__ + '(size={0}'.format(self.size) +# format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) +# format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) +# format_string += ', interpolation={0})'.format(interpolate_str) +# return format_string + + +# class RandomSizedCrop(RandomResizedCrop): +# """ +# Note: This transform is deprecated in favor of RandomResizedCrop. +# """ +# def __init__(self, *args, **kwargs): +# warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + +# "please use transforms.RandomResizedCrop instead.") +# super(RandomSizedCrop, self).__init__(*args, **kwargs) +# +# +# class FiveCrop(torch.nn.Module): +# """Crop the given image into four corners and the central crop. +# If the image is torch Tensor, it is expected +# to have [..., H, W] shape, where ... means an arbitrary number of leading +# dimensions +# +# .. Note:: +# This transform returns a tuple of images and there may be a mismatch in the number of +# inputs and targets your Dataset returns. See below for an example of how to deal with +# this. +# +# Args: +# size (sequence or int): Desired output size of the crop. If size is an ``int`` +# instead of sequence like (h, w), a square crop of size (size, size) is made. +# If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). +# +# Example: +# >>> transform = Compose([ +# >>> FiveCrop(size), # this is a list of PIL Images +# >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor +# >>> ]) +# >>> #In your test loop you can do the following: +# >>> input, target = batch # input is a 5d tensor, target is 2d +# >>> bs, ncrops, c, h, w = input.size() +# >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops +# >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops +# """ +# +# def __init__(self, size): +# super().__init__() +# self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") +# +# def forward(self, img): +# """ +# Args: +# img (PIL Image or Tensor): Image to be cropped. +# +# Returns: +# tuple of 5 images. Image can be PIL Image or Tensor +# """ +# return F.five_crop(img, self.size) +# +# def __repr__(self): +# return self.__class__.__name__ + '(size={0})'.format(self.size) + + +class TenCrop(torch.nn.Module): + """Crop the given image into four corners and the central crop plus the flipped version of + these (horizontal flipping is used by default). + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + .. Note:: + This transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your Dataset returns. See below for an example of how to deal with + this. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + vertical_flip (bool): Use vertical flipping instead of horizontal + + Example: + >>> transform = Compose([ + >>> TenCrop(size), # this is a list of PIL Images + >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor + >>> ]) + >>> #In your test loop you can do the following: + >>> input, target = batch # input is a 5d tensor, target is 2d + >>> bs, ncrops, c, h, w = input.size() + >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops + >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops + """ + + def __init__(self, size, vertical_flip=False): + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.vertical_flip = vertical_flip + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + tuple of 10 images. Image can be PIL Image or Tensor + """ + return F.ten_crop(img, self.size, self.vertical_flip) + + def __repr__(self): + return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) + + +class LinearTransformation(torch.nn.Module): + """Transform a tensor image with a square transformation matrix and a mean_vector computed + offline. + This transform does not support PIL Image. + Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and + subtract mean_vector from it which is then followed by computing the dot + product with the transformation matrix and then reshaping the tensor to its + original shape. + + Applications: + whitening transformation: Suppose X is a column vector zero-centered data. + Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X), + perform SVD on this matrix and pass it as transformation_matrix. + + Args: + transformation_matrix (Tensor): tensor [D x D], D = C x H x W + mean_vector (Tensor): tensor [D], D = C x H x W + """ + + def __init__(self, transformation_matrix, mean_vector): + super().__init__() + if transformation_matrix.size(0) != transformation_matrix.size(1): + raise ValueError("transformation_matrix should be square. Got " + + "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) + + if mean_vector.size(0) != transformation_matrix.size(0): + raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) + + " as any one of the dimensions of the transformation_matrix [{}]" + .format(tuple(transformation_matrix.size()))) + + if transformation_matrix.device != mean_vector.device: + raise ValueError("Input tensors should be on the same device. Got {} and {}" + .format(transformation_matrix.device, mean_vector.device)) + + self.transformation_matrix = transformation_matrix + self.mean_vector = mean_vector + + def forward(self, tensor: Tensor) -> Tensor: + """ + Args: + tensor (Tensor): Tensor image to be whitened. + + Returns: + Tensor: Transformed image. + """ + shape = tensor.shape + n = shape[-3] * shape[-2] * shape[-1] + if n != self.transformation_matrix.shape[0]: + raise ValueError("Input tensor and transformation matrix have incompatible shape." + + "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + + "{}".format(self.transformation_matrix.shape[0])) + + if tensor.device.type != self.mean_vector.device.type: + raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. " + "Got {} vs {}".format(tensor.device, self.mean_vector.device)) + + flat_tensor = tensor.view(-1, n) - self.mean_vector + transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) + tensor = transformed_tensor.view(shape) + return tensor + + def __repr__(self): + format_string = self.__class__.__name__ + '(transformation_matrix=' + format_string += (str(self.transformation_matrix.tolist()) + ')') + format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')') + return format_string + + +class ColorJitter(torch.nn.Module): + """Randomly change the brightness, contrast, saturation and hue of an image. + If the image is torch Tensor, it is expected + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported. + + Args: + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + super().__init__() + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), + clip_first_on_zero=False) + + @torch.jit.unused + def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError("If {} is a single number, it must be non negative.".format(name)) + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, bound)) + else: + raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name)) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def get_params(brightness: Optional[List[float]], + contrast: Optional[List[float]], + saturation: Optional[List[float]], + hue: Optional[List[float]] + ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: + """Get the parameters for the randomized transform to be applied on image. + + Args: + brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen + uniformly. Pass None to turn off the transformation. + contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen + uniformly. Pass None to turn off the transformation. + saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen + uniformly. Pass None to turn off the transformation. + hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. + Pass None to turn off the transformation. + + Returns: + tuple: The parameters used to apply the randomized transform + along with their random order. + """ + fn_idx = torch.randperm(4) + + b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) + c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) + s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) + h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) + + return fn_idx, b, c, s, h + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Input image. + + Returns: + PIL Image or Tensor: Color jittered image. + """ + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ + self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + + return img,mask,wmap + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + format_string += 'brightness={0}'.format(self.brightness) + format_string += ', contrast={0}'.format(self.contrast) + format_string += ', saturation={0}'.format(self.saturation) + format_string += ', hue={0})'.format(self.hue) + return format_string + +class RandomRotation(torch.nn.Module): + """Rotate the image by angle. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + degrees (sequence or float or int): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). + resample (int, optional): An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. + expand (bool, optional): Optional expansion flag. + If true, expands the output to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner. + Default is the center of the image. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. + This option is not supported for Tensor input. Fill value for the area outside the transform in the output + image is always 0. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + def __init__(self, degrees, resample=Image.NEAREST, expand=False, center=None, fill=0): + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2, )) + + self.center = center + + self.resample = resample + self.expand = expand + self.fill = fill + + @staticmethod + def get_params(degrees: List[float]) -> float: + """Get parameters for ``rotate`` for a random rotation. + + Returns: + float: angle parameter to be passed to ``rotate`` for random rotation. + """ + angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) + return angle + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be rotated. + + Returns: + PIL Image or Tensor: Rotated image. + """ + angle = self.get_params(self.degrees) + return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill),F.rotate(mask, angle, self.resample, self.expand, self.center, self.fill),F.rotate(wmap, angle, self.resample, self.expand, self.center, self.fill) + + + + +# class RandomAffine(torch.nn.Module): +# """Random affine transformation of the image keeping center invariant. +# If the image is torch Tensor, it is expected +# to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. +# +# Args: +# degrees (sequence or number): Range of degrees to select from. +# If degrees is a number instead of sequence like (min, max), the range of degrees +# will be (-degrees, +degrees). Set to 0 to deactivate rotations. +# translate (tuple, optional): tuple of maximum absolute fraction for horizontal +# and vertical translations. For example translate=(a, b), then horizontal shift +# is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is +# randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. +# scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is +# randomly sampled from the range a <= scale <= b. Will keep original scale by default. +# shear (sequence or number, optional): Range of degrees to select from. +# If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) +# will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the +# range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values, +# a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. +# Will not apply shear by default. +# interpolation (InterpolationMode): Desired interpolation enum defined by +# :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. +# If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. +# For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. +# fill (sequence or number): Pixel fill value for the area outside the transformed +# image. Default is ``0``. If given a number, the value is used for all bands respectively. +# fillcolor (sequence or number, optional): deprecated argument and will be removed since v0.10.0. +# Please use the ``fill`` parameter instead. +# resample (int, optional): deprecated argument and will be removed since v0.10.0. +# Please use the ``interpolation`` parameter instead. +# +# .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters +# +# """ +# +# def __init__( +# self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, +# fillcolor=None, resample=None +# ): +# super().__init__() +# if resample is not None: +# warnings.warn( +# "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" +# ) +# interpolation = _interpolation_modes_from_int(resample) +# +# # Backward compatibility with integer value +# if isinstance(interpolation, int): +# warnings.warn( +# "Argument interpolation should be of type InterpolationMode instead of int. " +# "Please, use InterpolationMode enum." +# ) +# interpolation = _interpolation_modes_from_int(interpolation) +# +# if fillcolor is not None: +# warnings.warn( +# "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead" +# ) +# fill = fillcolor +# +# self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) +# +# if translate is not None: +# _check_sequence_input(translate, "translate", req_sizes=(2, )) +# for t in translate: +# if not (0.0 <= t <= 1.0): +# raise ValueError("translation values should be between 0 and 1") +# self.translate = translate +# +# if scale is not None: +# _check_sequence_input(scale, "scale", req_sizes=(2, )) +# for s in scale: +# if s <= 0: +# raise ValueError("scale values should be positive") +# self.scale = scale +# +# if shear is not None: +# self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) +# else: +# self.shear = shear +# +# self.resample = self.interpolation = interpolation +# +# if fill is None: +# fill = 0 +# elif not isinstance(fill, (Sequence, numbers.Number)): +# raise TypeError("Fill should be either a sequence or a number.") +# +# self.fillcolor = self.fill = fill +# +# @staticmethod +# def get_params( +# degrees: List[float], +# translate: Optional[List[float]], +# scale_ranges: Optional[List[float]], +# shears: Optional[List[float]], +# img_size: List[int] +# ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: +# """Get parameters for affine transformation +# +# Returns: +# params to be passed to the affine transformation +# """ +# angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) +# if translate is not None: +# max_dx = float(translate[0] * img_size[0]) +# max_dy = float(translate[1] * img_size[1]) +# tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) +# ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) +# translations = (tx, ty) +# else: +# translations = (0, 0) +# +# if scale_ranges is not None: +# scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item()) +# else: +# scale = 1.0 +# +# shear_x = shear_y = 0.0 +# if shears is not None: +# shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item()) +# if len(shears) == 4: +# shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item()) +# +# shear = (shear_x, shear_y) +# +# return angle, translations, scale, shear +# +# def forward(self, img): +# """ +# img (PIL Image or Tensor): Image to be transformed. +# +# Returns: +# PIL Image or Tensor: Affine transformed image. +# """ +# fill = self.fill +# if isinstance(img, Tensor): +# if isinstance(fill, (int, float)): +# fill = [float(fill)] * F._get_image_num_channels(img) +# else: +# fill = [float(f) for f in fill] +# +# img_size = F._get_image_size(img) +# +# ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) +# +# return F.affine(img, *ret, interpolation=self.interpolation, fill=fill) +# +# def __repr__(self): +# s = '{name}(degrees={degrees}' +# if self.translate is not None: +# s += ', translate={translate}' +# if self.scale is not None: +# s += ', scale={scale}' +# if self.shear is not None: +# s += ', shear={shear}' +# if self.interpolation != InterpolationMode.NEAREST: +# s += ', interpolation={interpolation}' +# if self.fill != 0: +# s += ', fill={fill}' +# s += ')' +# d = dict(self.__dict__) +# d['interpolation'] = self.interpolation.value +# return s.format(name=self.__class__.__name__, **d) + + +class Grayscale(torch.nn.Module): + """Convert image to grayscale. + If the image is torch Tensor, it is expected + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions + + Args: + num_output_channels (int): (1 or 3) number of channels desired for output image + + Returns: + PIL Image: Grayscale version of the input. + + - If ``num_output_channels == 1`` : returned image is single channel + - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b + + """ + + def __init__(self, num_output_channels=1): + super().__init__() + self.num_output_channels = num_output_channels + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be converted to grayscale. + + Returns: + PIL Image or Tensor: Grayscaled image. + """ + return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(mask, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(wmap, num_output_channels=self.num_output_channels) + + def __repr__(self): + return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) + +class GetBoundingBoxes(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, img,mask): + import numpy as np + + A = np.array([ + [0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 2, 2, 0], + [0, 1, 1, 0, 2, 2, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 4, 4, 0, 3, 3, 0], + [0, 4, 4, 0, 3, 3, 0], + [0, 0, 0, 0, 0, 0, 0] + ]) + + bboxCorners = {} + for i in range(1, A.max() + 1): + B = np.argwhere(A == i) + bboxCorners[i] = B.min(0), B.max(0) + + print(bboxCorners) + return img + return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(mask, num_output_channels=self.num_output_channels) + + def __repr__(self): + return self.__class__.__name__ + + +class RandomGrayscale(torch.nn.Module): + """Randomly convert image to grayscale with a probability of p (default 0.1). + If the image is torch Tensor, it is expected + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions + + Args: + p (float): probability that image should be converted to grayscale. + + Returns: + PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged + with probability (1-p). + - If input image is 1 channel: grayscale version is 1 channel + - If input image is 3 channel: grayscale version is 3 channel with r == g == b + + """ + + def __init__(self, p=0.1): + super().__init__() + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be converted to grayscale. + + Returns: + PIL Image or Tensor: Randomly grayscaled image. + """ + num_output_channels = F._get_image_num_channels(img) + if torch.rand(1) < self.p: + return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={0})'.format(self.p) + + +class RandomErasing(torch.nn.Module): + """ Randomly selects a rectangle region in an torch Tensor image and erases its pixels. + This transform does not support PIL Image. + 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 + + Args: + p: probability that the random erasing operation will be performed. + scale: range of proportion of erased area against input image. + ratio: range of aspect ratio of erased area. + value: erasing value. Default is 0. If a single int, it is used to + erase all pixels. If a tuple of length 3, it is used to erase + R, G, B channels respectively. + If a str of 'random', erasing each pixel with random values. + inplace: boolean to make this transform inplace. Default set to False. + + Returns: + Erased Image. + + Example: + >>> transform = transforms.Compose([ + >>> transforms.RandomHorizontalFlip(), + >>> transforms.ToTensor(), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> transforms.RandomErasing(), + >>> ]) + """ + + def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): + super().__init__() + if not isinstance(value, (numbers.Number, str, tuple, list)): + raise TypeError("Argument value should be either a number or str or a sequence") + if isinstance(value, str) and value != "random": + raise ValueError("If value is str, it should be 'random'") + if not isinstance(scale, (tuple, list)): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, (tuple, list)): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + if scale[0] < 0 or scale[1] > 1: + raise ValueError("Scale should be between 0 and 1") + if p < 0 or p > 1: + raise ValueError("Random erasing probability should be between 0 and 1") + + self.p = p + self.scale = scale + self.ratio = ratio + self.value = value + self.inplace = inplace + + @staticmethod + def get_params( + img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None + ) -> Tuple[int, int, int, int, Tensor]: + """Get parameters for ``erase`` for a random erasing. + + Args: + img (Tensor): Tensor image to be erased. + scale (sequence): range of proportion of erased area against input image. + ratio (sequence): range of aspect ratio of erased area. + value (list, optional): erasing value. If None, it is interpreted as "random" + (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number, + i.e. ``value[0]``. + + Returns: + tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing. + """ + img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1] + area = img_h * img_w + + log_ratio = torch.log(torch.tensor(ratio)) + for _ in range(10): + erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() + + h = int(round(math.sqrt(erase_area * aspect_ratio))) + w = int(round(math.sqrt(erase_area / aspect_ratio))) + if not (h < img_h and w < img_w): + continue + + if value is None: + v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() + else: + v = torch.tensor(value)[:, None, None] + + i = torch.randint(0, img_h - h + 1, size=(1, )).item() + j = torch.randint(0, img_w - w + 1, size=(1, )).item() + return i, j, h, w, v + + # Return original image + return 0, 0, img_h, img_w, img + + def forward(self, img): + """ + Args: + img (Tensor): Tensor image to be erased. + + Returns: + img (Tensor): Erased Tensor image. + """ + if torch.rand(1) < self.p: + + # cast self.value to script acceptable type + if isinstance(self.value, (int, float)): + value = [self.value, ] + elif isinstance(self.value, str): + value = None + elif isinstance(self.value, tuple): + value = list(self.value) + else: + value = self.value + + if value is not None and not (len(value) in (1, img.shape[-3])): + raise ValueError( + "If value is a sequence, it should have either a single value or " + "{} (number of input channels)".format(img.shape[-3]) + ) + + x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value) + return F.erase(img, x, y, h, w, v, self.inplace) + return img + + def __repr__(self): + s = '(p={}, '.format(self.p) + s += 'scale={}, '.format(self.scale) + s += 'ratio={}, '.format(self.ratio) + s += 'value={}, '.format(self.value) + s += 'inplace={})'.format(self.inplace) + return self.__class__.__name__ + s + + +class GaussianBlur(torch.nn.Module): + """Blurs image with randomly chosen Gaussian blur. + If the image is torch Tensor, it is expected + to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + kernel_size (int or sequence): Size of the Gaussian kernel. + sigma (float or tuple of float (min, max)): Standard deviation to be used for + creating kernel to perform blurring. If float, sigma is fixed. If it is tuple + of float (min, max), sigma is chosen uniformly at random to lie in the + given range. + + Returns: + PIL Image or Tensor: Gaussian blurred version of the input image. + + """ + + def __init__(self, kernel_size, sigma=(0.1, 2.0)): + super().__init__() + self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") + for ks in self.kernel_size: + if ks <= 0 or ks % 2 == 0: + raise ValueError("Kernel size value should be an odd and positive number.") + + if isinstance(sigma, numbers.Number): + if sigma <= 0: + raise ValueError("If sigma is a single number, it must be positive.") + sigma = (sigma, sigma) + elif isinstance(sigma, Sequence) and len(sigma) == 2: + if not 0. < sigma[0] <= sigma[1]: + raise ValueError("sigma values should be positive and of the form (min, max).") + else: + raise ValueError("sigma should be a single number or a list/tuple with length 2.") + + self.sigma = sigma + + @staticmethod + def get_params(sigma_min: float, sigma_max: float) -> float: + """Choose sigma for random gaussian blurring. + + Args: + sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel. + sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel. + + Returns: + float: Standard deviation to be passed to calculate kernel for gaussian blurring. + """ + return torch.empty(1).uniform_(sigma_min, sigma_max).item() + + def forward(self, img: Tensor,mask: Tensor,wmap: Tensor) -> Tensor: + """ + Args: + img (PIL Image or Tensor): image to be blurred. + + Returns: + PIL Image or Tensor: Gaussian blurred image + """ + sigma = self.get_params(self.sigma[0], self.sigma[1]) + return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]),mask,wmap + + def __repr__(self): + s = '(kernel_size={}, '.format(self.kernel_size) + s += 'sigma={})'.format(self.sigma) + return self.__class__.__name__ + s + + +def _setup_size(size, error_msg): + if isinstance(size, numbers.Number): + return int(size), int(size) + + if isinstance(size, Sequence) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + + +def _check_sequence_input(x, name, req_sizes): + msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes]) + if not isinstance(x, Sequence): + raise TypeError("{} should be a sequence of length {}.".format(name, msg)) + if len(x) not in req_sizes: + raise ValueError("{} should be sequence of length {}.".format(name, msg)) + + +def _setup_angle(x, name, req_sizes=(2, )): + if isinstance(x, numbers.Number): + if x < 0: + raise ValueError("If {} is a single number, it must be positive.".format(name)) + x = [-x, x] + else: + _check_sequence_input(x, name, req_sizes) + + return [float(d) for d in x] + + +class RandomInvert(torch.nn.Module): + """Inverts the colors of the given image randomly with a given probability. + If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be inverted. + + Returns: + PIL Image or Tensor: Randomly color inverted image. + """ + if torch.rand(1).item() < self.p: + return F.invert(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomPosterize(torch.nn.Module): + """Posterize the image randomly with a given probability by reducing the + number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8, + and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + bits (int): number of bits to keep for each channel (0-8) + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, bits, p=0.5): + super().__init__() + self.bits = bits + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be posterized. + + Returns: + PIL Image or Tensor: Randomly posterized image. + """ + if torch.rand(1).item() < self.p: + return F.posterize(img, self.bits) + return img + + def __repr__(self): + return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) + + +class RandomSolarize(torch.nn.Module): + """Solarize the image randomly with a given probability by inverting all pixel + values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + threshold (float): all pixels equal or above this value are inverted. + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, threshold, p=0.5): + super().__init__() + self.threshold = threshold + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be solarized. + + Returns: + PIL Image or Tensor: Randomly solarized image. + """ + if torch.rand(1).item() < self.p: + return F.solarize(img, self.threshold) + return img + + def __repr__(self): + return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) + + +class RandomAdjustSharpness(torch.nn.Module): + """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor, + it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness_factor (float): How much to adjust the sharpness. Can be + any non negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, sharpness_factor, p=0.5): + super().__init__() + self.sharpness_factor = sharpness_factor + self.p = p + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be sharpened. + + Returns: + PIL Image or Tensor: Randomly sharpened image. + """ + if torch.rand(1).item() < self.p: + return F.adjust_sharpness(img, self.sharpness_factor),mask,wmap + return img,mask,wmap + + def __repr__(self): + return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) + + +class RandomAutocontrast(torch.nn.Module): + """Autocontrast the pixels of the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + p (float): probability of the image being autocontrasted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be autocontrasted. + + Returns: + PIL Image or Tensor: Randomly autocontrasted image. + """ + if torch.rand(1).item() < self.p: + return F.autocontrast(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomEqualize(torch.nn.Module): + """Equalize the histogram of the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". + + Args: + p (float): probability of the image being equalized. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be equalized. + + Returns: + PIL Image or Tensor: Randomly equalized image. + """ + if torch.rand(1).item() < self.p: + return F.equalize(img),mask,wmap + return img,mask,wmap + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) diff --git a/test_single_image_segmentation.py b/deepsea/test_single_image_segmentation.py similarity index 100% rename from test_single_image_segmentation.py rename to deepsea/test_single_image_segmentation.py diff --git a/test_single_set_tracking.py b/deepsea/test_single_set_tracking.py similarity index 100% rename from test_single_set_tracking.py rename to deepsea/test_single_set_tracking.py diff --git a/tracker_transforms.py b/deepsea/tracker_transforms.py similarity index 97% rename from tracker_transforms.py rename to deepsea/tracker_transforms.py index 6f6da7c..95725dc 100644 --- a/tracker_transforms.py +++ b/deepsea/tracker_transforms.py @@ -1,1906 +1,1906 @@ -import math -import numbers -import random -import warnings -from collections.abc import Sequence -from typing import Tuple, List, Optional -from PIL import Image -import torch -from torch import Tensor - -try: - import accimage -except ImportError: - accimage = None - - -from torchvision.transforms import functional as F - - -__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", - "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", - "RandomHorizontalFlip", "RandomVerticalFlip", "TenCrop", - "LinearTransformation", "ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur"] - - -class Compose: - """Composes several transforms together. This transform does not support torchscript. - Please, see the note below. - - Args: - transforms (list of ``Transform`` objects): list of transforms to compose. - - Example: - >>> transforms.Compose([ - >>> transforms.CenterCrop(10), - >>> transforms.ToTensor(), - >>> ]) - - .. note:: - In order to script the transformations, please use ``torch.nn.Sequential`` as below. - - >>> transforms = torch.nn.Sequential( - >>> transforms.CenterCrop(10), - >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - >>> ) - >>> scripted_transforms = torch.jit.script(transforms) - - Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require - `lambda` functions or ``PIL.Image``. - - """ - - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, img,mask,wmap): - for t in self.transforms: - img,mask,wmap = t(img,mask,wmap) - return img,mask,wmap - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' - return format_string - - -class ToTensor: - """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript. - - Converts a PIL Image or numpy.ndarray (H x W x C) in the range - [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] - if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) - or if the numpy.ndarray has dtype = np.uint8 - - In the other cases, tensors are returned without scaling. - - .. note:: - Because the input image is scaled to [0.0, 1.0], this transformation should not be used when - transforming target image masks. See the `references`_ for implementing the transforms for image masks. - - .. _references: https://github.com/pytorch/vision/tree/master/references/segmentation - """ - - def __call__(self, img_prev,img_curr,mask): - """ - Args: - pic (PIL Image or numpy.ndarray): Image to be converted to tensor. - - Returns: - Tensor: Converted image. - """ - return F.to_tensor(img_prev),F.to_tensor(img_curr),F.to_tensor(mask) - - def __repr__(self): - return self.__class__.__name__ + '()' - - -class PILToTensor: - """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript. - - Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). - """ - - def __call__(self, pic): - """ - Args: - pic (PIL Image): Image to be converted to tensor. - - Returns: - Tensor: Converted image. - """ - return F.pil_to_tensor(pic) - - def __repr__(self): - return self.__class__.__name__ + '()' - - -class ConvertImageDtype(torch.nn.Module): - """Convert a tensor image to the given ``dtype`` and scale the values accordingly - This function does not support PIL Image. - - Args: - dtype (torch.dtype): Desired data type of the output - - .. note:: - - When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. - If converted back and forth, this mismatch has no effect. - - Raises: - RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as - well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to - overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range - of the integer ``dtype``. - """ - - def __init__(self, dtype: torch.dtype) -> None: - super().__init__() - self.dtype = dtype - - def forward(self, image): - return F.convert_image_dtype(image, self.dtype) - - -class ToPILImage: - """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript. - - Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape - H x W x C to a PIL Image while preserving the value range. - - Args: - mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). - If ``mode`` is ``None`` (default) there are some assumptions made about the input data: - - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. - - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. - - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. - - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, - ``short``). - - .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes - """ - def __init__(self, mode=None): - self.mode = mode - - def __call__(self, pic): - """ - Args: - pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. - - Returns: - PIL Image: Image converted to PIL Image. - - """ - return F.to_pil_image(pic, self.mode) - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - if self.mode is not None: - format_string += 'mode={0}'.format(self.mode) - format_string += ')' - return format_string - - -class Normalize(torch.nn.Module): - """Normalize a tensor image with mean and standard deviation. - This transform does not support PIL Image. - Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` - channels, this transform will normalize each channel of the input - ``torch.*Tensor`` i.e., - ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` - - .. note:: - This transform acts out of place, i.e., it does not mutate the input tensor. - - Args: - mean (sequence): Sequence of means for each channel. - std (sequence): Sequence of standard deviations for each channel. - inplace(bool,optional): Bool to make this operation in-place. - - """ - - def __init__(self, mean, std, inplace=False): - super().__init__() - self.mean = mean - self.std = std - self.inplace = inplace - - def forward(self, img_prev: Tensor,img_curr: Tensor,mask: Tensor) -> Tensor: - """ - Args: - tensor (Tensor): Tensor image to be normalized. - - Returns: - Tensor: Normalized Tensor image. - """ - return F.normalize(img_prev, self.mean, self.std, self.inplace),F.normalize(img_curr, self.mean, self.std, self.inplace),mask - - def __repr__(self): - return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) - - -class Resize(torch.nn.Module): - """Resize the input image to the given size. - The image can be a PIL Image or a torch Tensor, in which case it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions - - Args: - size (sequence or int): Desired output size. If size is a sequence like - (h, w), output size will be matched to this. If size is an int, - smaller edge of the image will be matched to this number. - i.e, if height > width, then image will be rescaled to - (size * height / width, size). - In torchscript mode padding as single int is not supported, use a tuple or - list of length 1: ``[size, ]``. - interpolation (int, optional): Desired interpolation enum defined by `filters`_. - Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` - and ``PIL.Image.BICUBIC`` are supported. - """ - - def __init__(self, size, img_interpolation=Image.BILINEAR,mask_interpolation=Image.NEAREST): - super().__init__() - if not isinstance(size, (int, Sequence)): - raise TypeError("Size should be int or sequence. Got {}".format(type(size))) - if isinstance(size, Sequence) and len(size) not in (1, 2): - raise ValueError("If size is a sequence, it should have 1 or 2 values") - self.size = size - self.img_interpolation = img_interpolation - self.mask_interpolation = mask_interpolation - - def forward(self, img_prev,img_curr,mask): - """ - Args: - img (PIL Image or Tensor): Image to be scaled. - - Returns: - PIL Image or Tensor: Rescaled image. - """ - - - return F.resize(img_prev, self.size, self.img_interpolation),F.resize(img_curr, self.size, self.img_interpolation),F.resize(mask, self.size, self.mask_interpolation) - - def __repr__(self): - interpolate_str = self.interpolation.value - return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format( - self.size, interpolate_str, self.max_size, self.antialias) - - -class Scale(Resize): - """ - Note: This transform is deprecated in favor of Resize. - """ - def __init__(self, *args, **kwargs): - warnings.warn("The use of the transforms.Scale transform is deprecated, " + - "please use transforms.Resize instead.") - super(Scale, self).__init__(*args, **kwargs) - - -class CenterCrop(torch.nn.Module): - """Crops the given image at the center. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. - If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). - """ - - def __init__(self, size): - super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - PIL Image or Tensor: Cropped image. - """ - return F.center_crop(img, self.size) - - def __repr__(self): - return self.__class__.__name__ + '(size={0})'.format(self.size) - - -class Pad(torch.nn.Module): - """Pad the given image on all sides with the given "pad" value. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, - at most 3 leading dimensions for mode edge, - and an arbitrary number of leading dimensions for mode constant - - Args: - padding (int or sequence): Padding on each border. If a single int is provided this - is used to pad all borders. If sequence of length 2 is provided this is the padding - on left/right and top/bottom respectively. If a sequence of length 4 is provided - this is the padding for the left, top, right and bottom borders respectively. - - .. note:: - In torchscript mode padding as single int is not supported, use a sequence of - length 1: ``[padding, ]``. - fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of - length 3, it is used to fill R, G, B channels respectively. - This value is only used when the padding_mode is constant. - Only number is supported for torch Tensor. - Only int or str or tuple value is supported for PIL Image. - padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. - Default is constant. - - - constant: pads with a constant value, this value is specified with fill - - - edge: pads with the last value at the edge of the image. - If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 - - - reflect: pads with reflection of image without repeating the last value on the edge. - For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode - will result in [3, 2, 1, 2, 3, 4, 3, 2] - - - symmetric: pads with reflection of image repeating the last value on the edge. - For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode - will result in [2, 1, 1, 2, 3, 4, 4, 3] - """ - - def __init__(self, padding, fill=0, padding_mode="constant"): - super().__init__() - if not isinstance(padding, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate padding arg") - - if not isinstance(fill, (numbers.Number, str, tuple)): - raise TypeError("Got inappropriate fill arg") - - if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: - raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") - - if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: - raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) - - self.padding = padding - self.fill = fill - self.padding_mode = padding_mode - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be padded. - - Returns: - PIL Image or Tensor: Padded image. - """ - return F.pad(img, self.padding, self.fill, self.padding_mode) - - def __repr__(self): - return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ - format(self.padding, self.fill, self.padding_mode) - - -class Lambda: - """Apply a user-defined lambda as a transform. This transform does not support torchscript. - - Args: - lambd (function): Lambda/function to be used for transform. - """ - - def __init__(self, lambd): - if not callable(lambd): - raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__))) - self.lambd = lambd - - def __call__(self, img): - return self.lambd(img) - - def __repr__(self): - return self.__class__.__name__ + '()' - - -class RandomTransforms: - """Base class for a list of transformations with randomness - - Args: - transforms (sequence): list of transformations - """ - - def __init__(self, transforms): - if not isinstance(transforms, Sequence): - raise TypeError("Argument transforms should be a sequence") - self.transforms = transforms - - def __call__(self, *args, **kwargs): - raise NotImplementedError() - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' - return format_string - - -class RandomApply(torch.nn.Module): - """Apply randomly a list of transformations with a given probability. - - .. note:: - In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of - transforms as shown below: - - >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ - >>> transforms.ColorJitter(), - >>> ]), p=0.3) - >>> scripted_transforms = torch.jit.script(transforms) - - Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require - `lambda` functions or ``PIL.Image``. - - Args: - transforms (sequence or torch.nn.Module): list of transformations - p (float): probability - """ - - def __init__(self, transforms, p=0.5): - super().__init__() - self.transforms = transforms - self.p = p - - def forward(self, img_prev,img_curr,mask): - if self.p < torch.rand(1): - return img_prev,img_curr,mask - for t in self.transforms: - img_prev,img_curr,mask = t(img_prev,img_curr,mask) - return img_prev,img_curr,mask - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - format_string += '\n p={}'.format(self.p) - for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' - return format_string - - -class RandomOrder(RandomTransforms): - """Apply a list of transformations in a random order. This transform does not support torchscript. - """ - def __call__(self, img,mask,wmap): - order = list(range(len(self.transforms))) - random.shuffle(order) - for i in order: - img,mask,wmap = self.transforms[i](img,mask,wmap) - return img,mask,wmap - - -class RandomChoice(RandomTransforms): - """Apply single transformation randomly picked from a list. This transform does not support torchscript. - """ - def __call__(self, img,mask,wmap): - t = random.choice(self.transforms) - return t(img,mask,wmap) - - -class RandomCrop(torch.nn.Module): - """Crop the given image at a random location. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions, - but if non-constant padding is used, the input is expected to have at most 2 leading dimensions - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). - padding (int or sequence, optional): Optional padding on each border - of the image. Default is None. If a single int is provided this - is used to pad all borders. If sequence of length 2 is provided this is the padding - on left/right and top/bottom respectively. If a sequence of length 4 is provided - this is the padding for the left, top, right and bottom borders respectively. - - .. note:: - In torchscript mode padding as single int is not supported, use a sequence of - length 1: ``[padding, ]``. - pad_if_needed (boolean): It will pad the image if smaller than the - desired size to avoid raising an exception. Since cropping is done - after padding, the padding seems to be done at a random offset. - fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of - length 3, it is used to fill R, G, B channels respectively. - This value is only used when the padding_mode is constant. - Only number is supported for torch Tensor. - Only int or str or tuple value is supported for PIL Image. - padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. - Default is constant. - - - constant: pads with a constant value, this value is specified with fill - - - edge: pads with the last value at the edge of the image. - If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 - - - reflect: pads with reflection of image without repeating the last value on the edge. - For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode - will result in [3, 2, 1, 2, 3, 4, 3, 2] - - - symmetric: pads with reflection of image repeating the last value on the edge. - For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode - will result in [2, 1, 1, 2, 3, 4, 4, 3] - """ - - @staticmethod - def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: - """Get parameters for ``crop`` for a random crop. - - Args: - img (PIL Image or Tensor): Image to be cropped. - output_size (tuple): Expected output size of the crop. - - Returns: - tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. - """ - w, h = F._get_image_size(img) - th, tw = output_size - - if h + 1 < th or w + 1 < tw: - raise ValueError( - "Required crop size {} is larger then input image size {}".format((th, tw), (h, w)) - ) - - if w == tw and h == th: - return 0, 0, h, w - - i = torch.randint(0, h - th + 1, size=(1, )).item() - j = torch.randint(0, w - tw + 1, size=(1, )).item() - return i, j, th, tw - - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): - super().__init__() - - self.size = tuple(_setup_size( - size, error_msg="Please provide only two dimensions (h, w) for size." - )) - - self.padding = padding - self.pad_if_needed = pad_if_needed - self.fill = fill - self.padding_mode = padding_mode - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - PIL Image or Tensor: Cropped image. - """ - if self.padding is not None: - img = F.pad(img, self.padding, self.fill, self.padding_mode) - - width, height = F._get_image_size(img) - # pad the width if needed - if self.pad_if_needed and width < self.size[1]: - padding = [self.size[1] - width, 0] - img = F.pad(img, padding, self.fill, self.padding_mode) - # pad the height if needed - if self.pad_if_needed and height < self.size[0]: - padding = [0, self.size[0] - height] - img = F.pad(img, padding, self.fill, self.padding_mode) - - i, j, h, w = self.get_params(img, self.size) - - return F.crop(img, i, j, h, w),F.crop(mask, i, j, h, w),F.crop(wmap, i, j, h, w) - - def __repr__(self): - return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) - - -class Pass(torch.nn.Module): - """ - No Transforms - """ - - def __init__(self): - super().__init__() - - def forward(self, img,mask,wmap): - return img,mask,wmap - - def __repr__(self): - return self.__class__.__name__ - -class RandomHorizontalFlip(torch.nn.Module): - """Horizontally flip the given image randomly with a given probability. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading - dimensions - - Args: - p (float): probability of the image being flipped. Default value is 0.5 - """ - - def __init__(self, p=0.5): - super().__init__() - self.p = p - - def forward(self, img_prev,img_curr,mask): - """ - Args: - img (PIL Image or Tensor): Image to be flipped. - - Returns: - PIL Image or Tensor: Randomly flipped image. - """ - if torch.rand(1) < self.p: - return F.hflip(img_prev),F.hflip(img_curr),F.hflip(mask) - return img_prev,img_curr,mask - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) - - -class RandomVerticalFlip(torch.nn.Module): - """Vertically flip the given image randomly with a given probability. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading - dimensions - - Args: - p (float): probability of the image being flipped. Default value is 0.5 - """ - - def __init__(self, p=0.5): - super().__init__() - self.p = p - - def forward(self, img_prev,img_curr,mask): - """ - Args: - img (PIL Image or Tensor): Image to be flipped. - - Returns: - PIL Image or Tensor: Randomly flipped image. - """ - if torch.rand(1) < self.p: - return F.vflip(img_prev),F.vflip(img_curr),F.vflip(mask) - return img_prev,img_curr,mask - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) - - -class RandomPerspective(torch.nn.Module): - """Performs a random perspective transformation of the given image with a given probability. - The image can be a PIL Image or a Tensor, in which case it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. - - Args: - distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. - Default is 0.5. - p (float): probability of the image being transformed. Default is 0.5. - interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and - ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated - image. If int or float, the value is used for all bands respectively. Default is 0. - This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor - input. Fill value for the area outside the transform in the output image is always 0. - - """ - - def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0): - super().__init__() - self.p = p - self.interpolation = interpolation - self.distortion_scale = distortion_scale - self.fill = fill - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be Perspectively transformed. - - Returns: - PIL Image or Tensor: Randomly transformed image. - """ - if torch.rand(1) < self.p: - width, height = F._get_image_size(img) - startpoints, endpoints = self.get_params(width, height, self.distortion_scale) - return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill) - return img - - @staticmethod - def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]: - """Get parameters for ``perspective`` for a random perspective transform. - - Args: - width (int): width of the image. - height (int): height of the image. - distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. - - Returns: - List containing [top-left, top-right, bottom-right, bottom-left] of the original image, - List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image. - """ - half_height = height // 2 - half_width = width // 2 - topleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) - ] - topright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) - ] - botright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) - ] - botleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) - ] - startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] - endpoints = [topleft, topright, botright, botleft] - return startpoints, endpoints - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) - - -# class RandomResizedCrop(torch.nn.Module): -# """Crop a random portion of image and resize it to a given size. -# -# If the image is torch Tensor, it is expected -# to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions -# -# A crop of the original image is made: the crop has a random area (H * W) -# and a random aspect ratio. This crop is finally resized to the given -# size. This is popularly used to train the Inception networks. -# -# Args: -# size (int or sequence): expected output size of the crop, for each edge. If size is an -# int instead of sequence like (h, w), a square output size ``(size, size)`` is -# made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). -# -# .. note:: -# In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. -# scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop, -# before resizing. The scale is defined with respect to the area of the original image. -# ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before -# resizing. -# interpolation (InterpolationMode): Desired interpolation enum defined by -# :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. -# If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and -# ``InterpolationMode.BICUBIC`` are supported. -# For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. -# -# """ -# -# def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR): -# super().__init__() -# self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") -# -# if not isinstance(scale, Sequence): -# raise TypeError("Scale should be a sequence") -# if not isinstance(ratio, Sequence): -# raise TypeError("Ratio should be a sequence") -# if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): -# warnings.warn("Scale and ratio should be of kind (min, max)") -# -# # Backward compatibility with integer value -# if isinstance(interpolation, int): -# warnings.warn( -# "Argument interpolation should be of type InterpolationMode instead of int. " -# "Please, use InterpolationMode enum." -# ) -# interpolation = _interpolation_modes_from_int(interpolation) -# -# self.interpolation = interpolation -# self.scale = scale -# self.ratio = ratio -# -# @staticmethod -# def get_params( -# img: Tensor, scale: List[float], ratio: List[float] -# ) -> Tuple[int, int, int, int]: -# """Get parameters for ``crop`` for a random sized crop. -# -# Args: -# img (PIL Image or Tensor): Input image. -# scale (list): range of scale of the origin size cropped -# ratio (list): range of aspect ratio of the origin aspect ratio cropped -# -# Returns: -# tuple: params (i, j, h, w) to be passed to ``crop`` for a random -# sized crop. -# """ -# width, height = F._get_image_size(img) -# area = height * width -# -# log_ratio = torch.log(torch.tensor(ratio)) -# for _ in range(10): -# target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() -# aspect_ratio = torch.exp( -# torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) -# ).item() -# -# w = int(round(math.sqrt(target_area * aspect_ratio))) -# h = int(round(math.sqrt(target_area / aspect_ratio))) -# -# if 0 < w <= width and 0 < h <= height: -# i = torch.randint(0, height - h + 1, size=(1,)).item() -# j = torch.randint(0, width - w + 1, size=(1,)).item() -# return i, j, h, w -# -# # Fallback to central crop -# in_ratio = float(width) / float(height) -# if in_ratio < min(ratio): -# w = width -# h = int(round(w / min(ratio))) -# elif in_ratio > max(ratio): -# h = height -# w = int(round(h * max(ratio))) -# else: # whole image -# w = width -# h = height -# i = (height - h) // 2 -# j = (width - w) // 2 -# return i, j, h, w -# -# def forward(self, img): -# """ -# Args: -# img (PIL Image or Tensor): Image to be cropped and resized. -# -# Returns: -# PIL Image or Tensor: Randomly cropped and resized image. -# """ -# i, j, h, w = self.get_params(img, self.scale, self.ratio) -# return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) -# -# def __repr__(self): -# interpolate_str = self.interpolation.value -# format_string = self.__class__.__name__ + '(size={0}'.format(self.size) -# format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) -# format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) -# format_string += ', interpolation={0})'.format(interpolate_str) -# return format_string - - -# class RandomSizedCrop(RandomResizedCrop): -# """ -# Note: This transform is deprecated in favor of RandomResizedCrop. -# """ -# def __init__(self, *args, **kwargs): -# warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + -# "please use transforms.RandomResizedCrop instead.") -# super(RandomSizedCrop, self).__init__(*args, **kwargs) -# -# -# class FiveCrop(torch.nn.Module): -# """Crop the given image into four corners and the central crop. -# If the image is torch Tensor, it is expected -# to have [..., H, W] shape, where ... means an arbitrary number of leading -# dimensions -# -# .. Note:: -# This transform returns a tuple of images and there may be a mismatch in the number of -# inputs and targets your Dataset returns. See below for an example of how to deal with -# this. -# -# Args: -# size (sequence or int): Desired output size of the crop. If size is an ``int`` -# instead of sequence like (h, w), a square crop of size (size, size) is made. -# If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). -# -# Example: -# >>> transform = Compose([ -# >>> FiveCrop(size), # this is a list of PIL Images -# >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor -# >>> ]) -# >>> #In your test loop you can do the following: -# >>> input, target = batch # input is a 5d tensor, target is 2d -# >>> bs, ncrops, c, h, w = input.size() -# >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops -# >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops -# """ -# -# def __init__(self, size): -# super().__init__() -# self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") -# -# def forward(self, img): -# """ -# Args: -# img (PIL Image or Tensor): Image to be cropped. -# -# Returns: -# tuple of 5 images. Image can be PIL Image or Tensor -# """ -# return F.five_crop(img, self.size) -# -# def __repr__(self): -# return self.__class__.__name__ + '(size={0})'.format(self.size) - - -class TenCrop(torch.nn.Module): - """Crop the given image into four corners and the central crop plus the flipped version of - these (horizontal flipping is used by default). - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading - dimensions - - .. Note:: - This transform returns a tuple of images and there may be a mismatch in the number of - inputs and targets your Dataset returns. See below for an example of how to deal with - this. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). - vertical_flip (bool): Use vertical flipping instead of horizontal - - Example: - >>> transform = Compose([ - >>> TenCrop(size), # this is a list of PIL Images - >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor - >>> ]) - >>> #In your test loop you can do the following: - >>> input, target = batch # input is a 5d tensor, target is 2d - >>> bs, ncrops, c, h, w = input.size() - >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops - >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops - """ - - def __init__(self, size, vertical_flip=False): - super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - self.vertical_flip = vertical_flip - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - tuple of 10 images. Image can be PIL Image or Tensor - """ - return F.ten_crop(img, self.size, self.vertical_flip) - - def __repr__(self): - return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) - - -class LinearTransformation(torch.nn.Module): - """Transform a tensor image with a square transformation matrix and a mean_vector computed - offline. - This transform does not support PIL Image. - Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and - subtract mean_vector from it which is then followed by computing the dot - product with the transformation matrix and then reshaping the tensor to its - original shape. - - Applications: - whitening transformation: Suppose X is a column vector zero-centered data. - Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X), - perform SVD on this matrix and pass it as transformation_matrix. - - Args: - transformation_matrix (Tensor): tensor [D x D], D = C x H x W - mean_vector (Tensor): tensor [D], D = C x H x W - """ - - def __init__(self, transformation_matrix, mean_vector): - super().__init__() - if transformation_matrix.size(0) != transformation_matrix.size(1): - raise ValueError("transformation_matrix should be square. Got " + - "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) - - if mean_vector.size(0) != transformation_matrix.size(0): - raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) + - " as any one of the dimensions of the transformation_matrix [{}]" - .format(tuple(transformation_matrix.size()))) - - if transformation_matrix.device != mean_vector.device: - raise ValueError("Input tensors should be on the same device. Got {} and {}" - .format(transformation_matrix.device, mean_vector.device)) - - self.transformation_matrix = transformation_matrix - self.mean_vector = mean_vector - - def forward(self, tensor: Tensor) -> Tensor: - """ - Args: - tensor (Tensor): Tensor image to be whitened. - - Returns: - Tensor: Transformed image. - """ - shape = tensor.shape - n = shape[-3] * shape[-2] * shape[-1] - if n != self.transformation_matrix.shape[0]: - raise ValueError("Input tensor and transformation matrix have incompatible shape." + - "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + - "{}".format(self.transformation_matrix.shape[0])) - - if tensor.device.type != self.mean_vector.device.type: - raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. " - "Got {} vs {}".format(tensor.device, self.mean_vector.device)) - - flat_tensor = tensor.view(-1, n) - self.mean_vector - transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) - tensor = transformed_tensor.view(shape) - return tensor - - def __repr__(self): - format_string = self.__class__.__name__ + '(transformation_matrix=' - format_string += (str(self.transformation_matrix.tolist()) + ')') - format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')') - return format_string - - -class ColorJitter(torch.nn.Module): - """Randomly change the brightness, contrast, saturation and hue of an image. - If the image is torch Tensor, it is expected - to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported. - - Args: - brightness (float or tuple of float (min, max)): How much to jitter brightness. - brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] - or the given [min, max]. Should be non negative numbers. - contrast (float or tuple of float (min, max)): How much to jitter contrast. - contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] - or the given [min, max]. Should be non negative numbers. - saturation (float or tuple of float (min, max)): How much to jitter saturation. - saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] - or the given [min, max]. Should be non negative numbers. - hue (float or tuple of float (min, max)): How much to jitter hue. - hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. - Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. - """ - - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): - super().__init__() - self.brightness = self._check_input(brightness, 'brightness') - self.contrast = self._check_input(contrast, 'contrast') - self.saturation = self._check_input(saturation, 'saturation') - self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), - clip_first_on_zero=False) - - @torch.jit.unused - def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): - if isinstance(value, numbers.Number): - if value < 0: - raise ValueError("If {} is a single number, it must be non negative.".format(name)) - value = [center - float(value), center + float(value)] - if clip_first_on_zero: - value[0] = max(value[0], 0.0) - elif isinstance(value, (tuple, list)) and len(value) == 2: - if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError("{} values should be between {}".format(name, bound)) - else: - raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name)) - - # if value is 0 or (1., 1.) for brightness/contrast/saturation - # or (0., 0.) for hue, do nothing - if value[0] == value[1] == center: - value = None - return value - - @staticmethod - def get_params(brightness: Optional[List[float]], - contrast: Optional[List[float]], - saturation: Optional[List[float]], - hue: Optional[List[float]] - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: - """Get the parameters for the randomized transform to be applied on image. - - Args: - brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen - uniformly. Pass None to turn off the transformation. - contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen - uniformly. Pass None to turn off the transformation. - saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen - uniformly. Pass None to turn off the transformation. - hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. - Pass None to turn off the transformation. - - Returns: - tuple: The parameters used to apply the randomized transform - along with their random order. - """ - fn_idx = torch.randperm(4) - - b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) - c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) - s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) - h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) - - return fn_idx, b, c, s, h - - def forward(self, img_prev,img_curr,mask): - """ - Args: - img (PIL Image or Tensor): Input image. - - Returns: - PIL Image or Tensor: Color jittered image. - """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ - self.get_params(self.brightness, self.contrast, self.saturation, self.hue) - - for fn_id in fn_idx: - if fn_id == 0 and brightness_factor is not None: - img_prev = F.adjust_brightness(img_prev, brightness_factor) - img_curr = F.adjust_brightness(img_curr, brightness_factor) - elif fn_id == 1 and contrast_factor is not None: - img_prev = F.adjust_contrast(img_prev, contrast_factor) - img_curr = F.adjust_contrast(img_curr, contrast_factor) - elif fn_id == 2 and saturation_factor is not None: - img_prev = F.adjust_saturation(img_prev, saturation_factor) - img_curr = F.adjust_saturation(img_curr, saturation_factor) - elif fn_id == 3 and hue_factor is not None: - img_prev = F.adjust_hue(img_prev, hue_factor) - img_curr = F.adjust_hue(img_curr, hue_factor) - - return img_prev,img_curr,mask - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - format_string += 'brightness={0}'.format(self.brightness) - format_string += ', contrast={0}'.format(self.contrast) - format_string += ', saturation={0}'.format(self.saturation) - format_string += ', hue={0})'.format(self.hue) - return format_string - -class RandomRotation(torch.nn.Module): - """Rotate the image by angle. - The image can be a PIL Image or a Tensor, in which case it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. - - Args: - degrees (sequence or float or int): Range of degrees to select from. - If degrees is a number instead of sequence like (min, max), the range of degrees - will be (-degrees, +degrees). - resample (int, optional): An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. - If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. - expand (bool, optional): Optional expansion flag. - If true, expands the output to make it large enough to hold the entire rotated image. - If false or omitted, make the output image the same size as the input image. - Note that the expand flag assumes rotation around the center and no translation. - center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner. - Default is the center of the image. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated - image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. - This option is not supported for Tensor input. Fill value for the area outside the transform in the output - image is always 0. - - .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters - - """ - - def __init__(self, degrees, resample=Image.NEAREST, expand=False, center=None, fill=0): - super().__init__() - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) - - if center is not None: - _check_sequence_input(center, "center", req_sizes=(2, )) - - self.center = center - - self.resample = resample - self.expand = expand - self.fill = fill - - @staticmethod - def get_params(degrees: List[float]) -> float: - """Get parameters for ``rotate`` for a random rotation. - - Returns: - float: angle parameter to be passed to ``rotate`` for random rotation. - """ - angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) - return angle - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be rotated. - - Returns: - PIL Image or Tensor: Rotated image. - """ - angle = self.get_params(self.degrees) - return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill),F.rotate(mask, angle, self.resample, self.expand, self.center, self.fill),F.rotate(wmap, angle, self.resample, self.expand, self.center, self.fill) - - - - -# class RandomAffine(torch.nn.Module): -# """Random affine transformation of the image keeping center invariant. -# If the image is torch Tensor, it is expected -# to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. -# -# Args: -# degrees (sequence or number): Range of degrees to select from. -# If degrees is a number instead of sequence like (min, max), the range of degrees -# will be (-degrees, +degrees). Set to 0 to deactivate rotations. -# translate (tuple, optional): tuple of maximum absolute fraction for horizontal -# and vertical translations. For example translate=(a, b), then horizontal shift -# is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is -# randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. -# scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is -# randomly sampled from the range a <= scale <= b. Will keep original scale by default. -# shear (sequence or number, optional): Range of degrees to select from. -# If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) -# will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the -# range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values, -# a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. -# Will not apply shear by default. -# interpolation (InterpolationMode): Desired interpolation enum defined by -# :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. -# If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. -# For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. -# fill (sequence or number): Pixel fill value for the area outside the transformed -# image. Default is ``0``. If given a number, the value is used for all bands respectively. -# fillcolor (sequence or number, optional): deprecated argument and will be removed since v0.10.0. -# Please use the ``fill`` parameter instead. -# resample (int, optional): deprecated argument and will be removed since v0.10.0. -# Please use the ``interpolation`` parameter instead. -# -# .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters -# -# """ -# -# def __init__( -# self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, -# fillcolor=None, resample=None -# ): -# super().__init__() -# if resample is not None: -# warnings.warn( -# "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" -# ) -# interpolation = _interpolation_modes_from_int(resample) -# -# # Backward compatibility with integer value -# if isinstance(interpolation, int): -# warnings.warn( -# "Argument interpolation should be of type InterpolationMode instead of int. " -# "Please, use InterpolationMode enum." -# ) -# interpolation = _interpolation_modes_from_int(interpolation) -# -# if fillcolor is not None: -# warnings.warn( -# "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead" -# ) -# fill = fillcolor -# -# self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) -# -# if translate is not None: -# _check_sequence_input(translate, "translate", req_sizes=(2, )) -# for t in translate: -# if not (0.0 <= t <= 1.0): -# raise ValueError("translation values should be between 0 and 1") -# self.translate = translate -# -# if scale is not None: -# _check_sequence_input(scale, "scale", req_sizes=(2, )) -# for s in scale: -# if s <= 0: -# raise ValueError("scale values should be positive") -# self.scale = scale -# -# if shear is not None: -# self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) -# else: -# self.shear = shear -# -# self.resample = self.interpolation = interpolation -# -# if fill is None: -# fill = 0 -# elif not isinstance(fill, (Sequence, numbers.Number)): -# raise TypeError("Fill should be either a sequence or a number.") -# -# self.fillcolor = self.fill = fill -# -# @staticmethod -# def get_params( -# degrees: List[float], -# translate: Optional[List[float]], -# scale_ranges: Optional[List[float]], -# shears: Optional[List[float]], -# img_size: List[int] -# ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: -# """Get parameters for affine transformation -# -# Returns: -# params to be passed to the affine transformation -# """ -# angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) -# if translate is not None: -# max_dx = float(translate[0] * img_size[0]) -# max_dy = float(translate[1] * img_size[1]) -# tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) -# ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) -# translations = (tx, ty) -# else: -# translations = (0, 0) -# -# if scale_ranges is not None: -# scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item()) -# else: -# scale = 1.0 -# -# shear_x = shear_y = 0.0 -# if shears is not None: -# shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item()) -# if len(shears) == 4: -# shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item()) -# -# shear = (shear_x, shear_y) -# -# return angle, translations, scale, shear -# -# def forward(self, img): -# """ -# img (PIL Image or Tensor): Image to be transformed. -# -# Returns: -# PIL Image or Tensor: Affine transformed image. -# """ -# fill = self.fill -# if isinstance(img, Tensor): -# if isinstance(fill, (int, float)): -# fill = [float(fill)] * F._get_image_num_channels(img) -# else: -# fill = [float(f) for f in fill] -# -# img_size = F._get_image_size(img) -# -# ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) -# -# return F.affine(img, *ret, interpolation=self.interpolation, fill=fill) -# -# def __repr__(self): -# s = '{name}(degrees={degrees}' -# if self.translate is not None: -# s += ', translate={translate}' -# if self.scale is not None: -# s += ', scale={scale}' -# if self.shear is not None: -# s += ', shear={shear}' -# if self.interpolation != InterpolationMode.NEAREST: -# s += ', interpolation={interpolation}' -# if self.fill != 0: -# s += ', fill={fill}' -# s += ')' -# d = dict(self.__dict__) -# d['interpolation'] = self.interpolation.value -# return s.format(name=self.__class__.__name__, **d) - - -class Grayscale(torch.nn.Module): - """Convert image to grayscale. - If the image is torch Tensor, it is expected - to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions - - Args: - num_output_channels (int): (1 or 3) number of channels desired for output image - - Returns: - PIL Image: Grayscale version of the input. - - - If ``num_output_channels == 1`` : returned image is single channel - - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b - - """ - - def __init__(self, num_output_channels=1): - super().__init__() - self.num_output_channels = num_output_channels - - def forward(self, img_prev,img_curr,mask): - """ - Args: - img (PIL Image or Tensor): Image to be converted to grayscale. - - Returns: - PIL Image or Tensor: Grayscaled image. - """ - return F.rgb_to_grayscale(img_prev, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(img_curr, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(mask, num_output_channels=self.num_output_channels) - - def __repr__(self): - return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) - -class GetBoundingBoxes(torch.nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, img,mask): - import numpy as np - - A = np.array([ - [0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 0, 2, 2, 0], - [0, 1, 1, 0, 2, 2, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 4, 4, 0, 3, 3, 0], - [0, 4, 4, 0, 3, 3, 0], - [0, 0, 0, 0, 0, 0, 0] - ]) - - bboxCorners = {} - for i in range(1, A.max() + 1): - B = np.argwhere(A == i) - bboxCorners[i] = B.min(0), B.max(0) - - print(bboxCorners) - return img - return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(mask, num_output_channels=self.num_output_channels) - - def __repr__(self): - return self.__class__.__name__ - - -class RandomGrayscale(torch.nn.Module): - """Randomly convert image to grayscale with a probability of p (default 0.1). - If the image is torch Tensor, it is expected - to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions - - Args: - p (float): probability that image should be converted to grayscale. - - Returns: - PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged - with probability (1-p). - - If input image is 1 channel: grayscale version is 1 channel - - If input image is 3 channel: grayscale version is 3 channel with r == g == b - - """ - - def __init__(self, p=0.1): - super().__init__() - self.p = p - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be converted to grayscale. - - Returns: - PIL Image or Tensor: Randomly grayscaled image. - """ - num_output_channels = F._get_image_num_channels(img) - if torch.rand(1) < self.p: - return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) - return img - - def __repr__(self): - return self.__class__.__name__ + '(p={0})'.format(self.p) - - -class RandomErasing(torch.nn.Module): - """ Randomly selects a rectangle region in an torch Tensor image and erases its pixels. - This transform does not support PIL Image. - 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 - - Args: - p: probability that the random erasing operation will be performed. - scale: range of proportion of erased area against input image. - ratio: range of aspect ratio of erased area. - value: erasing value. Default is 0. If a single int, it is used to - erase all pixels. If a tuple of length 3, it is used to erase - R, G, B channels respectively. - If a str of 'random', erasing each pixel with random values. - inplace: boolean to make this transform inplace. Default set to False. - - Returns: - Erased Image. - - Example: - >>> transform = transforms.Compose([ - >>> transforms.RandomHorizontalFlip(), - >>> transforms.ToTensor(), - >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - >>> transforms.RandomErasing(), - >>> ]) - """ - - def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): - super().__init__() - if not isinstance(value, (numbers.Number, str, tuple, list)): - raise TypeError("Argument value should be either a number or str or a sequence") - if isinstance(value, str) and value != "random": - raise ValueError("If value is str, it should be 'random'") - if not isinstance(scale, (tuple, list)): - raise TypeError("Scale should be a sequence") - if not isinstance(ratio, (tuple, list)): - raise TypeError("Ratio should be a sequence") - if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): - warnings.warn("Scale and ratio should be of kind (min, max)") - if scale[0] < 0 or scale[1] > 1: - raise ValueError("Scale should be between 0 and 1") - if p < 0 or p > 1: - raise ValueError("Random erasing probability should be between 0 and 1") - - self.p = p - self.scale = scale - self.ratio = ratio - self.value = value - self.inplace = inplace - - @staticmethod - def get_params( - img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None - ) -> Tuple[int, int, int, int, Tensor]: - """Get parameters for ``erase`` for a random erasing. - - Args: - img (Tensor): Tensor image to be erased. - scale (sequence): range of proportion of erased area against input image. - ratio (sequence): range of aspect ratio of erased area. - value (list, optional): erasing value. If None, it is interpreted as "random" - (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number, - i.e. ``value[0]``. - - Returns: - tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing. - """ - img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1] - area = img_h * img_w - - log_ratio = torch.log(torch.tensor(ratio)) - for _ in range(10): - erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() - aspect_ratio = torch.exp( - torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) - ).item() - - h = int(round(math.sqrt(erase_area * aspect_ratio))) - w = int(round(math.sqrt(erase_area / aspect_ratio))) - if not (h < img_h and w < img_w): - continue - - if value is None: - v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() - else: - v = torch.tensor(value)[:, None, None] - - i = torch.randint(0, img_h - h + 1, size=(1, )).item() - j = torch.randint(0, img_w - w + 1, size=(1, )).item() - return i, j, h, w, v - - # Return original image - return 0, 0, img_h, img_w, img - - def forward(self, img): - """ - Args: - img (Tensor): Tensor image to be erased. - - Returns: - img (Tensor): Erased Tensor image. - """ - if torch.rand(1) < self.p: - - # cast self.value to script acceptable type - if isinstance(self.value, (int, float)): - value = [self.value, ] - elif isinstance(self.value, str): - value = None - elif isinstance(self.value, tuple): - value = list(self.value) - else: - value = self.value - - if value is not None and not (len(value) in (1, img.shape[-3])): - raise ValueError( - "If value is a sequence, it should have either a single value or " - "{} (number of input channels)".format(img.shape[-3]) - ) - - x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value) - return F.erase(img, x, y, h, w, v, self.inplace) - return img - - def __repr__(self): - s = '(p={}, '.format(self.p) - s += 'scale={}, '.format(self.scale) - s += 'ratio={}, '.format(self.ratio) - s += 'value={}, '.format(self.value) - s += 'inplace={})'.format(self.inplace) - return self.__class__.__name__ + s - - -class GaussianBlur(torch.nn.Module): - """Blurs image with randomly chosen Gaussian blur. - If the image is torch Tensor, it is expected - to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions. - - Args: - kernel_size (int or sequence): Size of the Gaussian kernel. - sigma (float or tuple of float (min, max)): Standard deviation to be used for - creating kernel to perform blurring. If float, sigma is fixed. If it is tuple - of float (min, max), sigma is chosen uniformly at random to lie in the - given range. - - Returns: - PIL Image or Tensor: Gaussian blurred version of the input image. - - """ - - def __init__(self, kernel_size, sigma=(0.1, 2.0)): - super().__init__() - self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") - for ks in self.kernel_size: - if ks <= 0 or ks % 2 == 0: - raise ValueError("Kernel size value should be an odd and positive number.") - - if isinstance(sigma, numbers.Number): - if sigma <= 0: - raise ValueError("If sigma is a single number, it must be positive.") - sigma = (sigma, sigma) - elif isinstance(sigma, Sequence) and len(sigma) == 2: - if not 0. < sigma[0] <= sigma[1]: - raise ValueError("sigma values should be positive and of the form (min, max).") - else: - raise ValueError("sigma should be a single number or a list/tuple with length 2.") - - self.sigma = sigma - - @staticmethod - def get_params(sigma_min: float, sigma_max: float) -> float: - """Choose sigma for random gaussian blurring. - - Args: - sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel. - sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel. - - Returns: - float: Standard deviation to be passed to calculate kernel for gaussian blurring. - """ - return torch.empty(1).uniform_(sigma_min, sigma_max).item() - - def forward(self, img_prev: Tensor,img_curr: Tensor,mask: Tensor) -> Tensor: - """ - Args: - img (PIL Image or Tensor): image to be blurred. - - Returns: - PIL Image or Tensor: Gaussian blurred image - """ - sigma = self.get_params(self.sigma[0], self.sigma[1]) - return F.gaussian_blur(img_prev, self.kernel_size, [sigma, sigma]),F.gaussian_blur(img_curr, self.kernel_size, [sigma, sigma]),mask - - def __repr__(self): - s = '(kernel_size={}, '.format(self.kernel_size) - s += 'sigma={})'.format(self.sigma) - return self.__class__.__name__ + s - - -def _setup_size(size, error_msg): - if isinstance(size, numbers.Number): - return int(size), int(size) - - if isinstance(size, Sequence) and len(size) == 1: - return size[0], size[0] - - if len(size) != 2: - raise ValueError(error_msg) - - return size - - -def _check_sequence_input(x, name, req_sizes): - msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes]) - if not isinstance(x, Sequence): - raise TypeError("{} should be a sequence of length {}.".format(name, msg)) - if len(x) not in req_sizes: - raise ValueError("{} should be sequence of length {}.".format(name, msg)) - - -def _setup_angle(x, name, req_sizes=(2, )): - if isinstance(x, numbers.Number): - if x < 0: - raise ValueError("If {} is a single number, it must be positive.".format(name)) - x = [-x, x] - else: - _check_sequence_input(x, name, req_sizes) - - return [float(d) for d in x] - - -class RandomInvert(torch.nn.Module): - """Inverts the colors of the given image randomly with a given probability. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, - where ... means it can have an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - p (float): probability of the image being color inverted. Default value is 0.5 - """ - - def __init__(self, p=0.5): - super().__init__() - self.p = p - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be inverted. - - Returns: - PIL Image or Tensor: Randomly color inverted image. - """ - if torch.rand(1).item() < self.p: - return F.invert(img) - return img - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) - - -class RandomPosterize(torch.nn.Module): - """Posterize the image randomly with a given probability by reducing the - number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8, - and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - bits (int): number of bits to keep for each channel (0-8) - p (float): probability of the image being color inverted. Default value is 0.5 - """ - - def __init__(self, bits, p=0.5): - super().__init__() - self.bits = bits - self.p = p - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be posterized. - - Returns: - PIL Image or Tensor: Randomly posterized image. - """ - if torch.rand(1).item() < self.p: - return F.posterize(img, self.bits) - return img - - def __repr__(self): - return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) - - -class RandomSolarize(torch.nn.Module): - """Solarize the image randomly with a given probability by inverting all pixel - values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, - where ... means it can have an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - threshold (float): all pixels equal or above this value are inverted. - p (float): probability of the image being color inverted. Default value is 0.5 - """ - - def __init__(self, threshold, p=0.5): - super().__init__() - self.threshold = threshold - self.p = p - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be solarized. - - Returns: - PIL Image or Tensor: Randomly solarized image. - """ - if torch.rand(1).item() < self.p: - return F.solarize(img, self.threshold) - return img - - def __repr__(self): - return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) - - -class RandomAdjustSharpness(torch.nn.Module): - """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor, - it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - - Args: - sharpness_factor (float): How much to adjust the sharpness. Can be - any non negative number. 0 gives a blurred image, 1 gives the - original image while 2 increases the sharpness by a factor of 2. - p (float): probability of the image being color inverted. Default value is 0.5 - """ - - def __init__(self, sharpness_factor, p=0.5): - super().__init__() - self.sharpness_factor = sharpness_factor - self.p = p - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be sharpened. - - Returns: - PIL Image or Tensor: Randomly sharpened image. - """ - if torch.rand(1).item() < self.p: - return F.adjust_sharpness(img, self.sharpness_factor),mask,wmap - return img,mask,wmap - - def __repr__(self): - return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) - - -class RandomAutocontrast(torch.nn.Module): - """Autocontrast the pixels of the given image randomly with a given probability. - If the image is torch Tensor, it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - p (float): probability of the image being autocontrasted. Default value is 0.5 - """ - - def __init__(self, p=0.5): - super().__init__() - self.p = p - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be autocontrasted. - - Returns: - PIL Image or Tensor: Randomly autocontrasted image. - """ - if torch.rand(1).item() < self.p: - return F.autocontrast(img) - return img - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) - - -class RandomEqualize(torch.nn.Module): - """Equalize the histogram of the given image randomly with a given probability. - If the image is torch Tensor, it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". - - Args: - p (float): probability of the image being equalized. Default value is 0.5 - """ - - def __init__(self, p=0.5): - super().__init__() - self.p = p - - def forward(self, img,mask,wmap): - """ - Args: - img (PIL Image or Tensor): Image to be equalized. - - Returns: - PIL Image or Tensor: Randomly equalized image. - """ - if torch.rand(1).item() < self.p: - return F.equalize(img),mask,wmap - return img,mask,wmap - - def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) +import math +import numbers +import random +import warnings +from collections.abc import Sequence +from typing import Tuple, List, Optional +from PIL import Image +import torch +from torch import Tensor + +try: + import accimage +except ImportError: + accimage = None + + +from torchvision.transforms import functional as F + + +__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", + "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", + "RandomHorizontalFlip", "RandomVerticalFlip", "TenCrop", + "LinearTransformation", "ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale", + "RandomPerspective", "RandomErasing", "GaussianBlur"] + + +class Compose: + """Composes several transforms together. This transform does not support torchscript. + Please, see the note below. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + + .. note:: + In order to script the transformations, please use ``torch.nn.Sequential`` as below. + + >>> transforms = torch.nn.Sequential( + >>> transforms.CenterCrop(10), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> ) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img,mask,wmap): + for t in self.transforms: + img,mask,wmap = t(img,mask,wmap) + return img,mask,wmap + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class ToTensor: + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript. + + Converts a PIL Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] + if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) + or if the numpy.ndarray has dtype = np.uint8 + + In the other cases, tensors are returned without scaling. + + .. note:: + Because the input image is scaled to [0.0, 1.0], this transformation should not be used when + transforming target image masks. See the `references`_ for implementing the transforms for image masks. + + .. _references: https://github.com/pytorch/vision/tree/master/references/segmentation + """ + + def __call__(self, img_prev,img_curr,mask): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return F.to_tensor(img_prev),F.to_tensor(img_curr),F.to_tensor(mask) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class PILToTensor: + """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript. + + Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return F.pil_to_tensor(pic) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class ConvertImageDtype(torch.nn.Module): + """Convert a tensor image to the given ``dtype`` and scale the values accordingly + This function does not support PIL Image. + + Args: + dtype (torch.dtype): Desired data type of the output + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + + def __init__(self, dtype: torch.dtype) -> None: + super().__init__() + self.dtype = dtype + + def forward(self, image): + return F.convert_image_dtype(image, self.dtype) + + +class ToPILImage: + """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript. + + Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape + H x W x C to a PIL Image while preserving the value range. + + Args: + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + If ``mode`` is ``None`` (default) there are some assumptions made about the input data: + - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. + - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. + - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. + - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, + ``short``). + + .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes + """ + def __init__(self, mode=None): + self.mode = mode + + def __call__(self, pic): + """ + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. + + Returns: + PIL Image: Image converted to PIL Image. + + """ + return F.to_pil_image(pic, self.mode) + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + if self.mode is not None: + format_string += 'mode={0}'.format(self.mode) + format_string += ')' + return format_string + + +class Normalize(torch.nn.Module): + """Normalize a tensor image with mean and standard deviation. + This transform does not support PIL Image. + Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` + channels, this transform will normalize each channel of the input + ``torch.*Tensor`` i.e., + ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` + + .. note:: + This transform acts out of place, i.e., it does not mutate the input tensor. + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation in-place. + + """ + + def __init__(self, mean, std, inplace=False): + super().__init__() + self.mean = mean + self.std = std + self.inplace = inplace + + def forward(self, img_prev: Tensor,img_curr: Tensor,mask: Tensor) -> Tensor: + """ + Args: + tensor (Tensor): Tensor image to be normalized. + + Returns: + Tensor: Normalized Tensor image. + """ + return F.normalize(img_prev, self.mean, self.std, self.inplace),F.normalize(img_curr, self.mean, self.std, self.inplace),mask + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + + +class Resize(torch.nn.Module): + """Resize the input image to the given size. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size). + In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[size, ]``. + interpolation (int, optional): Desired interpolation enum defined by `filters`_. + Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` + and ``PIL.Image.BICUBIC`` are supported. + """ + + def __init__(self, size, img_interpolation=Image.BILINEAR,mask_interpolation=Image.NEAREST): + super().__init__() + if not isinstance(size, (int, Sequence)): + raise TypeError("Size should be int or sequence. Got {}".format(type(size))) + if isinstance(size, Sequence) and len(size) not in (1, 2): + raise ValueError("If size is a sequence, it should have 1 or 2 values") + self.size = size + self.img_interpolation = img_interpolation + self.mask_interpolation = mask_interpolation + + def forward(self, img_prev,img_curr,mask): + """ + Args: + img (PIL Image or Tensor): Image to be scaled. + + Returns: + PIL Image or Tensor: Rescaled image. + """ + + + return F.resize(img_prev, self.size, self.img_interpolation),F.resize(img_curr, self.size, self.img_interpolation),F.resize(mask, self.size, self.mask_interpolation) + + def __repr__(self): + interpolate_str = self.interpolation.value + return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format( + self.size, interpolate_str, self.max_size, self.antialias) + + +class Scale(Resize): + """ + Note: This transform is deprecated in favor of Resize. + """ + def __init__(self, *args, **kwargs): + warnings.warn("The use of the transforms.Scale transform is deprecated, " + + "please use transforms.Resize instead.") + super(Scale, self).__init__(*args, **kwargs) + + +class CenterCrop(torch.nn.Module): + """Crops the given image at the center. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + """ + + def __init__(self, size): + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + return F.center_crop(img, self.size) + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + + +class Pad(torch.nn.Module): + """Pad the given image on all sides with the given "pad" value. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, + at most 3 leading dimensions for mode edge, + and an arbitrary number of leading dimensions for mode constant + + Args: + padding (int or sequence): Padding on each border. If a single int is provided this + is used to pad all borders. If sequence of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a sequence of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. + + .. note:: + In torchscript mode padding as single int is not supported, use a sequence of + length 1: ``[padding, ]``. + fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. + Only number is supported for torch Tensor. + Only int or str or tuple value is supported for PIL Image. + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image. + If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 + + - reflect: pads with reflection of image without repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + """ + + def __init__(self, padding, fill=0, padding_mode="constant"): + super().__init__() + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + + if not isinstance(fill, (numbers.Number, str, tuple)): + raise TypeError("Got inappropriate fill arg") + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: + raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + self.padding = padding + self.fill = fill + self.padding_mode = padding_mode + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be padded. + + Returns: + PIL Image or Tensor: Padded image. + """ + return F.pad(img, self.padding, self.fill, self.padding_mode) + + def __repr__(self): + return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ + format(self.padding, self.fill, self.padding_mode) + + +class Lambda: + """Apply a user-defined lambda as a transform. This transform does not support torchscript. + + Args: + lambd (function): Lambda/function to be used for transform. + """ + + def __init__(self, lambd): + if not callable(lambd): + raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__))) + self.lambd = lambd + + def __call__(self, img): + return self.lambd(img) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class RandomTransforms: + """Base class for a list of transformations with randomness + + Args: + transforms (sequence): list of transformations + """ + + def __init__(self, transforms): + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence") + self.transforms = transforms + + def __call__(self, *args, **kwargs): + raise NotImplementedError() + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class RandomApply(torch.nn.Module): + """Apply randomly a list of transformations with a given probability. + + .. note:: + In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of + transforms as shown below: + + >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ + >>> transforms.ColorJitter(), + >>> ]), p=0.3) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + Args: + transforms (sequence or torch.nn.Module): list of transformations + p (float): probability + """ + + def __init__(self, transforms, p=0.5): + super().__init__() + self.transforms = transforms + self.p = p + + def forward(self, img_prev,img_curr,mask): + if self.p < torch.rand(1): + return img_prev,img_curr,mask + for t in self.transforms: + img_prev,img_curr,mask = t(img_prev,img_curr,mask) + return img_prev,img_curr,mask + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + format_string += '\n p={}'.format(self.p) + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class RandomOrder(RandomTransforms): + """Apply a list of transformations in a random order. This transform does not support torchscript. + """ + def __call__(self, img,mask,wmap): + order = list(range(len(self.transforms))) + random.shuffle(order) + for i in order: + img,mask,wmap = self.transforms[i](img,mask,wmap) + return img,mask,wmap + + +class RandomChoice(RandomTransforms): + """Apply single transformation randomly picked from a list. This transform does not support torchscript. + """ + def __call__(self, img,mask,wmap): + t = random.choice(self.transforms) + return t(img,mask,wmap) + + +class RandomCrop(torch.nn.Module): + """Crop the given image at a random location. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions, + but if non-constant padding is used, the input is expected to have at most 2 leading dimensions + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + padding (int or sequence, optional): Optional padding on each border + of the image. Default is None. If a single int is provided this + is used to pad all borders. If sequence of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a sequence of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. + + .. note:: + In torchscript mode padding as single int is not supported, use a sequence of + length 1: ``[padding, ]``. + pad_if_needed (boolean): It will pad the image if smaller than the + desired size to avoid raising an exception. Since cropping is done + after padding, the padding seems to be done at a random offset. + fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. + Only number is supported for torch Tensor. + Only int or str or tuple value is supported for PIL Image. + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image. + If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 + + - reflect: pads with reflection of image without repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + """ + + @staticmethod + def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL Image or Tensor): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + w, h = F._get_image_size(img) + th, tw = output_size + + if h + 1 < th or w + 1 < tw: + raise ValueError( + "Required crop size {} is larger then input image size {}".format((th, tw), (h, w)) + ) + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1, )).item() + j = torch.randint(0, w - tw + 1, size=(1, )).item() + return i, j, th, tw + + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): + super().__init__() + + self.size = tuple(_setup_size( + size, error_msg="Please provide only two dimensions (h, w) for size." + )) + + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + width, height = F._get_image_size(img) + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img = F.pad(img, padding, self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img = F.pad(img, padding, self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img, self.size) + + return F.crop(img, i, j, h, w),F.crop(mask, i, j, h, w),F.crop(wmap, i, j, h, w) + + def __repr__(self): + return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) + + +class Pass(torch.nn.Module): + """ + No Transforms + """ + + def __init__(self): + super().__init__() + + def forward(self, img,mask,wmap): + return img,mask,wmap + + def __repr__(self): + return self.__class__.__name__ + +class RandomHorizontalFlip(torch.nn.Module): + """Horizontally flip the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img_prev,img_curr,mask): + """ + Args: + img (PIL Image or Tensor): Image to be flipped. + + Returns: + PIL Image or Tensor: Randomly flipped image. + """ + if torch.rand(1) < self.p: + return F.hflip(img_prev),F.hflip(img_curr),F.hflip(mask) + return img_prev,img_curr,mask + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomVerticalFlip(torch.nn.Module): + """Vertically flip the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img_prev,img_curr,mask): + """ + Args: + img (PIL Image or Tensor): Image to be flipped. + + Returns: + PIL Image or Tensor: Randomly flipped image. + """ + if torch.rand(1) < self.p: + return F.vflip(img_prev),F.vflip(img_curr),F.vflip(mask) + return img_prev,img_curr,mask + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomPerspective(torch.nn.Module): + """Performs a random perspective transformation of the given image with a given probability. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. + Default is 0.5. + p (float): probability of the image being transformed. Default is 0.5. + interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and + ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. Default is 0. + This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor + input. Fill value for the area outside the transform in the output image is always 0. + + """ + + def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0): + super().__init__() + self.p = p + self.interpolation = interpolation + self.distortion_scale = distortion_scale + self.fill = fill + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be Perspectively transformed. + + Returns: + PIL Image or Tensor: Randomly transformed image. + """ + if torch.rand(1) < self.p: + width, height = F._get_image_size(img) + startpoints, endpoints = self.get_params(width, height, self.distortion_scale) + return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill) + return img + + @staticmethod + def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]: + """Get parameters for ``perspective`` for a random perspective transform. + + Args: + width (int): width of the image. + height (int): height of the image. + distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. + + Returns: + List containing [top-left, top-right, bottom-right, bottom-left] of the original image, + List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image. + """ + half_height = height // 2 + half_width = width // 2 + topleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + ] + topright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + ] + botright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + ] + botleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + ] + startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] + endpoints = [topleft, topright, botright, botleft] + return startpoints, endpoints + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +# class RandomResizedCrop(torch.nn.Module): +# """Crop a random portion of image and resize it to a given size. +# +# If the image is torch Tensor, it is expected +# to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions +# +# A crop of the original image is made: the crop has a random area (H * W) +# and a random aspect ratio. This crop is finally resized to the given +# size. This is popularly used to train the Inception networks. +# +# Args: +# size (int or sequence): expected output size of the crop, for each edge. If size is an +# int instead of sequence like (h, w), a square output size ``(size, size)`` is +# made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). +# +# .. note:: +# In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. +# scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop, +# before resizing. The scale is defined with respect to the area of the original image. +# ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before +# resizing. +# interpolation (InterpolationMode): Desired interpolation enum defined by +# :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. +# If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and +# ``InterpolationMode.BICUBIC`` are supported. +# For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. +# +# """ +# +# def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR): +# super().__init__() +# self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") +# +# if not isinstance(scale, Sequence): +# raise TypeError("Scale should be a sequence") +# if not isinstance(ratio, Sequence): +# raise TypeError("Ratio should be a sequence") +# if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): +# warnings.warn("Scale and ratio should be of kind (min, max)") +# +# # Backward compatibility with integer value +# if isinstance(interpolation, int): +# warnings.warn( +# "Argument interpolation should be of type InterpolationMode instead of int. " +# "Please, use InterpolationMode enum." +# ) +# interpolation = _interpolation_modes_from_int(interpolation) +# +# self.interpolation = interpolation +# self.scale = scale +# self.ratio = ratio +# +# @staticmethod +# def get_params( +# img: Tensor, scale: List[float], ratio: List[float] +# ) -> Tuple[int, int, int, int]: +# """Get parameters for ``crop`` for a random sized crop. +# +# Args: +# img (PIL Image or Tensor): Input image. +# scale (list): range of scale of the origin size cropped +# ratio (list): range of aspect ratio of the origin aspect ratio cropped +# +# Returns: +# tuple: params (i, j, h, w) to be passed to ``crop`` for a random +# sized crop. +# """ +# width, height = F._get_image_size(img) +# area = height * width +# +# log_ratio = torch.log(torch.tensor(ratio)) +# for _ in range(10): +# target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() +# aspect_ratio = torch.exp( +# torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) +# ).item() +# +# w = int(round(math.sqrt(target_area * aspect_ratio))) +# h = int(round(math.sqrt(target_area / aspect_ratio))) +# +# if 0 < w <= width and 0 < h <= height: +# i = torch.randint(0, height - h + 1, size=(1,)).item() +# j = torch.randint(0, width - w + 1, size=(1,)).item() +# return i, j, h, w +# +# # Fallback to central crop +# in_ratio = float(width) / float(height) +# if in_ratio < min(ratio): +# w = width +# h = int(round(w / min(ratio))) +# elif in_ratio > max(ratio): +# h = height +# w = int(round(h * max(ratio))) +# else: # whole image +# w = width +# h = height +# i = (height - h) // 2 +# j = (width - w) // 2 +# return i, j, h, w +# +# def forward(self, img): +# """ +# Args: +# img (PIL Image or Tensor): Image to be cropped and resized. +# +# Returns: +# PIL Image or Tensor: Randomly cropped and resized image. +# """ +# i, j, h, w = self.get_params(img, self.scale, self.ratio) +# return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) +# +# def __repr__(self): +# interpolate_str = self.interpolation.value +# format_string = self.__class__.__name__ + '(size={0}'.format(self.size) +# format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) +# format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) +# format_string += ', interpolation={0})'.format(interpolate_str) +# return format_string + + +# class RandomSizedCrop(RandomResizedCrop): +# """ +# Note: This transform is deprecated in favor of RandomResizedCrop. +# """ +# def __init__(self, *args, **kwargs): +# warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + +# "please use transforms.RandomResizedCrop instead.") +# super(RandomSizedCrop, self).__init__(*args, **kwargs) +# +# +# class FiveCrop(torch.nn.Module): +# """Crop the given image into four corners and the central crop. +# If the image is torch Tensor, it is expected +# to have [..., H, W] shape, where ... means an arbitrary number of leading +# dimensions +# +# .. Note:: +# This transform returns a tuple of images and there may be a mismatch in the number of +# inputs and targets your Dataset returns. See below for an example of how to deal with +# this. +# +# Args: +# size (sequence or int): Desired output size of the crop. If size is an ``int`` +# instead of sequence like (h, w), a square crop of size (size, size) is made. +# If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). +# +# Example: +# >>> transform = Compose([ +# >>> FiveCrop(size), # this is a list of PIL Images +# >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor +# >>> ]) +# >>> #In your test loop you can do the following: +# >>> input, target = batch # input is a 5d tensor, target is 2d +# >>> bs, ncrops, c, h, w = input.size() +# >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops +# >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops +# """ +# +# def __init__(self, size): +# super().__init__() +# self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") +# +# def forward(self, img): +# """ +# Args: +# img (PIL Image or Tensor): Image to be cropped. +# +# Returns: +# tuple of 5 images. Image can be PIL Image or Tensor +# """ +# return F.five_crop(img, self.size) +# +# def __repr__(self): +# return self.__class__.__name__ + '(size={0})'.format(self.size) + + +class TenCrop(torch.nn.Module): + """Crop the given image into four corners and the central crop plus the flipped version of + these (horizontal flipping is used by default). + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + .. Note:: + This transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your Dataset returns. See below for an example of how to deal with + this. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + vertical_flip (bool): Use vertical flipping instead of horizontal + + Example: + >>> transform = Compose([ + >>> TenCrop(size), # this is a list of PIL Images + >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor + >>> ]) + >>> #In your test loop you can do the following: + >>> input, target = batch # input is a 5d tensor, target is 2d + >>> bs, ncrops, c, h, w = input.size() + >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops + >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops + """ + + def __init__(self, size, vertical_flip=False): + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.vertical_flip = vertical_flip + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + tuple of 10 images. Image can be PIL Image or Tensor + """ + return F.ten_crop(img, self.size, self.vertical_flip) + + def __repr__(self): + return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) + + +class LinearTransformation(torch.nn.Module): + """Transform a tensor image with a square transformation matrix and a mean_vector computed + offline. + This transform does not support PIL Image. + Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and + subtract mean_vector from it which is then followed by computing the dot + product with the transformation matrix and then reshaping the tensor to its + original shape. + + Applications: + whitening transformation: Suppose X is a column vector zero-centered data. + Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X), + perform SVD on this matrix and pass it as transformation_matrix. + + Args: + transformation_matrix (Tensor): tensor [D x D], D = C x H x W + mean_vector (Tensor): tensor [D], D = C x H x W + """ + + def __init__(self, transformation_matrix, mean_vector): + super().__init__() + if transformation_matrix.size(0) != transformation_matrix.size(1): + raise ValueError("transformation_matrix should be square. Got " + + "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) + + if mean_vector.size(0) != transformation_matrix.size(0): + raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) + + " as any one of the dimensions of the transformation_matrix [{}]" + .format(tuple(transformation_matrix.size()))) + + if transformation_matrix.device != mean_vector.device: + raise ValueError("Input tensors should be on the same device. Got {} and {}" + .format(transformation_matrix.device, mean_vector.device)) + + self.transformation_matrix = transformation_matrix + self.mean_vector = mean_vector + + def forward(self, tensor: Tensor) -> Tensor: + """ + Args: + tensor (Tensor): Tensor image to be whitened. + + Returns: + Tensor: Transformed image. + """ + shape = tensor.shape + n = shape[-3] * shape[-2] * shape[-1] + if n != self.transformation_matrix.shape[0]: + raise ValueError("Input tensor and transformation matrix have incompatible shape." + + "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + + "{}".format(self.transformation_matrix.shape[0])) + + if tensor.device.type != self.mean_vector.device.type: + raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. " + "Got {} vs {}".format(tensor.device, self.mean_vector.device)) + + flat_tensor = tensor.view(-1, n) - self.mean_vector + transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) + tensor = transformed_tensor.view(shape) + return tensor + + def __repr__(self): + format_string = self.__class__.__name__ + '(transformation_matrix=' + format_string += (str(self.transformation_matrix.tolist()) + ')') + format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')') + return format_string + + +class ColorJitter(torch.nn.Module): + """Randomly change the brightness, contrast, saturation and hue of an image. + If the image is torch Tensor, it is expected + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported. + + Args: + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + super().__init__() + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), + clip_first_on_zero=False) + + @torch.jit.unused + def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError("If {} is a single number, it must be non negative.".format(name)) + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, bound)) + else: + raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name)) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def get_params(brightness: Optional[List[float]], + contrast: Optional[List[float]], + saturation: Optional[List[float]], + hue: Optional[List[float]] + ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: + """Get the parameters for the randomized transform to be applied on image. + + Args: + brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen + uniformly. Pass None to turn off the transformation. + contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen + uniformly. Pass None to turn off the transformation. + saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen + uniformly. Pass None to turn off the transformation. + hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. + Pass None to turn off the transformation. + + Returns: + tuple: The parameters used to apply the randomized transform + along with their random order. + """ + fn_idx = torch.randperm(4) + + b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) + c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) + s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) + h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) + + return fn_idx, b, c, s, h + + def forward(self, img_prev,img_curr,mask): + """ + Args: + img (PIL Image or Tensor): Input image. + + Returns: + PIL Image or Tensor: Color jittered image. + """ + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ + self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img_prev = F.adjust_brightness(img_prev, brightness_factor) + img_curr = F.adjust_brightness(img_curr, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img_prev = F.adjust_contrast(img_prev, contrast_factor) + img_curr = F.adjust_contrast(img_curr, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img_prev = F.adjust_saturation(img_prev, saturation_factor) + img_curr = F.adjust_saturation(img_curr, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img_prev = F.adjust_hue(img_prev, hue_factor) + img_curr = F.adjust_hue(img_curr, hue_factor) + + return img_prev,img_curr,mask + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + format_string += 'brightness={0}'.format(self.brightness) + format_string += ', contrast={0}'.format(self.contrast) + format_string += ', saturation={0}'.format(self.saturation) + format_string += ', hue={0})'.format(self.hue) + return format_string + +class RandomRotation(torch.nn.Module): + """Rotate the image by angle. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + degrees (sequence or float or int): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). + resample (int, optional): An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. + expand (bool, optional): Optional expansion flag. + If true, expands the output to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner. + Default is the center of the image. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. + This option is not supported for Tensor input. Fill value for the area outside the transform in the output + image is always 0. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + def __init__(self, degrees, resample=Image.NEAREST, expand=False, center=None, fill=0): + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2, )) + + self.center = center + + self.resample = resample + self.expand = expand + self.fill = fill + + @staticmethod + def get_params(degrees: List[float]) -> float: + """Get parameters for ``rotate`` for a random rotation. + + Returns: + float: angle parameter to be passed to ``rotate`` for random rotation. + """ + angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) + return angle + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be rotated. + + Returns: + PIL Image or Tensor: Rotated image. + """ + angle = self.get_params(self.degrees) + return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill),F.rotate(mask, angle, self.resample, self.expand, self.center, self.fill),F.rotate(wmap, angle, self.resample, self.expand, self.center, self.fill) + + + + +# class RandomAffine(torch.nn.Module): +# """Random affine transformation of the image keeping center invariant. +# If the image is torch Tensor, it is expected +# to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. +# +# Args: +# degrees (sequence or number): Range of degrees to select from. +# If degrees is a number instead of sequence like (min, max), the range of degrees +# will be (-degrees, +degrees). Set to 0 to deactivate rotations. +# translate (tuple, optional): tuple of maximum absolute fraction for horizontal +# and vertical translations. For example translate=(a, b), then horizontal shift +# is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is +# randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. +# scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is +# randomly sampled from the range a <= scale <= b. Will keep original scale by default. +# shear (sequence or number, optional): Range of degrees to select from. +# If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) +# will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the +# range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values, +# a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. +# Will not apply shear by default. +# interpolation (InterpolationMode): Desired interpolation enum defined by +# :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. +# If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. +# For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. +# fill (sequence or number): Pixel fill value for the area outside the transformed +# image. Default is ``0``. If given a number, the value is used for all bands respectively. +# fillcolor (sequence or number, optional): deprecated argument and will be removed since v0.10.0. +# Please use the ``fill`` parameter instead. +# resample (int, optional): deprecated argument and will be removed since v0.10.0. +# Please use the ``interpolation`` parameter instead. +# +# .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters +# +# """ +# +# def __init__( +# self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, +# fillcolor=None, resample=None +# ): +# super().__init__() +# if resample is not None: +# warnings.warn( +# "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" +# ) +# interpolation = _interpolation_modes_from_int(resample) +# +# # Backward compatibility with integer value +# if isinstance(interpolation, int): +# warnings.warn( +# "Argument interpolation should be of type InterpolationMode instead of int. " +# "Please, use InterpolationMode enum." +# ) +# interpolation = _interpolation_modes_from_int(interpolation) +# +# if fillcolor is not None: +# warnings.warn( +# "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead" +# ) +# fill = fillcolor +# +# self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) +# +# if translate is not None: +# _check_sequence_input(translate, "translate", req_sizes=(2, )) +# for t in translate: +# if not (0.0 <= t <= 1.0): +# raise ValueError("translation values should be between 0 and 1") +# self.translate = translate +# +# if scale is not None: +# _check_sequence_input(scale, "scale", req_sizes=(2, )) +# for s in scale: +# if s <= 0: +# raise ValueError("scale values should be positive") +# self.scale = scale +# +# if shear is not None: +# self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) +# else: +# self.shear = shear +# +# self.resample = self.interpolation = interpolation +# +# if fill is None: +# fill = 0 +# elif not isinstance(fill, (Sequence, numbers.Number)): +# raise TypeError("Fill should be either a sequence or a number.") +# +# self.fillcolor = self.fill = fill +# +# @staticmethod +# def get_params( +# degrees: List[float], +# translate: Optional[List[float]], +# scale_ranges: Optional[List[float]], +# shears: Optional[List[float]], +# img_size: List[int] +# ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: +# """Get parameters for affine transformation +# +# Returns: +# params to be passed to the affine transformation +# """ +# angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) +# if translate is not None: +# max_dx = float(translate[0] * img_size[0]) +# max_dy = float(translate[1] * img_size[1]) +# tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) +# ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) +# translations = (tx, ty) +# else: +# translations = (0, 0) +# +# if scale_ranges is not None: +# scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item()) +# else: +# scale = 1.0 +# +# shear_x = shear_y = 0.0 +# if shears is not None: +# shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item()) +# if len(shears) == 4: +# shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item()) +# +# shear = (shear_x, shear_y) +# +# return angle, translations, scale, shear +# +# def forward(self, img): +# """ +# img (PIL Image or Tensor): Image to be transformed. +# +# Returns: +# PIL Image or Tensor: Affine transformed image. +# """ +# fill = self.fill +# if isinstance(img, Tensor): +# if isinstance(fill, (int, float)): +# fill = [float(fill)] * F._get_image_num_channels(img) +# else: +# fill = [float(f) for f in fill] +# +# img_size = F._get_image_size(img) +# +# ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) +# +# return F.affine(img, *ret, interpolation=self.interpolation, fill=fill) +# +# def __repr__(self): +# s = '{name}(degrees={degrees}' +# if self.translate is not None: +# s += ', translate={translate}' +# if self.scale is not None: +# s += ', scale={scale}' +# if self.shear is not None: +# s += ', shear={shear}' +# if self.interpolation != InterpolationMode.NEAREST: +# s += ', interpolation={interpolation}' +# if self.fill != 0: +# s += ', fill={fill}' +# s += ')' +# d = dict(self.__dict__) +# d['interpolation'] = self.interpolation.value +# return s.format(name=self.__class__.__name__, **d) + + +class Grayscale(torch.nn.Module): + """Convert image to grayscale. + If the image is torch Tensor, it is expected + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions + + Args: + num_output_channels (int): (1 or 3) number of channels desired for output image + + Returns: + PIL Image: Grayscale version of the input. + + - If ``num_output_channels == 1`` : returned image is single channel + - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b + + """ + + def __init__(self, num_output_channels=1): + super().__init__() + self.num_output_channels = num_output_channels + + def forward(self, img_prev,img_curr,mask): + """ + Args: + img (PIL Image or Tensor): Image to be converted to grayscale. + + Returns: + PIL Image or Tensor: Grayscaled image. + """ + return F.rgb_to_grayscale(img_prev, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(img_curr, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(mask, num_output_channels=self.num_output_channels) + + def __repr__(self): + return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) + +class GetBoundingBoxes(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, img,mask): + import numpy as np + + A = np.array([ + [0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 2, 2, 0], + [0, 1, 1, 0, 2, 2, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 4, 4, 0, 3, 3, 0], + [0, 4, 4, 0, 3, 3, 0], + [0, 0, 0, 0, 0, 0, 0] + ]) + + bboxCorners = {} + for i in range(1, A.max() + 1): + B = np.argwhere(A == i) + bboxCorners[i] = B.min(0), B.max(0) + + print(bboxCorners) + return img + return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels),F.rgb_to_grayscale(mask, num_output_channels=self.num_output_channels) + + def __repr__(self): + return self.__class__.__name__ + + +class RandomGrayscale(torch.nn.Module): + """Randomly convert image to grayscale with a probability of p (default 0.1). + If the image is torch Tensor, it is expected + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions + + Args: + p (float): probability that image should be converted to grayscale. + + Returns: + PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged + with probability (1-p). + - If input image is 1 channel: grayscale version is 1 channel + - If input image is 3 channel: grayscale version is 3 channel with r == g == b + + """ + + def __init__(self, p=0.1): + super().__init__() + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be converted to grayscale. + + Returns: + PIL Image or Tensor: Randomly grayscaled image. + """ + num_output_channels = F._get_image_num_channels(img) + if torch.rand(1) < self.p: + return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={0})'.format(self.p) + + +class RandomErasing(torch.nn.Module): + """ Randomly selects a rectangle region in an torch Tensor image and erases its pixels. + This transform does not support PIL Image. + 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 + + Args: + p: probability that the random erasing operation will be performed. + scale: range of proportion of erased area against input image. + ratio: range of aspect ratio of erased area. + value: erasing value. Default is 0. If a single int, it is used to + erase all pixels. If a tuple of length 3, it is used to erase + R, G, B channels respectively. + If a str of 'random', erasing each pixel with random values. + inplace: boolean to make this transform inplace. Default set to False. + + Returns: + Erased Image. + + Example: + >>> transform = transforms.Compose([ + >>> transforms.RandomHorizontalFlip(), + >>> transforms.ToTensor(), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> transforms.RandomErasing(), + >>> ]) + """ + + def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): + super().__init__() + if not isinstance(value, (numbers.Number, str, tuple, list)): + raise TypeError("Argument value should be either a number or str or a sequence") + if isinstance(value, str) and value != "random": + raise ValueError("If value is str, it should be 'random'") + if not isinstance(scale, (tuple, list)): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, (tuple, list)): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + if scale[0] < 0 or scale[1] > 1: + raise ValueError("Scale should be between 0 and 1") + if p < 0 or p > 1: + raise ValueError("Random erasing probability should be between 0 and 1") + + self.p = p + self.scale = scale + self.ratio = ratio + self.value = value + self.inplace = inplace + + @staticmethod + def get_params( + img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None + ) -> Tuple[int, int, int, int, Tensor]: + """Get parameters for ``erase`` for a random erasing. + + Args: + img (Tensor): Tensor image to be erased. + scale (sequence): range of proportion of erased area against input image. + ratio (sequence): range of aspect ratio of erased area. + value (list, optional): erasing value. If None, it is interpreted as "random" + (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number, + i.e. ``value[0]``. + + Returns: + tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing. + """ + img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1] + area = img_h * img_w + + log_ratio = torch.log(torch.tensor(ratio)) + for _ in range(10): + erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() + + h = int(round(math.sqrt(erase_area * aspect_ratio))) + w = int(round(math.sqrt(erase_area / aspect_ratio))) + if not (h < img_h and w < img_w): + continue + + if value is None: + v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() + else: + v = torch.tensor(value)[:, None, None] + + i = torch.randint(0, img_h - h + 1, size=(1, )).item() + j = torch.randint(0, img_w - w + 1, size=(1, )).item() + return i, j, h, w, v + + # Return original image + return 0, 0, img_h, img_w, img + + def forward(self, img): + """ + Args: + img (Tensor): Tensor image to be erased. + + Returns: + img (Tensor): Erased Tensor image. + """ + if torch.rand(1) < self.p: + + # cast self.value to script acceptable type + if isinstance(self.value, (int, float)): + value = [self.value, ] + elif isinstance(self.value, str): + value = None + elif isinstance(self.value, tuple): + value = list(self.value) + else: + value = self.value + + if value is not None and not (len(value) in (1, img.shape[-3])): + raise ValueError( + "If value is a sequence, it should have either a single value or " + "{} (number of input channels)".format(img.shape[-3]) + ) + + x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value) + return F.erase(img, x, y, h, w, v, self.inplace) + return img + + def __repr__(self): + s = '(p={}, '.format(self.p) + s += 'scale={}, '.format(self.scale) + s += 'ratio={}, '.format(self.ratio) + s += 'value={}, '.format(self.value) + s += 'inplace={})'.format(self.inplace) + return self.__class__.__name__ + s + + +class GaussianBlur(torch.nn.Module): + """Blurs image with randomly chosen Gaussian blur. + If the image is torch Tensor, it is expected + to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + kernel_size (int or sequence): Size of the Gaussian kernel. + sigma (float or tuple of float (min, max)): Standard deviation to be used for + creating kernel to perform blurring. If float, sigma is fixed. If it is tuple + of float (min, max), sigma is chosen uniformly at random to lie in the + given range. + + Returns: + PIL Image or Tensor: Gaussian blurred version of the input image. + + """ + + def __init__(self, kernel_size, sigma=(0.1, 2.0)): + super().__init__() + self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") + for ks in self.kernel_size: + if ks <= 0 or ks % 2 == 0: + raise ValueError("Kernel size value should be an odd and positive number.") + + if isinstance(sigma, numbers.Number): + if sigma <= 0: + raise ValueError("If sigma is a single number, it must be positive.") + sigma = (sigma, sigma) + elif isinstance(sigma, Sequence) and len(sigma) == 2: + if not 0. < sigma[0] <= sigma[1]: + raise ValueError("sigma values should be positive and of the form (min, max).") + else: + raise ValueError("sigma should be a single number or a list/tuple with length 2.") + + self.sigma = sigma + + @staticmethod + def get_params(sigma_min: float, sigma_max: float) -> float: + """Choose sigma for random gaussian blurring. + + Args: + sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel. + sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel. + + Returns: + float: Standard deviation to be passed to calculate kernel for gaussian blurring. + """ + return torch.empty(1).uniform_(sigma_min, sigma_max).item() + + def forward(self, img_prev: Tensor,img_curr: Tensor,mask: Tensor) -> Tensor: + """ + Args: + img (PIL Image or Tensor): image to be blurred. + + Returns: + PIL Image or Tensor: Gaussian blurred image + """ + sigma = self.get_params(self.sigma[0], self.sigma[1]) + return F.gaussian_blur(img_prev, self.kernel_size, [sigma, sigma]),F.gaussian_blur(img_curr, self.kernel_size, [sigma, sigma]),mask + + def __repr__(self): + s = '(kernel_size={}, '.format(self.kernel_size) + s += 'sigma={})'.format(self.sigma) + return self.__class__.__name__ + s + + +def _setup_size(size, error_msg): + if isinstance(size, numbers.Number): + return int(size), int(size) + + if isinstance(size, Sequence) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + + +def _check_sequence_input(x, name, req_sizes): + msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes]) + if not isinstance(x, Sequence): + raise TypeError("{} should be a sequence of length {}.".format(name, msg)) + if len(x) not in req_sizes: + raise ValueError("{} should be sequence of length {}.".format(name, msg)) + + +def _setup_angle(x, name, req_sizes=(2, )): + if isinstance(x, numbers.Number): + if x < 0: + raise ValueError("If {} is a single number, it must be positive.".format(name)) + x = [-x, x] + else: + _check_sequence_input(x, name, req_sizes) + + return [float(d) for d in x] + + +class RandomInvert(torch.nn.Module): + """Inverts the colors of the given image randomly with a given probability. + If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be inverted. + + Returns: + PIL Image or Tensor: Randomly color inverted image. + """ + if torch.rand(1).item() < self.p: + return F.invert(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomPosterize(torch.nn.Module): + """Posterize the image randomly with a given probability by reducing the + number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8, + and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + bits (int): number of bits to keep for each channel (0-8) + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, bits, p=0.5): + super().__init__() + self.bits = bits + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be posterized. + + Returns: + PIL Image or Tensor: Randomly posterized image. + """ + if torch.rand(1).item() < self.p: + return F.posterize(img, self.bits) + return img + + def __repr__(self): + return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) + + +class RandomSolarize(torch.nn.Module): + """Solarize the image randomly with a given probability by inverting all pixel + values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + threshold (float): all pixels equal or above this value are inverted. + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, threshold, p=0.5): + super().__init__() + self.threshold = threshold + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be solarized. + + Returns: + PIL Image or Tensor: Randomly solarized image. + """ + if torch.rand(1).item() < self.p: + return F.solarize(img, self.threshold) + return img + + def __repr__(self): + return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) + + +class RandomAdjustSharpness(torch.nn.Module): + """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor, + it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness_factor (float): How much to adjust the sharpness. Can be + any non negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, sharpness_factor, p=0.5): + super().__init__() + self.sharpness_factor = sharpness_factor + self.p = p + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be sharpened. + + Returns: + PIL Image or Tensor: Randomly sharpened image. + """ + if torch.rand(1).item() < self.p: + return F.adjust_sharpness(img, self.sharpness_factor),mask,wmap + return img,mask,wmap + + def __repr__(self): + return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) + + +class RandomAutocontrast(torch.nn.Module): + """Autocontrast the pixels of the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + p (float): probability of the image being autocontrasted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be autocontrasted. + + Returns: + PIL Image or Tensor: Randomly autocontrasted image. + """ + if torch.rand(1).item() < self.p: + return F.autocontrast(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomEqualize(torch.nn.Module): + """Equalize the histogram of the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". + + Args: + p (float): probability of the image being equalized. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img,mask,wmap): + """ + Args: + img (PIL Image or Tensor): Image to be equalized. + + Returns: + PIL Image or Tensor: Randomly equalized image. + """ + if torch.rand(1).item() < self.p: + return F.equalize(img),mask,wmap + return img,mask,wmap + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) diff --git a/train_segmentation.py b/deepsea/train_segmentation.py similarity index 97% rename from train_segmentation.py rename to deepsea/train_segmentation.py index c099d3b..8e4c500 100644 --- a/train_segmentation.py +++ b/deepsea/train_segmentation.py @@ -1,157 +1,157 @@ -import os -import argparse -from model import DeepSeaSegmentation -from data import BasicSegmentationDataset -import torch.nn as nn -from evaluate import evaluate_segmentation -from loss import dice_loss -import torch.optim as optim -import torch.optim.lr_scheduler as lr_scheduler -import torch.utils.data as data -import torch.nn.functional as F -import segmentation_transforms as transforms -import torch -import numpy as np -import os -import random -from tqdm import tqdm -import logging - -SEED = 42 -random.seed(SEED) -np.random.seed(SEED) -torch.manual_seed(SEED) -torch.cuda.manual_seed(SEED) -torch.backends.cudnn.deterministic = True - -def train(args,image_size = [383,512],image_means = [0.5],image_stds= [0.5],valid_ratio = 0.8,save_checkpoint=True,if_train_aug=False,train_aug_iter=3,patience=5): - - logging.basicConfig(filename=os.path.join(args.output_dir, 'train.log'), filemode='w',format='%(asctime)s - %(message)s', level=logging.INFO) - logging.info('>>>> image size=(%d,%d) , learning rate=%f , batch size=%d' % (image_size[0], image_size[1],args.lr,args.batch_size)) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - if if_train_aug: - train_transforms = transforms.Compose([ - transforms.Grayscale(num_output_channels=1), - transforms.RandomApply([ - transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0), - transforms.GaussianBlur((3, 3), sigma=(0.1, 1.0)), - transforms.RandomHorizontalFlip(0.5), - transforms.RandomVerticalFlip(0.5), - ],p=1-1/train_aug_iter), - transforms.Resize(image_size), - transforms.ToTensor(), - transforms.Normalize(mean = image_means,std = image_stds) - ]) - else: - train_transforms = transforms.Compose([ - transforms.Resize(image_size), - transforms.Grayscale(num_output_channels=1), - transforms.ToTensor(), - transforms.Normalize(mean=image_means, - std=image_stds) - ]) - - - train_data = BasicSegmentationDataset(os.path.join(args.train_set_dir, 'images'), os.path.join(args.train_set_dir, 'masks'),os.path.join(args.train_set_dir, 'wmaps'),transforms=train_transforms,if_train_aug=if_train_aug,train_aug_iter=train_aug_iter) - - n_train_examples = int(len(train_data) * valid_ratio) - n_valid_examples = len(train_data) - n_train_examples - - train_data, valid_data = data.random_split(train_data,[n_train_examples, n_valid_examples],generator=torch.Generator().manual_seed(SEED)) - - train_iterator = data.DataLoader(train_data,shuffle = True,batch_size = args.batch_size) - - valid_iterator = data.DataLoader(valid_data,batch_size = args.batch_size) - - model=DeepSeaSegmentation(n_channels=1, n_classes=2, bilinear=True) - - optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) - - STEPS_PER_EPOCH = len(train_iterator) - TOTAL_STEPS = args.max_epoch * STEPS_PER_EPOCH - MAX_LRS = [p['lr'] for p in optimizer.param_groups] - scheduler = lr_scheduler.OneCycleLR(optimizer,max_lr=MAX_LRS,total_steps=TOTAL_STEPS) - grad_scaler = torch.cuda.amp.GradScaler(enabled=True) - criterion = nn.CrossEntropyLoss() - model = model.to(device) - criterion = criterion.to(device) - nstop=0 - avg_precision_best=0 - logging.info('>>>> Start training') - print('INFO: Start training ...') - for epoch in range(args.max_epoch): - model.train() - epoch_loss = 0 - with tqdm(total=n_train_examples, desc=f'Epoch {epoch + 1}/{args.max_epoch}', unit='img') as pbar: - for step,batch in enumerate(train_iterator): - images = batch['image'] - true_masks = batch['mask'] - true_wmaps = batch['wmap'] - assert images.shape[1] == model.n_channels, \ - f'Network has been defined with {model.n_channels} input channels, ' \ - f'but loaded images have {images.shape[1]} channels. Please check that ' \ - 'the images are loaded correctly.' - - images = images.to(device=device, dtype=torch.float32) - true_masks = true_masks.to(device=device, dtype=torch.long) - true_wmaps = true_wmaps.to(device=device, dtype=torch.long) - true_masks=torch.squeeze(true_masks, dim=1) - true_wmaps = torch.squeeze(true_wmaps, dim=1) - with torch.cuda.amp.autocast(enabled=True): - masks_preds,wmap_preds = model(images) - loss = criterion(masks_preds, true_masks) \ - +1*criterion(wmap_preds, true_wmaps) \ - + dice_loss(F.softmax(masks_preds, dim=1).float(), - F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(), - multiclass=True) - - optimizer.zero_grad(set_to_none=True) - grad_scaler.scale(loss).backward() - grad_scaler.step(optimizer) - grad_scaler.update() - pbar.update(images.shape[0]) - epoch_loss += loss.item() - pbar.set_postfix(**{'loss': epoch_loss/(step+1)}) - - # Evaluation round - val_score,avg_precision,_,_ = evaluate_segmentation(model, valid_iterator, device,n_valid_examples,is_avg_prec=((1+epoch)%2==0),prec_thresholds=[0.5]) - if avg_precision is not None: - logging.info('>>>> Epoch:%d , loss=%f , valid score=%f , avg precision=%f' % ( - epoch, epoch_loss / (step+1), val_score, avg_precision[0])) - else: - logging.info('>>>> Epoch:%d , loss=%f , valid score=%f' % ( - epoch, epoch_loss / (step + 1), val_score)) - - ## Save best checkpoint corresponding the best average precision - if avg_precision is not None and avg_precision>avg_precision_best: - avg_precision_best=avg_precision - states = model.state_dict() - if save_checkpoint: - logging.info('>>>> save model to %s'%(os.path.join(args.output_dir,'segmentation.pth'))) - torch.save(states, os.path.join(args.output_dir,'segmentation.pth')) - nstop=0 - elif avg_precision is not None and avg_precision<=avg_precision_best: - nstop+=1 - if nstop==patience:#Early Stopping - print('INFO: Early Stopping met ...') - print('INFO: Finish training process') - break - scheduler.step() - - - - -if __name__ == "__main__": - ap = argparse.ArgumentParser() - ap.add_argument("--train_set_dir",required=True,type=str,help="path for the train dataset") - ap.add_argument("--lr", default=1e-3,type=float, help="learning rate") - ap.add_argument("--max_epoch", default=200, type=int, help="maximum epoch to train model") - ap.add_argument("--batch_size", default=16, type=int, help="train batch size") - ap.add_argument("--output_dir", required=True, type=str, help="path for saving the train log and best checkpoint") - - args = ap.parse_args() - assert os.path.isdir(args.train_set_dir), 'No such file or directory: ' + args.train_set_dir - if not os.path.isdir(args.output_dir): - os.makedirs(args.output_dir) - +import os +import argparse +from model import DeepSeaSegmentation +from data import BasicSegmentationDataset +import torch.nn as nn +from evaluate import evaluate_segmentation +from loss import dice_loss +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler +import torch.utils.data as data +import torch.nn.functional as F +import segmentation_transforms as transforms +import torch +import numpy as np +import os +import random +from tqdm import tqdm +import logging + +SEED = 42 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True + +def train(args,image_size = [383,512],image_means = [0.5],image_stds= [0.5],valid_ratio = 0.8,save_checkpoint=True,if_train_aug=False,train_aug_iter=3,patience=5): + + logging.basicConfig(filename=os.path.join(args.output_dir, 'train.log'), filemode='w',format='%(asctime)s - %(message)s', level=logging.INFO) + logging.info('>>>> image size=(%d,%d) , learning rate=%f , batch size=%d' % (image_size[0], image_size[1],args.lr,args.batch_size)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if if_train_aug: + train_transforms = transforms.Compose([ + transforms.Grayscale(num_output_channels=1), + transforms.RandomApply([ + transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0), + transforms.GaussianBlur((3, 3), sigma=(0.1, 1.0)), + transforms.RandomHorizontalFlip(0.5), + transforms.RandomVerticalFlip(0.5), + ],p=1-1/train_aug_iter), + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize(mean = image_means,std = image_stds) + ]) + else: + train_transforms = transforms.Compose([ + transforms.Resize(image_size), + transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Normalize(mean=image_means, + std=image_stds) + ]) + + + train_data = BasicSegmentationDataset(os.path.join(args.train_set_dir, 'images'), os.path.join(args.train_set_dir, 'masks'),os.path.join(args.train_set_dir, 'wmaps'),transforms=train_transforms,if_train_aug=if_train_aug,train_aug_iter=train_aug_iter) + + n_train_examples = int(len(train_data) * valid_ratio) + n_valid_examples = len(train_data) - n_train_examples + + train_data, valid_data = data.random_split(train_data,[n_train_examples, n_valid_examples],generator=torch.Generator().manual_seed(SEED)) + + train_iterator = data.DataLoader(train_data,shuffle = True,batch_size = args.batch_size) + + valid_iterator = data.DataLoader(valid_data,batch_size = args.batch_size) + + model=DeepSeaSegmentation(n_channels=1, n_classes=2, bilinear=True) + + optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) + + STEPS_PER_EPOCH = len(train_iterator) + TOTAL_STEPS = args.max_epoch * STEPS_PER_EPOCH + MAX_LRS = [p['lr'] for p in optimizer.param_groups] + scheduler = lr_scheduler.OneCycleLR(optimizer,max_lr=MAX_LRS,total_steps=TOTAL_STEPS) + grad_scaler = torch.cuda.amp.GradScaler(enabled=True) + criterion = nn.CrossEntropyLoss() + model = model.to(device) + criterion = criterion.to(device) + nstop=0 + avg_precision_best=0 + logging.info('>>>> Start training') + print('INFO: Start training ...') + for epoch in range(args.max_epoch): + model.train() + epoch_loss = 0 + with tqdm(total=n_train_examples, desc=f'Epoch {epoch + 1}/{args.max_epoch}', unit='img') as pbar: + for step,batch in enumerate(train_iterator): + images = batch['image'] + true_masks = batch['mask'] + true_wmaps = batch['wmap'] + assert images.shape[1] == model.n_channels, \ + f'Network has been defined with {model.n_channels} input channels, ' \ + f'but loaded images have {images.shape[1]} channels. Please check that ' \ + 'the images are loaded correctly.' + + images = images.to(device=device, dtype=torch.float32) + true_masks = true_masks.to(device=device, dtype=torch.long) + true_wmaps = true_wmaps.to(device=device, dtype=torch.long) + true_masks=torch.squeeze(true_masks, dim=1) + true_wmaps = torch.squeeze(true_wmaps, dim=1) + with torch.cuda.amp.autocast(enabled=True): + masks_preds,wmap_preds = model(images) + loss = criterion(masks_preds, true_masks) \ + +1*criterion(wmap_preds, true_wmaps) \ + + dice_loss(F.softmax(masks_preds, dim=1).float(), + F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(), + multiclass=True) + + optimizer.zero_grad(set_to_none=True) + grad_scaler.scale(loss).backward() + grad_scaler.step(optimizer) + grad_scaler.update() + pbar.update(images.shape[0]) + epoch_loss += loss.item() + pbar.set_postfix(**{'loss': epoch_loss/(step+1)}) + + # Evaluation round + val_score,avg_precision,_,_ = evaluate_segmentation(model, valid_iterator, device,n_valid_examples,is_avg_prec=((1+epoch)%2==0),prec_thresholds=[0.5]) + if avg_precision is not None: + logging.info('>>>> Epoch:%d , loss=%f , valid score=%f , avg precision=%f' % ( + epoch, epoch_loss / (step+1), val_score, avg_precision[0])) + else: + logging.info('>>>> Epoch:%d , loss=%f , valid score=%f' % ( + epoch, epoch_loss / (step + 1), val_score)) + + ## Save best checkpoint corresponding the best average precision + if avg_precision is not None and avg_precision>avg_precision_best: + avg_precision_best=avg_precision + states = model.state_dict() + if save_checkpoint: + logging.info('>>>> save model to %s'%(os.path.join(args.output_dir,'segmentation.pth'))) + torch.save(states, os.path.join(args.output_dir,'segmentation.pth')) + nstop=0 + elif avg_precision is not None and avg_precision<=avg_precision_best: + nstop+=1 + if nstop==patience:#Early Stopping + print('INFO: Early Stopping met ...') + print('INFO: Finish training process') + break + scheduler.step() + + + + +if __name__ == "__main__": + ap = argparse.ArgumentParser() + ap.add_argument("--train_set_dir",required=True,type=str,help="path for the train dataset") + ap.add_argument("--lr", default=1e-3,type=float, help="learning rate") + ap.add_argument("--max_epoch", default=200, type=int, help="maximum epoch to train model") + ap.add_argument("--batch_size", default=16, type=int, help="train batch size") + ap.add_argument("--output_dir", required=True, type=str, help="path for saving the train log and best checkpoint") + + args = ap.parse_args() + assert os.path.isdir(args.train_set_dir), 'No such file or directory: ' + args.train_set_dir + if not os.path.isdir(args.output_dir): + os.makedirs(args.output_dir) + train(args) \ No newline at end of file diff --git a/train_tracker.py b/deepsea/train_tracker.py similarity index 97% rename from train_tracker.py rename to deepsea/train_tracker.py index 2276fc1..e5f1c88 100644 --- a/train_tracker.py +++ b/deepsea/train_tracker.py @@ -1,158 +1,158 @@ -import os -import argparse -from model import DeepSeaTracker -from data import BasicTrackerDataset -import torch.nn as nn -from evaluate import evaluate_tracker -from loss import dice_loss -import torch.optim as optim -import torch.optim.lr_scheduler as lr_scheduler -import torch.utils.data as data -import torch.nn.functional as F -import tracker_transforms as transforms -import torch -import numpy as np -import os -import random -from tqdm import tqdm -import logging - -SEED = 42 -random.seed(SEED) -np.random.seed(SEED) -torch.manual_seed(SEED) -torch.cuda.manual_seed(SEED) -torch.backends.cudnn.deterministic = True - -def train(args,image_size = [128,128],image_means = [0.5],image_stds= [0.5],valid_ratio = 0.8,save_checkpoint=True,if_train_aug=True,train_aug_iter=1,patience=5): - - logging.basicConfig(filename=os.path.join(args.output_dir, 'train.log'), filemode='w',format='%(asctime)s - %(message)s', level=logging.INFO) - logging.info('>>>> image size=(%d,%d) , learning rate=%f , batch size=%d' % (image_size[0], image_size[1],args.lr,args.batch_size)) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - if if_train_aug: - train_transforms = transforms.Compose([ - transforms.Grayscale(num_output_channels=1), - transforms.RandomApply([ - transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0), - transforms.GaussianBlur((3, 3), sigma=(0.1, 0.5)), - transforms.RandomHorizontalFlip(0.5), - transforms.RandomVerticalFlip(0.5), - ],p=1-1/train_aug_iter), - transforms.Resize(image_size), - transforms.ToTensor(), - transforms.Normalize(mean = image_means,std = image_stds) - ]) - else: - train_transforms = transforms.Compose([ - transforms.Resize(image_size), - transforms.Grayscale(num_output_channels=1), - transforms.ToTensor(), - transforms.Normalize(mean=image_means, - std=image_stds) - ]) - - - train_data = BasicTrackerDataset(os.path.join(args.train_set_dir),transforms=train_transforms,if_train_aug=if_train_aug,train_aug_iter=train_aug_iter) - - n_train_examples = int(len(train_data) * valid_ratio) - n_valid_examples = len(train_data) - n_train_examples - - train_data, valid_data = data.random_split(train_data,[n_train_examples, n_valid_examples],generator=torch.Generator().manual_seed(SEED)) - - - train_iterator = data.DataLoader(train_data,shuffle = True,batch_size = args.batch_size) - - valid_iterator = data.DataLoader(valid_data,batch_size = args.batch_size) - - model=DeepSeaTracker(n_channels=1, n_classes=2, bilinear=True) - - optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) - - STEPS_PER_EPOCH = len(train_iterator) - TOTAL_STEPS = args.max_epoch * STEPS_PER_EPOCH - MAX_LRS = [p['lr'] for p in optimizer.param_groups] - scheduler = lr_scheduler.OneCycleLR(optimizer,max_lr=MAX_LRS,total_steps=TOTAL_STEPS) - grad_scaler = torch.cuda.amp.GradScaler(enabled=True) - criterion = nn.CrossEntropyLoss() - model = model.to(device) - criterion = criterion.to(device) - nstop=0 - avg_precision_best=0 - - logging.info('>>>> Start training') - print('INFO: Start training ...') - for epoch in range(args.max_epoch): - model.train() - epoch_loss = 0 - with tqdm(total=n_train_examples, desc=f'Epoch {epoch + 1}/{args.max_epoch}', unit='img') as pbar: - for step,batch in enumerate(train_iterator): - img_prev,img_curr,mask = batch['image_prev'],batch['image_curr'],batch['mask'] - - assert img_prev.shape[1] == model.n_channels, \ - f'Network has been defined with {model.n_channels} input channels, ' \ - f'but loaded images have {img_prev.shape[1]} channels. Please check that ' \ - 'the images are loaded correctly.' - - img_prev = img_prev.to(device=device, dtype=torch.float32) - img_curr = img_curr.to(device=device, dtype=torch.float32) - mask = mask.to(device=device, dtype=torch.long) - true_masks=torch.squeeze(mask, dim=1) - with torch.cuda.amp.autocast(enabled=True): - masks_preds = model(img_prev,img_curr) - loss = criterion(masks_preds, true_masks) \ - + dice_loss(F.softmax(masks_preds, dim=1).float(), - F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(), - multiclass=True) - - optimizer.zero_grad(set_to_none=True) - grad_scaler.scale(loss).backward() - grad_scaler.step(optimizer) - grad_scaler.update() - pbar.update(img_prev.shape[0]) - epoch_loss += loss.item() - pbar.set_postfix(**{'loss': epoch_loss/(step+1)}) - - # Evaluation round - val_score,avg_precision,single,mitosis = evaluate_tracker(model, valid_iterator, device,n_valid_examples,is_avg_prec=((1+epoch)%2==0),prec_thresholds=[0.5]) - - if avg_precision is not None: - logging.info('>>>> Epoch:%d , loss=%f , valid score=%f , avg precision=%f' % ( - epoch, epoch_loss / (step+1), val_score, avg_precision[0])) - else: - logging.info('>>>> Epoch:%d , loss=%f , valid score=%f' % ( - epoch, epoch_loss / (step + 1), val_score)) - ## Save best checkpoint corresponding the best average precision - if avg_precision is not None and avg_precision>avg_precision_best: - avg_precision_best=avg_precision - states = model.state_dict() - if save_checkpoint: - logging.info('>>>> save model to %s'%(os.path.join(args.output_dir,'tracker.pth'))) - torch.save(states, os.path.join(args.output_dir,'tracker.pth')) - - nstop=0 - elif avg_precision is not None and avg_precision<=avg_precision_best: - nstop+=1 - if nstop==patience:#Early Stopping - print('INFO: Early Stopping met ...') - print('INFO: Finish training process') - break - scheduler.step() - - - - -if __name__ == "__main__": - ap = argparse.ArgumentParser() - ap.add_argument("--train_set_dir",required=True,type=str,help="path for the train set") - ap.add_argument("--lr", default=1e-3,type=float, help="learning rate") - ap.add_argument("--max_epoch", default=100, type=int, help="maximum epoch to train model") - ap.add_argument("--batch_size", default=16, type=int, help="train batch size") - ap.add_argument("--output_dir", required=True, type=str, help="path for saving the train log and best model") - - args = ap.parse_args() - - assert os.path.isdir(args.train_set_dir), 'No such file or directory: ' + args.train_set_dir - if not os.path.isdir(args.output_dir): - os.makedirs(args.output_dir) - +import os +import argparse +from model import DeepSeaTracker +from data import BasicTrackerDataset +import torch.nn as nn +from evaluate import evaluate_tracker +from loss import dice_loss +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler +import torch.utils.data as data +import torch.nn.functional as F +import tracker_transforms as transforms +import torch +import numpy as np +import os +import random +from tqdm import tqdm +import logging + +SEED = 42 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True + +def train(args,image_size = [128,128],image_means = [0.5],image_stds= [0.5],valid_ratio = 0.8,save_checkpoint=True,if_train_aug=True,train_aug_iter=1,patience=5): + + logging.basicConfig(filename=os.path.join(args.output_dir, 'train.log'), filemode='w',format='%(asctime)s - %(message)s', level=logging.INFO) + logging.info('>>>> image size=(%d,%d) , learning rate=%f , batch size=%d' % (image_size[0], image_size[1],args.lr,args.batch_size)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if if_train_aug: + train_transforms = transforms.Compose([ + transforms.Grayscale(num_output_channels=1), + transforms.RandomApply([ + transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0), + transforms.GaussianBlur((3, 3), sigma=(0.1, 0.5)), + transforms.RandomHorizontalFlip(0.5), + transforms.RandomVerticalFlip(0.5), + ],p=1-1/train_aug_iter), + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize(mean = image_means,std = image_stds) + ]) + else: + train_transforms = transforms.Compose([ + transforms.Resize(image_size), + transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Normalize(mean=image_means, + std=image_stds) + ]) + + + train_data = BasicTrackerDataset(os.path.join(args.train_set_dir),transforms=train_transforms,if_train_aug=if_train_aug,train_aug_iter=train_aug_iter) + + n_train_examples = int(len(train_data) * valid_ratio) + n_valid_examples = len(train_data) - n_train_examples + + train_data, valid_data = data.random_split(train_data,[n_train_examples, n_valid_examples],generator=torch.Generator().manual_seed(SEED)) + + + train_iterator = data.DataLoader(train_data,shuffle = True,batch_size = args.batch_size) + + valid_iterator = data.DataLoader(valid_data,batch_size = args.batch_size) + + model=DeepSeaTracker(n_channels=1, n_classes=2, bilinear=True) + + optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) + + STEPS_PER_EPOCH = len(train_iterator) + TOTAL_STEPS = args.max_epoch * STEPS_PER_EPOCH + MAX_LRS = [p['lr'] for p in optimizer.param_groups] + scheduler = lr_scheduler.OneCycleLR(optimizer,max_lr=MAX_LRS,total_steps=TOTAL_STEPS) + grad_scaler = torch.cuda.amp.GradScaler(enabled=True) + criterion = nn.CrossEntropyLoss() + model = model.to(device) + criterion = criterion.to(device) + nstop=0 + avg_precision_best=0 + + logging.info('>>>> Start training') + print('INFO: Start training ...') + for epoch in range(args.max_epoch): + model.train() + epoch_loss = 0 + with tqdm(total=n_train_examples, desc=f'Epoch {epoch + 1}/{args.max_epoch}', unit='img') as pbar: + for step,batch in enumerate(train_iterator): + img_prev,img_curr,mask = batch['image_prev'],batch['image_curr'],batch['mask'] + + assert img_prev.shape[1] == model.n_channels, \ + f'Network has been defined with {model.n_channels} input channels, ' \ + f'but loaded images have {img_prev.shape[1]} channels. Please check that ' \ + 'the images are loaded correctly.' + + img_prev = img_prev.to(device=device, dtype=torch.float32) + img_curr = img_curr.to(device=device, dtype=torch.float32) + mask = mask.to(device=device, dtype=torch.long) + true_masks=torch.squeeze(mask, dim=1) + with torch.cuda.amp.autocast(enabled=True): + masks_preds = model(img_prev,img_curr) + loss = criterion(masks_preds, true_masks) \ + + dice_loss(F.softmax(masks_preds, dim=1).float(), + F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(), + multiclass=True) + + optimizer.zero_grad(set_to_none=True) + grad_scaler.scale(loss).backward() + grad_scaler.step(optimizer) + grad_scaler.update() + pbar.update(img_prev.shape[0]) + epoch_loss += loss.item() + pbar.set_postfix(**{'loss': epoch_loss/(step+1)}) + + # Evaluation round + val_score,avg_precision,single,mitosis = evaluate_tracker(model, valid_iterator, device,n_valid_examples,is_avg_prec=((1+epoch)%2==0),prec_thresholds=[0.5]) + + if avg_precision is not None: + logging.info('>>>> Epoch:%d , loss=%f , valid score=%f , avg precision=%f' % ( + epoch, epoch_loss / (step+1), val_score, avg_precision[0])) + else: + logging.info('>>>> Epoch:%d , loss=%f , valid score=%f' % ( + epoch, epoch_loss / (step + 1), val_score)) + ## Save best checkpoint corresponding the best average precision + if avg_precision is not None and avg_precision>avg_precision_best: + avg_precision_best=avg_precision + states = model.state_dict() + if save_checkpoint: + logging.info('>>>> save model to %s'%(os.path.join(args.output_dir,'tracker.pth'))) + torch.save(states, os.path.join(args.output_dir,'tracker.pth')) + + nstop=0 + elif avg_precision is not None and avg_precision<=avg_precision_best: + nstop+=1 + if nstop==patience:#Early Stopping + print('INFO: Early Stopping met ...') + print('INFO: Finish training process') + break + scheduler.step() + + + + +if __name__ == "__main__": + ap = argparse.ArgumentParser() + ap.add_argument("--train_set_dir",required=True,type=str,help="path for the train set") + ap.add_argument("--lr", default=1e-3,type=float, help="learning rate") + ap.add_argument("--max_epoch", default=100, type=int, help="maximum epoch to train model") + ap.add_argument("--batch_size", default=16, type=int, help="train batch size") + ap.add_argument("--output_dir", required=True, type=str, help="path for saving the train log and best model") + + args = ap.parse_args() + + assert os.path.isdir(args.train_set_dir), 'No such file or directory: ' + args.train_set_dir + if not os.path.isdir(args.output_dir): + os.makedirs(args.output_dir) + train(args) \ No newline at end of file diff --git a/deepsea/trained_models/__init__.py b/deepsea/trained_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trained_models/segmentation.pth b/deepsea/trained_models/segmentation.pth similarity index 100% rename from trained_models/segmentation.pth rename to deepsea/trained_models/segmentation.pth diff --git a/trained_models/tracker.pth b/deepsea/trained_models/tracker.pth similarity index 100% rename from trained_models/tracker.pth rename to deepsea/trained_models/tracker.pth diff --git a/utils.py b/deepsea/utils.py similarity index 97% rename from utils.py rename to deepsea/utils.py index 781648b..e53a8ce 100644 --- a/utils.py +++ b/deepsea/utils.py @@ -558,4 +558,4 @@ def track_cells(pred_list,img_list,tracking_model,device,transforms): masks_prev = masks_curr pbar.update(1) - return cell_labels,cell_centroids,tracked_imgs + return cell_labels,cell_centroids,tracked_imgs \ No newline at end of file diff --git a/evaluate_test_set_segmentation.py b/evaluate_test_set_segmentation.py deleted file mode 100644 index d6a6d3e..0000000 --- a/evaluate_test_set_segmentation.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch.utils.data as data -import segmentation_transforms as transforms -import numpy as np -import argparse -import os -import random -from model import DeepSeaSegmentation -from data import BasicSegmentationDataset -import torch -from evaluate import evaluate_segmentation -from utils import get_n_params - -SEED = 1234 -random.seed(SEED) -np.random.seed(SEED) -torch.manual_seed(SEED) -torch.cuda.manual_seed(SEED) -torch.backends.cudnn.deterministic = True - - -def test(args,image_size = [383,512],image_means = [0.5],image_stds= [0.5],batch_size=1): - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - test_transforms = transforms.Compose([ - transforms.Resize(image_size), - transforms.Grayscale(num_output_channels=1), - transforms.ToTensor(), - transforms.Normalize(mean = image_means, - std = image_stds) - ]) - - - test_data = BasicSegmentationDataset(os.path.join(args.test_set_dir, 'images'), os.path.join(args.test_set_dir, 'masks'),os.path.join(args.test_set_dir, 'wmaps'),transforms=test_transforms) - - test_iterator = data.DataLoader(test_data,batch_size = batch_size,shuffle=False) - - model=DeepSeaSegmentation(n_channels=1, n_classes=2, bilinear=True) - print('INFO: Num of model parameters:',get_n_params(model)) - model.load_state_dict(torch.load(args.ckpt_dir)) - model = model.to(device) - - test_score, test_avg_precision,test_easy_avg_precision,test_hard_avg_precision = evaluate_segmentation(model, test_iterator, device,len(test_data),is_avg_prec=True,prec_thresholds=[0.5,0.6,0.7,0.8,0.9],output_dir=args.output_dir) - print('INFO: Dice score:', test_score) - print('INFO: Average precision at ordered thresholds:', test_avg_precision) - print('INFO: Easy samples average precision at ordered thresholds:', test_easy_avg_precision) - print('INFO: Hard samples average precision at ordered thresholds:', test_hard_avg_precision) - -if __name__ == "__main__": - ap = argparse.ArgumentParser() - ap.add_argument("--test_set_dir",required=True,type=str,help="path for the test dataset") - ap.add_argument("--ckpt_dir",required=True,type=str,help="path for the checkpoint of segmentation model to test") - ap.add_argument("--output_dir", required=True, type=str, help="path for saving the test outputs") - - args = ap.parse_args() - - assert os.path.isdir(args.test_set_dir), 'No such file or directory: ' + args.test_set_dir - if not os.path.isdir(os.path.join(args.output_dir,'input_segmentation_images')): - os.makedirs(os.path.join(args.output_dir,'input_segmentation_images')) - if not os.path.isdir(os.path.join(args.output_dir,'segmentation_predictions')): - os.makedirs(os.path.join(args.output_dir,'segmentation_predictions')) - - test(args) \ No newline at end of file diff --git a/evaluate_test_set_tracking.py b/evaluate_test_set_tracking.py deleted file mode 100644 index 25c59f7..0000000 --- a/evaluate_test_set_tracking.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import torch.utils.data as data -import tracker_transforms as transforms -import numpy as np -import argparse -import random -from model import DeepSeaTracker -from data import BasicTrackerDataset -import torch -from evaluate import evaluate_tracker -from utils import get_n_params - -SEED = 1234 -random.seed(SEED) -np.random.seed(SEED) -torch.manual_seed(SEED) -torch.cuda.manual_seed(SEED) -torch.backends.cudnn.deterministic = True - - - -def test(args,image_size = [128,128],image_means = [0.5],image_stds= [0.5],batch_size=1): - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - test_transforms = transforms.Compose([ - transforms.Resize(image_size), - transforms.Grayscale(num_output_channels=1), - transforms.ToTensor(), - transforms.Normalize(mean = image_means, - std = image_stds) - ]) - - - test_data = BasicTrackerDataset(os.path.join(args.test_set_dir), transforms=test_transforms,if_test=True) - test_iterator = data.DataLoader(test_data,batch_size = batch_size) - - model=DeepSeaTracker(n_channels=1, n_classes=2, bilinear=True) - print('INFO: Num of model parameters:',get_n_params(model)) - - model.load_state_dict(torch.load(args.ckpt_dir)) - model = model.to(device) - - test_score, test_avg_precision,test_single_cell_avg_precision,test_mitosis_avg_precision = evaluate_tracker(model, test_iterator, device,len(test_data),is_avg_prec=True,prec_thresholds=[0.2,0.6,0.7,0.8,0.9],output_dir=args.output_dir) - - print('INFO: Dice score:', test_score) - print('INFO: Average precision:', test_avg_precision) - print('INFO: Single cells average precision:', test_single_cell_avg_precision) - print('INFO: Mitosis average precision:', test_mitosis_avg_precision) - - -if __name__ == "__main__": - ap = argparse.ArgumentParser() - ap.add_argument("--test_set_dir",required=True,type=str,help="path for the test dataset") - ap.add_argument("--ckpt_dir",required=True,type=str,help="path for the checkpoint of tracking model to test") - ap.add_argument("--output_dir", required=True, type=str, help="path for saving the test outputs") - - args = ap.parse_args() - - assert os.path.isdir(args.test_set_dir), 'No such file or directory: ' + args.test_set_dir - if not os.path.isdir(os.path.join(args.output_dir,'input_crops')): - os.makedirs(os.path.join(args.output_dir,'input_crops')) - if not os.path.isdir(os.path.join(args.output_dir,'tracking_predictions')): - os.makedirs(os.path.join(args.output_dir,'tracking_predictions')) - - test(args) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..098a99b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +# pyproject.toml +[build-system] +requires = [ + "setuptools>=45", + "wheel", + "setuptools_scm[toml]>=6.2" +] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] +write_to = "deepsea/_version.py" + +# Ignore DeprecationWarnings +[tool.pytest.ini_options] +filterwarnings = [ + "ignore::DeprecationWarning" +] diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..1e08337 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,51 @@ +[metadata] +name = deepsea +author = Abolfazl Zargari and Ali Shariati +author_email = alish@ucsc.edu +description = An efficient deep learning model for single-cell segmentation and tracking of time-lapse microscopy images +keywords = + live-cell imaging + cell segmentation + cell tracking + image analysis +long_description = file: README.md +long_description_content_type = text/markdown +url = https://github.com/SchmollerLab/Cell_ACDC +project_urls = + Shariati lab = https://shariatilab.sites.ucsc.edu/ +classifiers = + Development Status :: 3 - Alpha + Programming Language :: Python :: 3 :: Only + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + License :: OSI Approved :: BSD License + Intended Audience :: Education + Intended Audience :: Science/Research + Operating System :: Microsoft :: Windows + Operating System :: POSIX + Operating System :: Unix + Operating System :: MacOS + Topic :: Scientific/Engineering + Topic :: Scientific/Engineering :: Bio-Informatics + Topic :: Scientific/Engineering :: Information Analysis + Topic :: Scientific/Engineering :: Image Processing + Topic :: Scientific/Engineering :: Visualization + Topic :: Utilities + +[options] +packages = find: +python_requires = + >=3.8 +include_package_data = True +install_requires = + opencv-python-headless>=4.5.1.48 + numpy>=1.19.5 + torch>=1.7.1 + Pillow>=8.1.0 + tqdm>=4.56.0 + scipy>=1.3.3 + scikit-image>=0.17.2 + munkres>=1.1.4 + torchvision>=0.8.2 + setuptools-scm \ No newline at end of file