-
Notifications
You must be signed in to change notification settings - Fork 115
/
camera.py
307 lines (279 loc) · 11.5 KB
/
camera.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import collections
from easydict import EasyDict as edict
import util
from util import log,debug
class Pose():
"""
A class of operations on camera poses (PyTorch tensors with shape [...,3,4])
each [3,4] camera pose takes the form of [R|t]
"""
def __call__(self,R=None,t=None):
# construct a camera pose from the given R and/or t
assert(R is not None or t is not None)
if R is None:
if not isinstance(t,torch.Tensor): t = torch.tensor(t)
R = torch.eye(3,device=t.device).repeat(*t.shape[:-1],1,1)
elif t is None:
if not isinstance(R,torch.Tensor): R = torch.tensor(R)
t = torch.zeros(R.shape[:-1],device=R.device)
else:
if not isinstance(R,torch.Tensor): R = torch.tensor(R)
if not isinstance(t,torch.Tensor): t = torch.tensor(t)
assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3,3))
R = R.float()
t = t.float()
pose = torch.cat([R,t[...,None]],dim=-1) # [...,3,4]
assert(pose.shape[-2:]==(3,4))
return pose
def invert(self,pose,use_inverse=False):
# invert a camera pose
R,t = pose[...,:3],pose[...,3:]
R_inv = R.inverse() if use_inverse else R.transpose(-1,-2)
t_inv = (-R_inv@t)[...,0]
pose_inv = self(R=R_inv,t=t_inv)
return pose_inv
def compose(self,pose_list):
# compose a sequence of poses together
# pose_new(x) = poseN o ... o pose2 o pose1(x)
pose_new = pose_list[0]
for pose in pose_list[1:]:
pose_new = self.compose_pair(pose_new,pose)
return pose_new
def compose_pair(self,pose_a,pose_b):
# pose_new(x) = pose_b o pose_a(x)
R_a,t_a = pose_a[...,:3],pose_a[...,3:]
R_b,t_b = pose_b[...,:3],pose_b[...,3:]
R_new = R_b@R_a
t_new = (R_b@t_a+t_b)[...,0]
pose_new = self(R=R_new,t=t_new)
return pose_new
class Lie():
"""
Lie algebra for SO(3) and SE(3) operations in PyTorch
"""
def so3_to_SO3(self,w): # [...,3]
wx = self.skew_symmetric(w)
theta = w.norm(dim=-1)[...,None,None]
I = torch.eye(3,device=w.device,dtype=torch.float32)
A = self.taylor_A(theta)
B = self.taylor_B(theta)
R = I+A*wx+B*wx@wx
return R
def SO3_to_so3(self,R,eps=1e-7): # [...,3,3]
trace = R[...,0,0]+R[...,1,1]+R[...,2,2]
theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi
lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird
w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0]
w = torch.stack([w0,w1,w2],dim=-1)
return w
def se3_to_SE3(self,wu): # [...,3]
w,u = wu.split([3,3],dim=-1)
wx = self.skew_symmetric(w)
theta = w.norm(dim=-1)[...,None,None]
I = torch.eye(3,device=w.device,dtype=torch.float32)
A = self.taylor_A(theta)
B = self.taylor_B(theta)
C = self.taylor_C(theta)
R = I+A*wx+B*wx@wx
V = I+B*wx+C*wx@wx
Rt = torch.cat([R,(V@u[...,None])],dim=-1)
return Rt
def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4]
R,t = Rt.split([3,1],dim=-1)
w = self.SO3_to_so3(R)
wx = self.skew_symmetric(w)
theta = w.norm(dim=-1)[...,None,None]
I = torch.eye(3,device=w.device,dtype=torch.float32)
A = self.taylor_A(theta)
B = self.taylor_B(theta)
invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx
u = (invV@t)[...,0]
wu = torch.cat([w,u],dim=-1)
return wu
def skew_symmetric(self,w):
w0,w1,w2 = w.unbind(dim=-1)
O = torch.zeros_like(w0)
wx = torch.stack([torch.stack([O,-w2,w1],dim=-1),
torch.stack([w2,O,-w0],dim=-1),
torch.stack([-w1,w0,O],dim=-1)],dim=-2)
return wx
def taylor_A(self,x,nth=10):
# Taylor expansion of sin(x)/x
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
if i>0: denom *= (2*i)*(2*i+1)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
def taylor_B(self,x,nth=10):
# Taylor expansion of (1-cos(x))/x**2
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
denom *= (2*i+1)*(2*i+2)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
def taylor_C(self,x,nth=10):
# Taylor expansion of (x-sin(x))/x**3
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
denom *= (2*i+2)*(2*i+3)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
class Quaternion():
def q_to_R(self,q):
# https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
qa,qb,qc,qd = q.unbind(dim=-1)
R = torch.stack([torch.stack([1-2*(qc**2+qd**2),2*(qb*qc-qa*qd),2*(qa*qc+qb*qd)],dim=-1),
torch.stack([2*(qb*qc+qa*qd),1-2*(qb**2+qd**2),2*(qc*qd-qa*qb)],dim=-1),
torch.stack([2*(qb*qd-qa*qc),2*(qa*qb+qc*qd),1-2*(qb**2+qc**2)],dim=-1)],dim=-2)
return R
def R_to_q(self,R,eps=1e-8): # [B,3,3]
# https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
# FIXME: this function seems a bit problematic, need to double-check
row0,row1,row2 = R.unbind(dim=-2)
R00,R01,R02 = row0.unbind(dim=-1)
R10,R11,R12 = row1.unbind(dim=-1)
R20,R21,R22 = row2.unbind(dim=-1)
t = R[...,0,0]+R[...,1,1]+R[...,2,2]
r = (1+t+eps).sqrt()
qa = 0.5*r
qb = (R21-R12).sign()*0.5*(1+R00-R11-R22+eps).sqrt()
qc = (R02-R20).sign()*0.5*(1-R00+R11-R22+eps).sqrt()
qd = (R10-R01).sign()*0.5*(1-R00-R11+R22+eps).sqrt()
q = torch.stack([qa,qb,qc,qd],dim=-1)
for i,qi in enumerate(q):
if torch.isnan(qi).any():
K = torch.stack([torch.stack([R00-R11-R22,R10+R01,R20+R02,R12-R21],dim=-1),
torch.stack([R10+R01,R11-R00-R22,R21+R12,R20-R02],dim=-1),
torch.stack([R20+R02,R21+R12,R22-R00-R11,R01-R10],dim=-1),
torch.stack([R12-R21,R20-R02,R01-R10,R00+R11+R22],dim=-1)],dim=-2)/3.0
K = K[i]
eigval,eigvec = torch.linalg.eigh(K)
V = eigvec[:,eigval.argmax()]
q[i] = torch.stack([V[3],V[0],V[1],V[2]])
return q
def invert(self,q):
qa,qb,qc,qd = q.unbind(dim=-1)
norm = q.norm(dim=-1,keepdim=True)
q_inv = torch.stack([qa,-qb,-qc,-qd],dim=-1)/norm**2
return q_inv
def product(self,q1,q2): # [B,4]
q1a,q1b,q1c,q1d = q1.unbind(dim=-1)
q2a,q2b,q2c,q2d = q2.unbind(dim=-1)
hamil_prod = torch.stack([q1a*q2a-q1b*q2b-q1c*q2c-q1d*q2d,
q1a*q2b+q1b*q2a+q1c*q2d-q1d*q2c,
q1a*q2c-q1b*q2d+q1c*q2a+q1d*q2b,
q1a*q2d+q1b*q2c-q1c*q2b+q1d*q2a],dim=-1)
return hamil_prod
pose = Pose()
lie = Lie()
quaternion = Quaternion()
def to_hom(X):
# get homogeneous coordinates of the input
X_hom = torch.cat([X,torch.ones_like(X[...,:1])],dim=-1)
return X_hom
# basic operations of transforming 3D points between world/camera/image coordinates
def world2cam(X,pose): # [B,N,3]
X_hom = to_hom(X)
return X_hom@pose.transpose(-1,-2)
def cam2img(X,cam_intr):
return X@cam_intr.transpose(-1,-2)
def img2cam(X,cam_intr):
return X@cam_intr.inverse().transpose(-1,-2)
def cam2world(X,pose):
X_hom = to_hom(X)
pose_inv = Pose().invert(pose)
return X_hom@pose_inv.transpose(-1,-2)
def angle_to_rotation_matrix(a,axis):
# get the rotation matrix from Euler angle around specific axis
roll = dict(X=1,Y=2,Z=0)[axis]
O = torch.zeros_like(a)
I = torch.ones_like(a)
M = torch.stack([torch.stack([a.cos(),-a.sin(),O],dim=-1),
torch.stack([a.sin(),a.cos(),O],dim=-1),
torch.stack([O,O,I],dim=-1)],dim=-2)
M = M.roll((roll,roll),dims=(-2,-1))
return M
def get_center_and_ray(opt,pose,intr=None): # [HW,2]
# given the intrinsic/extrinsic matrices, get the camera center and ray directions]
assert(opt.camera.model=="perspective")
with torch.no_grad():
# compute image coordinate grid
y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5)
x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5)
Y,X = torch.meshgrid(y_range,x_range) # [H,W]
xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]
# compute center and ray
batch_size = len(pose)
xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2]
grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3]
center_3D = torch.zeros_like(grid_3D) # [B,HW,3]
# transform from camera to world coordinates
grid_3D = cam2world(grid_3D,pose) # [B,HW,3]
center_3D = cam2world(center_3D,pose) # [B,HW,3]
ray = grid_3D-center_3D # [B,HW,3]
return center_3D,ray
def get_3D_points_from_depth(opt,center,ray,depth,multi_samples=False):
if multi_samples: center,ray = center[:,:,None],ray[:,:,None]
# x = c+dv
points_3D = center+ray*depth # [B,HW,3]/[B,HW,N,3]/[N,3]
return points_3D
def convert_NDC(opt,center,ray,intr,near=1):
# shift camera center (ray origins) to near plane (z=1)
# (unlike conventional NDC, we assume the cameras are facing towards the +z direction)
center = center+(near-center[...,2:])/ray[...,2:]*ray
# projection
cx,cy,cz = center.unbind(dim=-1) # [B,HW]
rx,ry,rz = ray.unbind(dim=-1) # [B,HW]
scale_x = intr[:,0,0]/intr[:,0,2] # [B]
scale_y = intr[:,1,1]/intr[:,1,2] # [B]
cnx = scale_x[:,None]*(cx/cz)
cny = scale_y[:,None]*(cy/cz)
cnz = 1-2*near/cz
rnx = scale_x[:,None]*(rx/rz-cx/cz)
rny = scale_y[:,None]*(ry/rz-cy/cz)
rnz = 2*near/cz
center_ndc = torch.stack([cnx,cny,cnz],dim=-1) # [B,HW,3]
ray_ndc = torch.stack([rnx,rny,rnz],dim=-1) # [B,HW,3]
return center_ndc,ray_ndc
def rotation_distance(R1,R2,eps=1e-7):
# http://www.boris-belousov.net/2016/12/01/quat-dist/
R_diff = R1@R2.transpose(-2,-1)
trace = R_diff[...,0,0]+R_diff[...,1,1]+R_diff[...,2,2]
angle = ((trace-1)/2).clamp(-1+eps,1-eps).acos_() # numerical stability near -1/+1
return angle
def procrustes_analysis(X0,X1): # [N,3]
# translation
t0 = X0.mean(dim=0,keepdim=True)
t1 = X1.mean(dim=0,keepdim=True)
X0c = X0-t0
X1c = X1-t1
# scale
s0 = (X0c**2).sum(dim=-1).mean().sqrt()
s1 = (X1c**2).sum(dim=-1).mean().sqrt()
X0cs = X0c/s0
X1cs = X1c/s1
# rotation (use double for SVD, float loses precision)
U,S,V = (X0cs.t()@X1cs).double().svd(some=True)
R = (U@V.t()).float()
if R.det()<0: R[2] *= -1
# align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0
sim3 = edict(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R)
return sim3
def get_novel_view_poses(opt,pose_anchor,N=60,scale=1):
# create circular viewpoints (small oscillations)
theta = torch.arange(N)/N*2*np.pi
R_x = angle_to_rotation_matrix((theta.sin()*0.05).asin(),"X")
R_y = angle_to_rotation_matrix((theta.cos()*0.05).asin(),"Y")
pose_rot = pose(R=R_y@R_x)
pose_shift = pose(t=[0,0,-4*scale])
pose_shift2 = pose(t=[0,0,3.8*scale])
pose_oscil = pose.compose([pose_shift,pose_rot,pose_shift2])
pose_novel = pose.compose([pose_oscil,pose_anchor.cpu()[None]])
return pose_novel