Skip to content

Commit

Permalink
Merge pull request #474 from dkuegler/feature/segstats/brainvolstats
Browse files Browse the repository at this point in the history
Add brainvol stats to the segstats command
  • Loading branch information
m-reuter authored Aug 21, 2024
2 parents bde5b9d + 0faf91b commit 494460b
Show file tree
Hide file tree
Showing 18 changed files with 6,054 additions and 2,225 deletions.
325 changes: 97 additions & 228 deletions CerebNet/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@


# IMPORTS
from typing import Tuple, Union, Sequence, Optional, TypeVar
from typing import Tuple, Union, Sequence, Optional, TypeVar, TypedDict, Iterable, Type
from pathlib import Path

import nibabel as nib
import numpy as np
from numpy import typing as npt
import torch

from FastSurferCNN.data_loader.conform import getscale, scalecrop

# class names for network training and validation/testing
CLASS_NAMES = {
"Background": 0,
"Left_I_IV": 1,
Expand Down Expand Up @@ -54,11 +55,38 @@
"Right_Corpus_Medullare": 38,
}

# class names for network training and validation/testing
subseg_labels = {"cereb_subseg": np.array(list(CLASS_NAMES.values()))}

AT = TypeVar("AT", np.ndarray, torch.Tensor)


class LTADict(TypedDict):
type: int
nxforms: int
mean: list[float]
sigma: float
lta: npt.NDArray[float]
src_valid: int
src_filename: str
src_volume: list[int]
src_voxelsize: list[float]
src_xras: list[float]
src_yras: list[float]
src_zras: list[float]
src_cras: list[float]
dst_valid: int
dst_filename: str
dst_volume: list[int]
dst_voxelsize: list[float]
dst_xras: list[float]
dst_yras: list[float]
dst_zras: list[float]
dst_cras: list[float]
src: npt.NDArray[float]
dst: npt.NDArray[float]


def define_size(mov_dim, ref_dim):
new_dim = np.zeros(len(mov_dim), dtype=int)
borders = np.zeros((len(mov_dim), 2), dtype=int)
Expand Down Expand Up @@ -167,7 +195,7 @@ def bounding_volume_offset(
if isinstance(img, np.ndarray):
from FastSurferCNN.data_loader.data_utils import bbox_3d

bbox = bbox_3d(img != 0)
bbox = bbox_3d(np.not_equal(img, 0))
bbox = bbox[::2] + bbox[1::2]
else:
bbox = img
Expand Down Expand Up @@ -325,237 +353,78 @@ def apply_warp_field(dform_field, img, interpol_order=3):
return deformed_img


def readLTA(file):
def read_lta(file: Path | str) -> LTADict:
"""Read the LTA info."""
import re
from functools import partial
import numpy as np
parameter_pattern = re.compile("^\s*([^=]+)\s*=\s*([^#]*)\s*(#.*)")
vol_info_pattern = re.compile("^(.*) volume info$")
shape_pattern = re.compile("^(\s*\d+)+$")
matrix_pattern = re.compile("^(-?\d+\.\S+\s+)+$")

_Type = TypeVar("_Type", bound=Type)

def _vector(_a: str, dtype: Type[_Type] = float, count: int = -1) -> list[_Type]:
return np.fromstring(_a, dtype=dtype, count=count, sep=" ").tolist()

parameters = {
"type": int,
"nxforms": int,
"mean": partial(_vector, dtype=float, count=3),
"sigma": float,
"subject": str,
"fscale": float,
}
vol_info_par = {
"valid": int,
"filename": str,
"volume": partial(_vector, dtype=int, count=3),
"voxelsize": partial(_vector, dtype=float, count=3),
**{f"{c}ras": partial(_vector, dtype=float) for c in "xyzc"}
}

with open(file, "r") as f:
lta = f.readlines()
d = dict()
i = 0
while i < len(lta):
if re.match("type", lta[i]) is not None:
d["type"] = int(
re.sub("=", "", re.sub("[a-z]+", "", re.sub("#.*", "", lta[i]))).strip()
)
i += 1
elif re.match("nxforms", lta[i]) is not None:
d["nxforms"] = int(
re.sub("=", "", re.sub("[a-z]+", "", re.sub("#.*", "", lta[i]))).strip()
)
i += 1
elif re.match("mean", lta[i]) is not None:
d["mean"] = [
float(x)
for x in re.split(
" +",
re.sub(
"=", "", re.sub("[a-z]+", "", re.sub("#.*", "", lta[i]))
).strip(),
)
]
i += 1
elif re.match("sigma", lta[i]) is not None:
d["sigma"] = float(
re.sub("=", "", re.sub("[a-z]+", "", re.sub("#.*", "", lta[i]))).strip()
)
i += 1
elif (
re.match(
"-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+", lta[i]
)
is not None
):
d["lta"] = np.array(
[
[
float(x)
for x in re.split(
" +",
re.match(
"-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+",
lta[i],
).string.strip(),
)
],
[
float(x)
for x in re.split(
" +",
re.match(
"-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+",
lta[i + 1],
).string.strip(),
)
],
[
float(x)
for x in re.split(
" +",
re.match(
"-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+",
lta[i + 2],
).string.strip(),
)
],
[
float(x)
for x in re.split(
" +",
re.match(
"-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+",
lta[i + 3],
).string.strip(),
)
],
]
)
i += 4
elif re.match("src volume info", lta[i]) is not None:
while i < len(lta) and re.match("dst volume info", lta[i]) is None:
if re.match("valid", lta[i]) is not None:
d["src_valid"] = int(
re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
elif re.match("filename", lta[i]) is not None:
d["src_filename"] = re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
elif re.match("volume", lta[i]) is not None:
d["src_volume"] = [
int(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
elif re.match("voxelsize", lta[i]) is not None:
d["src_voxelsize"] = [
float(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
elif re.match("xras", lta[i]) is not None:
d["src_xras"] = [
float(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
elif re.match("yras", lta[i]) is not None:
d["src_yras"] = [
float(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
elif re.match("zras", lta[i]) is not None:
d["src_zras"] = [
float(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
elif re.match("cras", lta[i]) is not None:
d["src_cras"] = [
float(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
i += 1
elif re.match("dst volume info", lta[i]) is not None:
while i < len(lta) and re.match("src volume info", lta[i]) is None:
if re.match("valid", lta[i]) is not None:
d["dst_valid"] = int(
re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
elif re.match("filename", lta[i]) is not None:
d["dst_filename"] = re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
elif re.match("volume", lta[i]) is not None:
d["dst_volume"] = [
int(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
elif re.match("voxelsize", lta[i]) is not None:
d["dst_voxelsize"] = [
float(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
elif re.match("xras", lta[i]) is not None:
d["dst_xras"] = [
float(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
elif re.match("yras", lta[i]) is not None:
d["dst_yras"] = [
float(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
elif re.match("zras", lta[i]) is not None:
d["dst_zras"] = [
float(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
elif re.match("cras", lta[i]) is not None:
d["dst_cras"] = [
float(x)
for x in re.split(
" +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip()
)
]
i += 1
else:
i += 1
# create full transformation matrices
d["src"] = np.concatenate(
(
np.concatenate(
(
np.c_[d["src_xras"]],
np.c_[d["src_yras"]],
np.c_[d["src_zras"]],
np.c_[d["src_cras"]],
),
axis=1,
),
np.array([0.0, 0.0, 0.0, 1.0], ndmin=2),
),
axis=0,
)
d["dst"] = np.concatenate(
(
np.concatenate(
(
np.c_[d["dst_xras"]],
np.c_[d["dst_yras"]],
np.c_[d["dst_zras"]],
np.c_[d["dst_cras"]],
),
axis=1,
),
np.array([0.0, 0.0, 0.0, 1.0], ndmin=2),
),
axis=0,
)
# return
return d
lines = f.readlines()

items = []
shape_lines = []
matrix_lines = []
section = ""
for i, line in enumerate(lines):
if line.strip() == "":
continue
if hits := parameter_pattern.match(line):
name = hits.group(1)
if section and name in vol_info_par:
items.append((f"{section}_{name}", vol_info_par[name](hits.group(2))))
elif name in parameters:
section = ""
items.append((name, parameters[name](hits.group(2))))
else:
raise NotImplementedError(f"Unrecognized type string in lta-file "
f"{file}:{i+1}: '{name}'")
elif hits := vol_info_pattern.match(line):
section = hits.group(1)
# not a parameter line
elif shape_pattern.search(line):
shape_lines.append(np.fromstring(line, dtype=int, count=-1, sep=" "))
elif matrix_pattern.search(line):
matrix_lines.append(np.fromstring(line, dtype=float, count=-1, sep=" "))

shape_lines = list(map(tuple, shape_lines))
lta = dict(items)
if lta["nxforms"] != len(shape_lines):
raise IOError("Inconsistent lta format: nxforms inconsistent with shapes.")
if len(shape_lines) > 1 and np.any(np.not_equal([shape_lines[0]], shape_lines[1:])):
raise IOError(f"Inconsistent lta format: shapes inconsistent {shape_lines}")
lta_matrix = np.asarray(matrix_lines).reshape((-1,) + shape_lines[0].shape)
lta["lta"] = lta_matrix
return lta


def load_talairach_coordinates(tala_path, img_shape, vox2ras):
tala_lta = readLTA(tala_path)
tala_lta = read_lta(tala_path)
# create image grid p
x, y, z = np.meshgrid(
np.arange(img_shape[0]),
Expand All @@ -567,7 +436,7 @@ def load_talairach_coordinates(tala_path, img_shape, vox2ras):
p1 = np.concatenate((p, np.ones((p.shape[0], 1))), axis=1)

assert tala_lta["type"] == 1, "talairach not in ras2ras" # ras2ras
m = np.matmul(tala_lta["lta"], vox2ras)
m = np.matmul(tala_lta["lta"][0, 0], vox2ras)

tala_coordinates = np.matmul(m, p1.transpose()).transpose()
tala_coordinates = tala_coordinates[:, :-1]
Expand Down
1 change: 1 addition & 0 deletions CerebNet/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _get_ids_startswith(_label_map: Dict[int, str], prefix: str) -> List[int]:
table = pv_calc(
seg_data,
norm_data,
norm_data,
list(filter(lambda l: l != 0, label_map.keys())),
vox_vol=vox_vol,
threads=self.threads,
Expand Down
Loading

0 comments on commit 494460b

Please sign in to comment.