forked from Hben-atya/P2T2-Robust-T2-estimation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
p2t2.py
130 lines (105 loc) · 3.62 KB
/
p2t2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import Optional
def simulate(
config_file: str = "config.yaml",
model_type: str = "P2T2",
min_te: float = 5.0,
max_te: float = 15.0,
n_echoes: int = 20,
num_signals: int = 10000,
out_folder: str = 'data'
):
"""Generate simulated data using a configuration file.
Args:
config_file (str, optional): Configuration file in yaml format. Defaults to "config.yaml".
model_type (str, optional): Model type. Defaults to "P2T2".
min_te (float, optional): Minimum TE. Defaults to 5.0.
max_te (float, optional): Maximum TE. Defaults to 15.0.
n_echoes (int, optional): Number of echoes. Defaults to 20.
num_signals (int, optional): Number of signals. Defaults to 10000.
out_folder (str, optional): Destination folder for the simulated data. Defaults to 'data'.
"""
from data_simulation import main
print(f"Simulating data to: {out_folder}")
main(
config_path=config_file,
model_type=model_type,
min_te=min_te,
max_te=max_te,
n_echoes=n_echoes,
num_signals=num_signals,
out_folder=out_folder
)
def train(
config: str = 'config.yaml',
data_folder: str = "data",
output_path: str = "runs",
model_type: str = 'P2T2-FC',
min_te: float = 7.9,
max_te: Optional[float] = None
):
from argparse import Namespace
from pathlib import Path
from datetime import datetime
from pt2_reconstruction_model_main import main
print("Running training...")
args = {}
args["config"] = config
args["data_folder"] = data_folder
Path(output_path).mkdir(parents=True, exist_ok=True)
args["runs_outputs_path"] = output_path
now = datetime.now()
args["dt_string"] = now.strftime("%y%m%d_%H_%M_%S")
main(
Namespace(**args),
model_type,
min_te,
max_te
)
def infer(
model_type: str = "P2T2-FC",
model_path: str = "model.pt",
model_args_path: str = "model_args.yaml",
data_dict: dict = {},
output_dir: str = "output",
n_echoes: Optional[int] = None,
):
from pathlib import Path
import SimpleITK as sitk
import numpy as np
from box import Box
import torch
import yaml
import os
from pt2_reconstruction_model_deploy import load_model, deploy_model, load_mri_data
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# load model config
model_args = Box(yaml.safe_load(open(model_args_path)))
if hasattr(model_args, 'max_echoes'):
model_args.num_echoes = model_args.max_echoes
else:
model_args.num_echoes = n_echoes
num_used_echoes = model_args.n_echoes
# load data
data = load_mri_data(data_dict, num_used_echoes)
model_args.TEs = data['TEs']
# load model
model = load_model(model_path, device, model_type, model_args)
# deploy model
pt2_pred, pred_mri, predicted_fa = deploy_model(model, data, device, model_type, model_args)
# save output
Path(output_dir).mkdir(parents=True, exist_ok=True)
pt2_pred_array = sitk.GetImageFromArray(np.transpose(pt2_pred.cpu().numpy(), (2, 1, 0)))
pred_mri_array = sitk.GetImageFromArray(np.transpose(pred_mri.cpu().numpy(), (2, 1, 0)))
predicted_fa_array = sitk.GetImageFromArray(np.transpose(predicted_fa.cpu().numpy(), (1, 0)))
sitk.WriteImage(
pt2_pred_array,
os.path.join(output_dir, 'pt2_pred.nii.gz')
)
sitk.WriteImage(
pred_mri_array,
os.path.join(output_dir, 'pred_mri.nii.gz')
)
sitk.WriteImage(
predicted_fa_array,
os.path.join(output_dir, 'predicted_fa.nii.gz')
)