-
Notifications
You must be signed in to change notification settings - Fork 2
/
demo.py
137 lines (108 loc) · 5.01 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Main demo script for running an experiment
import argparse
from skimage import data, segmentation
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import os
import glob
import skimage.io as sio
import tqdm
import datetime
import time
import copy
import random
from skimage.segmentation import watershed, slic
from src.segment.watershed import watershed_segmentation
from src.segment.watershed import watershed_segmentation
from src.rag import RAG
from src.nng import NNG
from src.utils import visualize_rag, quantize_image, visualize_nng
from src.sprt import compute_averageColor, compute_conditionalProbability, compute_averageColor, sample_from_region, sprt_test
from src.region_merging import RegionMerging, NNGRegionMerging
def plot_segmentationLabels(img, labels):
plt.figure(0, figsize=(10, 7))
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.subplot(1, 2, 2)
plt.imshow(labels)
plt.show()
def main(args):
"""
Main function takes the arguments and does the following:
1. Load the image from the image_path parameter
2. Compute the labels using either slic or watershed
3. Generated the NNG graph and start running the algorithm
4. Run the region merging process for max_iters number of times
5. Visualize the outputs if we want
6. Create the output directory and store the outputs
7. Save the outputs for initial RAG and NNG, and final RAG and NNG.
"""
# path = '/home/saiamrit/Documents/BSR_bsds500/BSR/BSDS500/data/images/test/'
path = args.input_path
output_path = args.output
# Load the image from the image_path parameter
if(os.path.isdir(path)):
files = glob.glob(path + '*.jpg')
file = random.choice(files)
name = file.split('/')[-1][:-4]
print("Operating on file: ",name)
# Create the output directory
save_path = os.path.join(output_path , '{}'.format(name))
os.mkdir(save_path)
image = cv2.imread(file)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
w, h = int(image.shape[0]/2), int(image.shape[1]/2)
image = cv2.resize(image, (h, w))
# Compute the labels using either slic or watershed
if(args.watershed):
gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
labels = watershed_segmentation(gray_image)
else:
labels = slic(image, compactness=10, n_segments=300, start_label=1)
# plot_segmentationLabels(image, labels)
# Generate the NNG graph and start running the algorithm
drm = NNGRegionMerging(image, labels, lambda1=0.95, lambda2=0.1)
start = time.time()
# start = datetime.datetime.now()
initial_labels = drm.get_labels()
initial_graph = copy.deepcopy(drm.graph)
initial_nng_graph = copy.deepcopy(drm.nng)
print("Number of initial regions: ", len(drm.graph.nodes))
# Run the region merging process for max_iters number of times
drm.run_region_merging(100)
new_labels = drm.get_labels()
print("Number of final regions: ", len(drm.graph.nodes))
end = time.time()
print('Segmentation takes: {:.4f} secs'.format((end-start)))
# Save the outputs
plt.imsave(save_path+'/Original Image.png',image)
plt.imsave(save_path+'/Initial Labels.png',initial_labels)
plt.imsave(save_path+'/Final labels.png',new_labels)
# plot_segmentationLabels(initial_labels, new_labels)
# Save the outputs for initial RAG and NNG, and final RAG and NNG.
initial_label_img = visualize_rag(image, initial_labels, initial_graph, method = 'rag', time = 'initial', path = save_path)
_ = visualize_nng(image, initial_labels, initial_nng_graph, method = 'nng', time = 'initial', path = save_path)
final_label_image = visualize_rag(image, new_labels, drm.graph, method = 'rag', time = 'final', path = save_path)
_ = visualize_nng(image, new_labels, drm.nng, method = 'nng', time = 'final', path = save_path)
plt.imsave(save_path+'/Initial Label Image.png',initial_label_img)
plt.imsave(save_path+'/Final label Image.png',final_label_image)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Segmentation parameters Parser. Pass the parameters following instructions given\
below to run the demo experiment.")
# Input and output paths
parser.add_argument('--input_path', default="./img/", type=str, help='Path to image for processing')
parser.add_argument('--output', default="./output/", type=str, help="Path to folder for storing output")
# RAG/NNG parameters
parser.add_argument('--watershed', default=False, action='store_true', help="Whether to use watershed for initial segmentation")
parser.add_argument('--lambda1', default=0.8, type=float, help="Value for Lambda1 parameter")
parser.add_argument('--lambda2', default=0.1, type=float, help="Value for Lambda2 parameter")
parser.add_argument('--alpha', default=0.05, type=float, help="Parameter value for alpha")
parser.add_argument('--beta', default=0.05, type=float, help="Parameter value for beta")
parser.add_argument('--visualize', default=False, action='store_true', help='Visualize the outputs of the algorithm')
parser.add_argument('--max_iters', default=100, type=int, help='Max iterations for the Region merging loop')
# Parse the arguments
args = parser.parse_args()
# Call the main function
main(args)