-
Notifications
You must be signed in to change notification settings - Fork 24
/
inference_single-coil.py
138 lines (115 loc) · 5.51 KB
/
inference_single-coil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from pathlib import Path
from models import utils as mutils
from sde_lib import VESDE
from sampling import (ReverseDiffusionPredictor,
LangevinCorrector,
get_pc_fouriercs_RI)
from models import ncsnpp
import time
from utils import fft2, ifft2, get_mask, get_data_scaler, get_data_inverse_scaler, restore_checkpoint
import torch
import torch.nn as nn
import numpy as np
from models.ema import ExponentialMovingAverage
import matplotlib.pyplot as plt
import importlib
import argparse
def main():
###############################################
# 1. Configurations
###############################################
# args
args = create_argparser().parse_args()
N = args.N
m = args.m
fname = args.data
filename = f'./samples/single-coil/{fname}.npy'
print('initaializing...')
configs = importlib.import_module(f"configs.ve.fastmri_knee_320_ncsnpp_continuous")
config = configs.get_config()
img_size = config.data.image_size
batch_size = 1
# Read data
img = torch.from_numpy(np.load(filename).astype(np.complex64))
img = img.view(1, 1, 320, 320)
img = img.to(config.device)
mask = get_mask(img, img_size, batch_size,
type=args.mask_type,
acc_factor=args.acc_factor,
center_fraction=args.center_fraction)
ckpt_filename = f"./weights/checkpoint_95.pth"
sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=N)
config.training.batch_size = batch_size
predictor = ReverseDiffusionPredictor
corrector = LangevinCorrector
probability_flow = False
snr = 0.16
# sigmas = mutils.get_sigmas(config)
scaler = get_data_scaler(config)
inverse_scaler = get_data_inverse_scaler(config)
# create model and load checkpoint
score_model = mutils.create_model(config)
ema = ExponentialMovingAverage(score_model.parameters(),
decay=config.model.ema_rate)
state = dict(step=0, model=score_model, ema=ema)
state = restore_checkpoint(ckpt_filename, state, config.device, skip_sigma=True)
ema.copy_to(score_model.parameters())
# Specify save directory for saving generated samples
save_root = Path(f'./results/single-coil')
save_root.mkdir(parents=True, exist_ok=True)
irl_types = ['input', 'recon', 'recon_progress', 'label']
for t in irl_types:
save_root_f = save_root / t
save_root_f.mkdir(parents=True, exist_ok=True)
###############################################
# 2. Inference
###############################################
pc_fouriercs = get_pc_fouriercs_RI(sde,
predictor, corrector,
inverse_scaler,
snr=snr,
n_steps=m,
probability_flow=probability_flow,
continuous=config.training.continuous,
denoise=True)
# fft
kspace = fft2(img)
# undersampling
under_kspace = kspace * mask
under_img = ifft2(under_kspace)
print(f'Beginning inference')
tic = time.time()
x = pc_fouriercs(score_model, scaler(under_img), mask, Fy=under_kspace)
toc = time.time() - tic
print(f'Time took for recon: {toc} secs.')
###############################################
# 3. Saving recon
###############################################
input = under_img.squeeze().cpu().detach().numpy()
label = img.squeeze().cpu().detach().numpy()
mask_sv = mask.squeeze().cpu().detach().numpy()
np.save(str(save_root / 'input' / fname) + '.npy', input)
np.save(str(save_root / 'input' / (fname + '_mask')) + '.npy', mask_sv)
np.save(str(save_root / 'label' / fname) + '.npy', label)
plt.imsave(str(save_root / 'input' / fname) + '.png', np.abs(input), cmap='gray')
plt.imsave(str(save_root / 'label' / fname) + '.png', np.abs(label), cmap='gray')
recon = x.squeeze().cpu().detach().numpy()
np.save(str(save_root / 'recon' / fname) + '.npy', recon)
plt.imsave(str(save_root / 'recon' / fname) + '.png', np.abs(recon), cmap='gray')
def create_argparser():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='which data to use for reconstruction', required=True)
parser.add_argument('--mask_type', type=str, help='which mask to use for retrospective undersampling.'
'(NOTE) only used for retrospective model!', default='gaussian1d',
choices=['gaussian1d', 'uniform1d', 'gaussian2d'])
parser.add_argument('--acc_factor', type=int, help='Acceleration factor for Fourier undersampling.'
'(NOTE) only used for retrospective model!', default=4)
parser.add_argument('--center_fraction', type=float, help='Fraction of ACS region to keep.'
'(NOTE) only used for retrospective model!', default=0.08)
parser.add_argument('--save_dir', default='./results')
parser.add_argument('--N', type=int, help='Number of iterations for score-POCS sampling', default=2000)
parser.add_argument('--m', type=int, help='Number of corrector step per single predictor step.'
'It is advised not to change this default value.', default=1)
return parser
if __name__ == "__main__":
main()