-
Notifications
You must be signed in to change notification settings - Fork 12
/
differential_color_functions.py
161 lines (130 loc) · 5.28 KB
/
differential_color_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import numpy as np
import torch
def rgb2xyz(rgb_image,device):
mt = torch.tensor([[0.4124, 0.3576, 0.1805],
[0.2126, 0.7152, 0.0722],
[0.0193, 0.1192, 0.9504]]).to(device)
mask1=(rgb_image > 0.0405).float()
mask1_no=1-mask1
temp_img = mask1* (((rgb_image + 0.055 ) / 1.055 ) ** 2.4)
temp_img = temp_img+mask1_no * (rgb_image / 12.92)
temp_img = 100 * temp_img
res = torch.matmul(mt, temp_img.permute(1, 0, 2,3).contiguous().view(3, -1)).view(3, rgb_image.size(0),rgb_image.size(2), rgb_image.size(3)).permute(1, 0, 2,3)
return res
def xyz_lab(xyz_image,device):
mask_value_0=(xyz_image==0).float().to(device)
mask_value_0_no=1-mask_value_0
xyz_image=xyz_image+0.0001*mask_value_0
mask1= (xyz_image > 0.008856).float()
mask1_no= 1-mask1
res = mask1 * (xyz_image) ** (1 /3)
res = res+mask1_no * ((7.787 * xyz_image) + (16/ 116))
res=res*mask_value_0_no
return res
def rgb2lab_diff(rgb_image,device):
'''
Function to convert a batch of image tensors from RGB space to CIELAB space.
parameters: xn, yn, zn are the CIE XYZ tristimulus values of the reference white point.
Here use the standard Illuminant D65 with normalization Y = 100.
'''
rgb_image=rgb_image.to(device)
res = torch.zeros_like(rgb_image)
xyz_image = rgb2xyz(rgb_image,device)
xn = 95.0489
yn = 100
zn = 108.8840
x = xyz_image[:,0, :, :]
y = xyz_image[:,1, :, :]
z = xyz_image[:,2, :, :]
L = 116*xyz_lab(y/yn,device) - 16
a = 500*(xyz_lab(x/xn,device) - xyz_lab(y/yn,device))
b = 200*(xyz_lab(y/yn,device) - xyz_lab(z/zn,device))
res[:, 0, :, :] = L
res[:, 1, :, :] = a
res[:, 2, :, :] = b
return res
def degrees(n): return n * (180. / np.pi)
def radians(n): return n * (np.pi / 180.)
def hpf_diff(x, y):
mask1=((x == 0) * (y == 0)).float()
mask1_no = 1-mask1
tmphp = degrees(torch.atan2(x*mask1_no, y*mask1_no))
tmphp1 = tmphp * (tmphp >= 0).float()
tmphp2 = (360+tmphp)* (tmphp < 0).float()
return tmphp1+tmphp2
def dhpf_diff(c1, c2, h1p, h2p):
mask1 = ((c1 * c2) == 0).float()
mask1_no = 1-mask1
res1=(h2p - h1p)*mask1_no*(torch.abs(h2p - h1p) <= 180).float()
res2 = ((h2p - h1p)- 360) * ((h2p - h1p) > 180).float()*mask1_no
res3 = ((h2p - h1p)+360) * ((h2p - h1p) < -180).float()*mask1_no
return res1+res2+res3
def ahpf_diff(c1, c2, h1p, h2p):
mask1=((c1 * c2) == 0).float()
mask1_no=1-mask1
mask2=(torch.abs(h2p - h1p) <= 180).float()
mask2_no=1-mask2
mask3=(torch.abs(h2p + h1p) < 360).float()
mask3_no=1-mask3
res1 = (h1p + h2p) *mask1_no * mask2
res2 = (h1p + h2p + 360.) * mask1_no * mask2_no * mask3
res3 = (h1p + h2p - 360.) * mask1_no * mask2_no * mask3_no
res = (res1+res2+res3)+(res1+res2+res3)*mask1
return res*0.5
def ciede2000_diff(lab1, lab2,device):
'''
CIEDE2000 metric to claculate the color distance map for a batch of image tensors defined in CIELAB space
'''
lab1=lab1.to(device)
lab2=lab2.to(device)
L1 = lab1[:,0,:,:]
A1 = lab1[:,1,:,:]
B1 = lab1[:,2,:,:]
L2 = lab2[:,0,:,:]
A2 = lab2[:,1,:,:]
B2 = lab2[:,2,:,:]
kL = 1
kC = 1
kH = 1
mask_value_0_input1=((A1==0)*(B1==0)).float()
mask_value_0_input2=((A2==0)*(B2==0)).float()
mask_value_0_input1_no=1-mask_value_0_input1
mask_value_0_input2_no=1-mask_value_0_input2
B1=B1+0.0001*mask_value_0_input1
B2=B2+0.0001*mask_value_0_input2
C1 = torch.sqrt((A1 ** 2.) + (B1 ** 2.))
C2 = torch.sqrt((A2 ** 2.) + (B2 ** 2.))
aC1C2 = (C1 + C2) / 2.
G = 0.5 * (1. - torch.sqrt((aC1C2 ** 7.) / ((aC1C2 ** 7.) + (25 ** 7.))))
a1P = (1. + G) * A1
a2P = (1. + G) * A2
c1P = torch.sqrt((a1P ** 2.) + (B1 ** 2.))
c2P = torch.sqrt((a2P ** 2.) + (B2 ** 2.))
h1P = hpf_diff(B1, a1P)
h2P = hpf_diff(B2, a2P)
h1P=h1P*mask_value_0_input1_no
h2P=h2P*mask_value_0_input2_no
dLP = L2 - L1
dCP = c2P - c1P
dhP = dhpf_diff(C1, C2, h1P, h2P)
dHP = 2. * torch.sqrt(c1P * c2P) * torch.sin(radians(dhP) / 2.)
mask_0_no=1-torch.max(mask_value_0_input1,mask_value_0_input2)
dHP=dHP*mask_0_no
aL = (L1 + L2) / 2.
aCP = (c1P + c2P) / 2.
aHP = ahpf_diff(C1, C2, h1P, h2P)
T = 1. - 0.17 * torch.cos(radians(aHP - 39)) + 0.24 * torch.cos(radians(2. * aHP)) + 0.32 * torch.cos(radians(3. * aHP + 6.)) - 0.2 * torch.cos(radians(4. * aHP - 63.))
dRO = 30. * torch.exp(-1. * (((aHP - 275.) / 25.) ** 2.))
rC = torch.sqrt((aCP ** 7.) / ((aCP ** 7.) + (25. ** 7.)))
sL = 1. + ((0.015 * ((aL - 50.) ** 2.)) / torch.sqrt(20. + ((aL - 50.) ** 2.)))
sC = 1. + 0.045 * aCP
sH = 1. + 0.015 * aCP * T
rT = -2. * rC * torch.sin(radians(2. * dRO))
# res_square=((dLP / (sL * kL)) ** 2.) + ((dCP / (sC * kC)) ** 2.) + ((dHP / (sH * kH)) ** 2.) + rT * (dCP / (sC * kC)) * (dHP / (sH * kH))
res_square=((dLP / (sL * kL)) ** 2.) + ((dCP / (sC * kC)) ** 2.)*mask_0_no + ((dHP / (sH * kH)) ** 2.)*mask_0_no + rT * (dCP / (sC * kC)) * (dHP / (sH * kH))*mask_0_no
mask_0=(res_square<=0).float()
mask_0_no=1-mask_0
res_square=res_square+0.0001*mask_0
res=torch.sqrt(res_square)
res=res*mask_0_no
return res