-
Notifications
You must be signed in to change notification settings - Fork 9
/
utils.py
81 lines (61 loc) · 2.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import math
from math import floor
import numpy as np
import torch
import torch.nn as nn
from cross_correlation import xcorr_torch
from spectral_tools import gen_mtf
def net_scope(kernel_size):
"""
Compute the network scope.
Parameters
----------
kernel_size : List[int]
A list containing the kernel size of each layer of the network.
Return
------
scope : int
The scope of the network
"""
scope = np.sum([math.floor(k / 2) for k in kernel_size])
return scope
def local_corr_mask(img_in, ratio, sensor, device, kernel=8):
"""
Compute the threshold mask for the structural loss.
Parameters
----------
img_in : Torch Tensor
The test image, already normalized and with the MS part upsampled with ideal interpolator.
ratio : int
The resolution scale which elapses between MS and PAN.
sensor : str
The name of the satellites which has provided the images.
device : Torch device
The device on which perform the operation.
kernel : int
The semi-width for local cross-correlation computation.
(See the cross-correlation function for more details)
Return
------
mask : PyTorch Tensor
Local correlation field stack, composed by each MS and PAN. Dimensions: Batch, B, H, W.
"""
I_PAN = torch.unsqueeze(img_in[:, -1, :, :], dim=1)
I_MS = img_in[:, :-1, :, :]
MTF_kern = gen_mtf(ratio, sensor)[:, :, 0]
MTF_kern = np.expand_dims(MTF_kern, axis=(0, 1))
MTF_kern = torch.from_numpy(MTF_kern).type(torch.float32)
pad = floor((MTF_kern.shape[-1] - 1) / 2)
padding = nn.ReflectionPad2d(pad)
depthconv = nn.Conv2d(in_channels=1,
out_channels=1,
groups=1,
kernel_size=MTF_kern.shape,
bias=False)
depthconv.weight.data = MTF_kern
depthconv.weight.requires_grad = False
I_PAN = padding(I_PAN)
I_PAN = depthconv(I_PAN)
mask = xcorr_torch(I_PAN, I_MS, kernel, device)
mask = 1.0 - mask
return mask