-
Notifications
You must be signed in to change notification settings - Fork 13
/
HeatMap.py
115 lines (102 loc) · 4.04 KB
/
HeatMap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import io
import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage as ndimage
from PIL import Image
class HeatMap:
def __init__(self,image,heat_map,gaussian_std=10):
#if image is numpy array
if isinstance(image,np.ndarray):
height = image.shape[0]
width = image.shape[1]
self.image = image
else:
#PIL open the image path, record the height and width
image = Image.open(image)
width, height = image.size
self.image = image
#Convert numpy heat_map values into image formate for easy upscale
#Rezie the heat_map to the size of the input image
#Apply the gausian filter for smoothing
#Convert back to numpy
heatmap_image = Image.fromarray(heat_map*255)
heatmap_image_resized = heatmap_image.resize((width,height))
heatmap_image_resized = ndimage.gaussian_filter(heatmap_image_resized,
sigma=(gaussian_std, gaussian_std),
order=0)
heatmap_image_resized = np.asarray(heatmap_image_resized)
self.heat_map = heatmap_image_resized
#Plot the figure
def plot(self,transparency=0.7,color_map='bwr',
show_axis=False, show_original=False, show_colorbar=False,width_pad=0):
#If show_original is True, then subplot first figure as orginal image
#Set x,y to let the heatmap plot in the second subfigure,
#otherwise heatmap will plot in the first sub figure
if show_original:
plt.subplot(1, 2, 1)
if not show_axis:
plt.axis('off')
plt.imshow(self.image)
x,y=2,2
else:
x,y=1,1
#Plot the heatmap
plt.subplot(1,x,y)
if not show_axis:
plt.axis('off')
plt.imshow(self.image)
plt.imshow(self.heat_map/255, alpha=transparency, cmap=color_map)
if show_colorbar:
plt.colorbar()
plt.tight_layout(w_pad=width_pad)
plt.show()
###Save the figure
def save(self,filename,format='png',save_path=os.getcwd(),
transparency=0.7,color_map='bwr',width_pad = -10,
show_axis=False, show_original=False, show_colorbar=False, **kwargs):
if show_original:
plt.subplot(1, 2, 1)
if not show_axis:
plt.axis('off')
plt.imshow(self.image)
x,y=2,2
else:
x,y=1,1
#Plot the heatmap
plt.subplot(1,x,y)
if not show_axis:
plt.axis('off')
plt.imshow(self.image)
plt.imshow(self.heat_map/255, alpha=transparency, cmap=color_map)
if show_colorbar:
plt.colorbar()
plt.tight_layout(w_pad=width_pad)
plt.savefig(os.path.join(save_path,filename+'.'+format),
format=format,
bbox_inches='tight',
pad_inches = 0, **kwargs)
print('{}.{} has been successfully saved to {}'.format(filename,format,save_path))
#Return heatmap array
def get_heatmap_array(self,transparency=0.7,color_map='bwr'):
#Plot the heatmap
fig = plt.figure()
plt.subplot(1,1,1)
plt.axis('off')
plt.imshow(self.image)
plt.imshow(self.heat_map/255, alpha=transparency, cmap=color_map)
plt.tight_layout(w_pad=0)
# https://stackoverflow.com/questions/8598673/how-to-save-a-pylab-figure-into-in-memory-file-which-can-be-read-into-pil-image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close('all')
# Ensure Image reads from the beginning
buf.seek(0)
im = Image.open(buf).convert('RGB')
# Turn heatmap into numpy array.
# NOTE that the read image size is too large because of plt's default size.
heatmap_array = np.array(im)
buf.close()
im.close()
return heatmap_array