-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
utils.py
1542 lines (1296 loc) · 62.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import hashlib
import json
import logging
import math
import os
import pickle
import warnings
from collections import abc, defaultdict
from collections.abc import Generator, Iterable, Mapping, Sequence, Sized
from copy import deepcopy
from functools import reduce
from itertools import product, starmap, zip_longest
from pathlib import PurePath
from typing import Any
import numpy as np
import torch
from torch.utils.data._utils.collate import default_collate
from monai import config
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
from monai.data.meta_obj import MetaObj
from monai.utils import (
MAX_SEED,
BlendMode,
Method,
NumpyPadMode,
TraceKeys,
convert_data_type,
convert_to_dst_type,
ensure_tuple,
ensure_tuple_rep,
ensure_tuple_size,
fall_back_tuple,
first,
get_equivalent_dtype,
issequenceiterable,
look_up_option,
optional_import,
)
pd, _ = optional_import("pandas")
DataFrame, _ = optional_import("pandas", name="DataFrame")
nib, _ = optional_import("nibabel")
__all__ = [
"AFFINE_TOL",
"SUPPORTED_PICKLE_MOD",
"affine_to_spacing",
"compute_importance_map",
"compute_shape_offset",
"convert_tables_to_dicts",
"correct_nifti_header_if_necessary",
"create_file_basename",
"decollate_batch",
"dense_patch_slices",
"get_random_patch",
"get_valid_patch_size",
"is_supported_format",
"iter_patch",
"iter_patch_position",
"iter_patch_slices",
"json_hashing",
"list_data_collate",
"no_collation",
"orientation_ras_lps",
"pad_list_data_collate",
"partition_dataset",
"partition_dataset_classes",
"pickle_hashing",
"rectify_header_sform_qform",
"reorient_spatial_axes",
"resample_datalist",
"select_cross_validation_folds",
"set_rnd",
"sorted_dict",
"to_affine_nd",
"worker_init_fn",
"zoom_affine",
"remove_keys",
"remove_extra_metadata",
"get_extra_metadata_keys",
"PICKLE_KEY_SUFFIX",
"is_no_channel",
]
# module to be used by `torch.save`
SUPPORTED_PICKLE_MOD = {"pickle": pickle}
# tolerance for affine matrix computation
AFFINE_TOL = 1e-3
def get_random_patch(
dims: Sequence[int], patch_size: Sequence[int], rand_state: np.random.RandomState | None = None
) -> tuple[slice, ...]:
"""
Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size` or the as
close to it as possible within the given dimension. It is expected that `patch_size` is a valid patch for a source
of shape `dims` as returned by `get_valid_patch_size`.
Args:
dims: shape of source array
patch_size: shape of patch size to generate
rand_state: a random state object to generate random numbers from
Returns:
(tuple of slice): a tuple of slice objects defining the patch
"""
# choose the minimal corner of the patch
rand_int = np.random.randint if rand_state is None else rand_state.randint
min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
# create the slices for each dimension which define the patch in the source array
return tuple(slice(mc, mc + ps) for mc, ps in zip(min_corner, patch_size))
def iter_patch_slices(
image_size: Sequence[int],
patch_size: Sequence[int] | int,
start_pos: Sequence[int] = (),
overlap: Sequence[float] | float = 0.0,
padded: bool = True,
) -> Generator[tuple[slice, ...], None, None]:
"""
Yield successive tuples of slices defining patches of size `patch_size` from an array of dimensions `image_size`.
The iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each
patch is chosen in a contiguous grid using a rwo-major ordering.
Args:
image_size: dimensions of array to iterate over
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
start_pos: starting position in the array, default is 0 for each dimension
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
padded: if the image is padded so the patches can go beyond the borders. Defaults to False.
Yields:
Tuples of slice objects defining each patch
"""
# ensure patch_size has the right length
patch_size_ = get_valid_patch_size(image_size, patch_size)
# create slices based on start position of each patch
for position in iter_patch_position(
image_size=image_size, patch_size=patch_size_, start_pos=start_pos, overlap=overlap, padded=padded
):
yield tuple(slice(s, s + p) for s, p in zip(position, patch_size_))
def dense_patch_slices(
image_size: Sequence[int], patch_size: Sequence[int], scan_interval: Sequence[int]
) -> list[tuple[slice, ...]]:
"""
Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image.
Args:
image_size: dimensions of image to iterate over
patch_size: size of patches to generate slices
scan_interval: dense patch sampling interval
Returns:
a list of slice objects defining each patch
"""
num_spatial_dims = len(image_size)
patch_size = get_valid_patch_size(image_size, patch_size)
scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims)
scan_num = []
for i in range(num_spatial_dims):
if scan_interval[i] == 0:
scan_num.append(1)
else:
num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
scan_dim = first(d for d in range(num) if d * scan_interval[i] + patch_size[i] >= image_size[i])
scan_num.append(scan_dim + 1 if scan_dim is not None else 1)
starts = []
for dim in range(num_spatial_dims):
dim_starts = []
for idx in range(scan_num[dim]):
start_idx = idx * scan_interval[dim]
start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
dim_starts.append(start_idx)
starts.append(dim_starts)
out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]
def iter_patch_position(
image_size: Sequence[int],
patch_size: Sequence[int] | int | np.ndarray,
start_pos: Sequence[int] = (),
overlap: Sequence[float] | float = 0.0,
padded: bool = False,
):
"""
Yield successive tuples of upper left corner of patches of size `patch_size` from an array of dimensions `image_size`.
The iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each
patch is chosen in a contiguous grid using a rwo-major ordering.
Args:
image_size: dimensions of array to iterate over
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
start_pos: starting position in the array, default is 0 for each dimension
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
padded: if the image is padded so the patches can go beyond the borders. Defaults to False.
Yields:
Tuples of positions defining the upper left corner of each patch
"""
# ensure patchSize and startPos are the right length
ndim = len(image_size)
patch_size_ = get_valid_patch_size(image_size, patch_size)
start_pos = ensure_tuple_size(start_pos, ndim)
overlap = ensure_tuple_rep(overlap, ndim)
# calculate steps, which depends on the amount of overlap
steps = tuple(round(p * (1.0 - o)) for p, o in zip(patch_size_, overlap))
# calculate the last starting location (depending on the padding)
end_pos = image_size if padded else tuple(s - round(p) + 1 for s, p in zip(image_size, patch_size_))
# collect the ranges to step over each dimension
ranges = starmap(range, zip(start_pos, end_pos, steps))
# choose patches by applying product to the ranges
return product(*ranges)
def iter_patch(
arr: np.ndarray,
patch_size: Sequence[int] | int = 0,
start_pos: Sequence[int] = (),
overlap: Sequence[float] | float = 0.0,
copy_back: bool = True,
mode: str | None = NumpyPadMode.WRAP,
**pad_opts: dict,
):
"""
Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr`
but drawing from a padded array extended by the `patch_size` in each dimension (so these coordinates can be negative
to start in the padded region). If `copy_back` is True the values from each patch are written back to `arr`.
Args:
arr: array to iterate over
patch_size: size of patches to generate slices for, 0 or None selects whole dimension.
For 0 or None, padding and overlap ratio of the corresponding dimension will be 0.
start_pos: starting position in the array, default is 0 for each dimension
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
copy_back: if True data from the yielded patches is copied back to `arr` once the generator completes
mode: One of the listed string values in ``monai.utils.NumpyPadMode`` or ``monai.utils.PytorchPadMode``,
or a user supplied function. If None, no wrapping is performed. Defaults to ``"wrap"``.
pad_opts: padding options, see `numpy.pad`
Yields:
Patches of array data from `arr` which are views into a padded array which can be modified, if `copy_back` is
True these changes will be reflected in `arr` once the iteration completes.
Note:
coordinate format is:
[1st_dim_start, 1st_dim_end,
2nd_dim_start, 2nd_dim_end,
...,
Nth_dim_start, Nth_dim_end]]
"""
# ensure patchSize and startPos are the right length
patch_size_ = get_valid_patch_size(arr.shape, patch_size)
start_pos = ensure_tuple_size(start_pos, arr.ndim)
# set padded flag to false if pad mode is None
padded = bool(mode)
is_v = [bool(p) for p in ensure_tuple_size(patch_size, arr.ndim)] # whether a valid patch size provided
_pad_size = tuple(p if v and padded else 0 for p, v in zip(patch_size_, is_v)) # pad p if v else 0
_overlap = [op if v else 0.0 for op, v in zip(ensure_tuple_rep(overlap, arr.ndim), is_v)] # overlap if v else 0.0
# pad image by maximum values needed to ensure patches are taken from inside an image
if padded:
arrpad = np.pad(arr, tuple((p, p) for p in _pad_size), look_up_option(mode, NumpyPadMode).value, **pad_opts)
# choose a start position in the padded image
start_pos_padded = tuple(s + p for s, p in zip(start_pos, _pad_size))
# choose a size to iterate over which is smaller than the actual padded image to prevent producing
# patches which are only in the padded regions
iter_size = tuple(s + p for s, p in zip(arr.shape, _pad_size))
else:
arrpad = arr
start_pos_padded = start_pos
iter_size = arr.shape
for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded, _overlap, padded=padded):
# compensate original image padding
if padded:
coords_no_pad = tuple((coord.start - p, coord.stop - p) for coord, p in zip(slices, _pad_size))
else:
coords_no_pad = tuple((coord.start, coord.stop) for coord in slices)
yield arrpad[slices], np.asarray(coords_no_pad) # data and coords (in numpy; works with torch loader)
# copy back data from the padded image if required
if copy_back:
slices = tuple(slice(p, p + s) for p, s in zip(_pad_size, arr.shape))
arr[...] = arrpad[slices]
def get_valid_patch_size(image_size: Sequence[int], patch_size: Sequence[int] | int | np.ndarray) -> tuple[int, ...]:
"""
Given an image of dimensions `image_size`, return a patch size tuple taking the dimension from `patch_size` if this is
not 0/None. Otherwise, or if `patch_size` is shorter than `image_size`, the dimension from `image_size` is taken. This ensures
the returned patch size is within the bounds of `image_size`. If `patch_size` is a single number this is interpreted as a
patch of the same dimensionality of `image_size` with that size in each dimension.
"""
ndim = len(image_size)
patch_size_ = ensure_tuple_size(patch_size, ndim)
# ensure patch size dimensions are not larger than image dimension, if a dimension is None or 0 use whole dimension
return tuple(min(ms, ps or ms) for ms, ps in zip(image_size, patch_size_))
def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"):
"""
Recursively run collate logic and provide detailed loggings for debugging purposes.
It reports results at the 'critical' level, is therefore suitable in the context of exception handling.
Args:
batch: batch input to collate
level: current level of recursion for logging purposes
logger_name: name of logger to use for logging
See also: https://pytorch.org/docs/stable/data.html#working-with-collate-fn
"""
elem = batch[0]
elem_type = type(elem)
l_str = ">" * level
batch_str = f"{batch[:10]}{' ... ' if len(batch) > 10 else ''}"
if isinstance(elem, torch.Tensor):
try:
logging.getLogger(logger_name).critical(f"{l_str} collate/stack a list of tensors")
return torch.stack(batch, 0)
except TypeError as e:
logging.getLogger(logger_name).critical(
f"{l_str} E: {e}, type {[type(elem).__name__ for elem in batch]} in collate({batch_str})"
)
return
except RuntimeError as e:
logging.getLogger(logger_name).critical(
f"{l_str} E: {e}, shape {[elem.shape for elem in batch]} in collate({batch_str})"
)
return
elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_":
if elem_type.__name__ in ["ndarray", "memmap"]:
logging.getLogger(logger_name).critical(f"{l_str} collate/stack a list of numpy arrays")
return dev_collate([torch.as_tensor(b) for b in batch], level=level, logger_name=logger_name)
elif elem.shape == (): # scalars
return batch
elif isinstance(elem, (float, int, str, bytes)):
return batch
elif isinstance(elem, abc.Mapping):
out = {}
for key in elem:
logging.getLogger(logger_name).critical(f'{l_str} collate dict key "{key}" out of {len(elem)} keys')
out[key] = dev_collate([d[key] for d in batch], level=level + 1, logger_name=logger_name)
return out
elif isinstance(elem, abc.Sequence):
it = iter(batch)
els = list(it)
try:
sizes = [len(elem) for elem in els] # may not have `len`
except TypeError:
types = [type(elem).__name__ for elem in els]
logging.getLogger(logger_name).critical(f"{l_str} E: type {types} in collate({batch_str})")
return
logging.getLogger(logger_name).critical(f"{l_str} collate list of sizes: {sizes}.")
if any(s != sizes[0] for s in sizes):
logging.getLogger(logger_name).critical(
f"{l_str} collate list inconsistent sizes, got size: {sizes}, in collate({batch_str})"
)
transposed = zip(*batch)
return [dev_collate(samples, level=level + 1, logger_name=logger_name) for samples in transposed]
logging.getLogger(logger_name).critical(f"{l_str} E: unsupported type in collate {batch_str}.")
return
PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX
def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True):
"""
Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate.
Args:
data: a list or dictionary with substructures to be pickled/unpickled.
key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`).
is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False).
"""
if isinstance(data, Mapping):
data = dict(data)
for k in data:
if f"{k}".endswith(key):
if is_encode and not isinstance(data[k], bytes):
data[k] = pickle.dumps(data[k], 0)
if not is_encode and isinstance(data[k], bytes):
data[k] = pickle.loads(data[k])
return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()}
elif isinstance(data, (list, tuple)):
return [pickle_operations(item, key=key, is_encode=is_encode) for item in data]
return data
def collate_meta_tensor(batch):
"""collate a sequence of meta tensor sequences/dictionaries into
a single batched metatensor or a dictionary of batched metatensor"""
if not isinstance(batch, Sequence):
raise NotImplementedError()
elem_0 = first(batch)
if isinstance(elem_0, MetaObj):
collated = default_collate(batch)
collated.meta = default_collate([i.meta or TraceKeys.NONE for i in batch])
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
collated.is_batch = True
return collated
if isinstance(elem_0, Mapping):
return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0}
if isinstance(elem_0, (tuple, list)):
return [collate_meta_tensor([d[i] for d in batch]) for i in range(len(elem_0))]
# no more recursive search for MetaTensor
return default_collate(batch)
def list_data_collate(batch: Sequence):
"""
Enhancement for PyTorch DataLoader default collate.
If dataset already returns a list of batch data that generated in transforms, need to merge all data to 1 list.
Then it's same as the default collate behavior.
Note:
Need to use this collate if apply some transforms that can generate batch data.
"""
elem = batch[0]
data = [i for k in batch for i in k] if isinstance(elem, list) else batch
key = None
try:
if config.USE_META_DICT:
data = pickle_operations(data) # bc 0.9.0
if isinstance(elem, Mapping):
ret = {}
for k in elem:
key = k
data_for_batch = [d[key] for d in data]
ret[key] = collate_meta_tensor(data_for_batch)
else:
ret = collate_meta_tensor(data)
return ret
except RuntimeError as re:
re_str = str(re)
if "equal size" in re_str:
if key is not None:
re_str += f"\nCollate error on the key '{key}' of dictionary data."
re_str += (
"\n\nMONAI hint: if your transforms intentionally create images of different shapes, creating your "
+ "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its "
+ "documentation)."
)
_ = dev_collate(data)
raise RuntimeError(re_str) from re
except TypeError as re:
re_str = str(re)
if "numpy" in re_str and "Tensor" in re_str:
if key is not None:
re_str += f"\nCollate error on the key '{key}' of dictionary data."
re_str += (
"\n\nMONAI hint: if your transforms intentionally create mixtures of torch Tensor and numpy ndarray, "
+ "creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem "
+ "(check its documentation)."
)
_ = dev_collate(data)
raise TypeError(re_str) from re
def _non_zipping_check(batch_data: Mapping | Iterable, detach: bool, pad: bool, fill_value):
"""
Utility function based on `decollate_batch`, to identify the largest batch size from the collated data.
returns batch_size, the list of non-iterable items, and the dictionary or list with their items decollated.
See `decollate_batch` for more details.
"""
_deco: Mapping | Sequence
if isinstance(batch_data, Mapping):
_deco = {key: decollate_batch(batch_data[key], detach, pad=pad, fill_value=fill_value) for key in batch_data}
elif isinstance(batch_data, Iterable):
_deco = [decollate_batch(b, detach, pad=pad, fill_value=fill_value) for b in batch_data]
else:
raise NotImplementedError(f"Unable to de-collate: {batch_data}, type: {type(batch_data)}.")
batch_size, non_iterable = 0, []
for k, v in _deco.items() if isinstance(_deco, Mapping) else enumerate(_deco):
if not isinstance(v, Iterable) or isinstance(v, (str, bytes)) or (isinstance(v, torch.Tensor) and v.ndim == 0):
# Not running the usual list decollate here:
# don't decollate ['test', 'test'] into [['t', 't'], ['e', 'e'], ['s', 's'], ['t', 't']]
# torch.tensor(0) is iterable but iter(torch.tensor(0)) raises TypeError: iteration over a 0-d tensor
non_iterable.append(k)
elif isinstance(v, Sized):
batch_size = max(batch_size, len(v))
return batch_size, non_iterable, _deco
def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
"""De-collate a batch of data (for example, as produced by a `DataLoader`).
Returns a list of structures with the original tensor's 0-th dimension sliced into elements using `torch.unbind`.
Images originally stored as (B,C,H,W,[D]) will be returned as (C,H,W,[D]). Other information,
such as metadata, may have been stored in a list (or a list inside nested dictionaries). In
this case we return the element of the list corresponding to the batch idx.
Return types aren't guaranteed to be the same as the original, since numpy arrays will have been
converted to torch.Tensor, sequences may be converted to lists of tensors,
mappings may be converted into dictionaries.
For example:
.. code-block:: python
batch_data = {
"image": torch.rand((2,1,10,10)),
DictPostFix.meta("image"): {"scl_slope": torch.Tensor([0.0, 0.0])}
}
out = decollate_batch(batch_data)
print(len(out))
>>> 2
print(out[0])
>>> {'image': tensor([[[4.3549e-01...43e-01]]]), DictPostFix.meta("image"): {'scl_slope': 0.0}}
batch_data = [torch.rand((2,1,10,10)), torch.rand((2,3,5,5))]
out = decollate_batch(batch_data)
print(out[0])
>>> [tensor([[[4.3549e-01...43e-01]]], tensor([[[5.3435e-01...45e-01]]])]
batch_data = torch.rand((2,1,10,10))
out = decollate_batch(batch_data)
print(out[0])
>>> tensor([[[4.3549e-01...43e-01]]])
batch_data = {
"image": [1, 2, 3], "meta": [4, 5], # undetermined batch size
}
out = decollate_batch(batch_data, pad=True, fill_value=0)
print(out)
>>> [{'image': 1, 'meta': 4}, {'image': 2, 'meta': 5}, {'image': 3, 'meta': 0}]
out = decollate_batch(batch_data, pad=False)
print(out)
>>> [{'image': 1, 'meta': 4}, {'image': 2, 'meta': 5}]
Args:
batch: data to be de-collated.
detach: whether to detach the tensors. Scalars tensors will be detached into number types
instead of torch tensors.
pad: when the items in a batch indicate different batch size, whether to pad all the sequences to the longest.
If False, the batch size will be the length of the shortest sequence.
fill_value: when `pad` is True, the `fillvalue` to use when padding, defaults to `None`.
"""
if batch is None:
return batch
if isinstance(batch, (float, int, str, bytes)) or (
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
):
return batch
if isinstance(batch, torch.Tensor):
if detach:
batch = batch.detach()
if batch.ndim == 0:
return batch.item() if detach else batch
out_list = torch.unbind(batch, dim=0)
# if of type MetaObj, decollate the metadata
if isinstance(batch, MetaObj):
for t, m in zip(out_list, decollate_batch(batch.meta)):
if isinstance(t, MetaObj):
t.meta = m
t.is_batch = False
for t, m in zip(out_list, batch.applied_operations):
if isinstance(t, MetaObj):
t.applied_operations = m
t.is_batch = False
if out_list[0].ndim == 0 and detach:
return [t.item() for t in out_list]
return list(out_list)
b, non_iterable, deco = _non_zipping_check(batch, detach, pad, fill_value)
if b <= 0: # all non-iterable, single item "batch"? {"image": 1, "label": 1}
return deco
if pad: # duplicate non-iterable items to the longest batch
for k in non_iterable:
deco[k] = [deepcopy(deco[k]) for _ in range(b)]
if isinstance(deco, Mapping):
_gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values())
ret = [dict(zip(deco, item)) for item in _gen]
if not config.USE_META_DICT:
return ret
return pickle_operations(ret, is_encode=False) # bc 0.9.0
if isinstance(deco, Iterable):
_gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco)
ret_list = [list(item) for item in _gen]
if not config.USE_META_DICT:
return ret_list
return pickle_operations(ret_list, is_encode=False) # bc 0.9.0
raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.")
def pad_list_data_collate(batch: Sequence, method: str = Method.SYMMETRIC, mode: str = NumpyPadMode.CONSTANT, **kwargs):
"""
Function version of :py:class:`monai.transforms.croppad.batch.PadListDataCollate`.
Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest
tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of
different sizes.
This can be used on both list and dictionary data.
Note that in the case of the dictionary data, this decollate function may add the transform information of
`PadListDataCollate` to the list of invertible transforms if input batch have different spatial shape, so need to
call static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse` before inverting other transforms.
Args:
batch: batch of data to pad-collate
method: padding method (see :py:class:`monai.transforms.SpatialPad`)
mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
kwargs: other arguments for the `np.pad` or `torch.pad` function.
note that `np.pad` treats channel dimension as the first dimension.
"""
from monai.transforms.croppad.batch import PadListDataCollate # needs to be here to avoid circular import
return PadListDataCollate(method=method, mode=mode, **kwargs)(batch)
def no_collation(x):
"""
No any collation operation.
"""
return x
def worker_init_fn(worker_id: int) -> None:
"""
Callback function for PyTorch DataLoader `worker_init_fn`.
It can set different random seed for the transforms in different workers.
"""
worker_info = torch.utils.data.get_worker_info()
set_rnd(worker_info.dataset, seed=worker_info.seed) # type: ignore[union-attr]
def set_rnd(obj, seed: int) -> int:
"""
Set seed or random state for all randomizable properties of obj.
Args:
obj: object to set seed or random state for.
seed: set the random state with an integer seed.
"""
if isinstance(obj, (tuple, list)): # ZipDataset.data is a list
_seed = seed
for item in obj:
_seed = set_rnd(item, seed=seed)
return seed if _seed == seed else seed + 1 # return a different seed if there are randomizable items
if not hasattr(obj, "__dict__"):
return seed # no attribute
if hasattr(obj, "set_random_state"):
obj.set_random_state(seed=seed % MAX_SEED)
return seed + 1 # a different seed for the next component
for key in obj.__dict__:
if key.startswith("__"): # skip the private methods
continue
seed = set_rnd(obj.__dict__[key], seed=seed)
return seed
def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_zeros: bool = True) -> NdarrayTensor:
"""
Computing the current spacing from the affine matrix.
Args:
affine: a d x d affine matrix.
r: indexing based on the spatial rank, spacing is computed from `affine[:r, :r]`.
dtype: data type of the output.
suppress_zeros: whether to suppress the zeros with ones.
Returns:
an `r` dimensional vector of spacing.
"""
if len(affine.shape) != 2 or affine.shape[0] != affine.shape[1]:
raise ValueError(f"affine must be a square matrix, got {affine.shape}.")
_affine, *_ = convert_to_dst_type(affine[:r, :r], dst=affine, dtype=dtype)
if isinstance(_affine, torch.Tensor):
spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))
else:
spacing = np.sqrt(np.sum(_affine * _affine, axis=0))
if suppress_zeros:
spacing[spacing == 0] = 1.0
spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype)
return spacing_
def correct_nifti_header_if_necessary(img_nii):
"""
Check nifti object header's format, update the header if needed.
In the updated image pixdim matches the affine.
Args:
img_nii: nifti image object
"""
if img_nii.header.get("dim") is None:
return img_nii # not nifti?
dim = img_nii.header["dim"][0]
if dim >= 5:
return img_nii # do nothing for high-dimensional array
# check that affine matches zooms
pixdim = np.asarray(img_nii.header.get_zooms())[:dim]
norm_affine = affine_to_spacing(img_nii.affine, r=dim)
if np.allclose(pixdim, norm_affine):
return img_nii
if hasattr(img_nii, "get_sform"):
return rectify_header_sform_qform(img_nii)
return img_nii
def rectify_header_sform_qform(img_nii):
"""
Look at the sform and qform of the nifti object and correct it if any
incompatibilities with pixel dimensions
Adapted from https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/io/misc_io.py
Args:
img_nii: nifti image object
"""
d = img_nii.header["dim"][0]
pixdim = np.asarray(img_nii.header.get_zooms())[:d]
sform, qform = img_nii.get_sform(), img_nii.get_qform()
norm_sform = affine_to_spacing(sform, r=d)
norm_qform = affine_to_spacing(qform, r=d)
sform_mismatch = not np.allclose(norm_sform, pixdim)
qform_mismatch = not np.allclose(norm_qform, pixdim)
if img_nii.header["sform_code"] != 0:
if not sform_mismatch:
return img_nii
if not qform_mismatch:
img_nii.set_sform(img_nii.get_qform())
return img_nii
if img_nii.header["qform_code"] != 0:
if not qform_mismatch:
return img_nii
if not sform_mismatch:
img_nii.set_qform(img_nii.get_sform())
return img_nii
norm = affine_to_spacing(img_nii.affine, r=d)
warnings.warn(f"Modifying image pixdim from {pixdim} to {norm}")
img_nii.header.set_zooms(norm)
return img_nii
def zoom_affine(affine: np.ndarray, scale: np.ndarray | Sequence[float], diagonal: bool = True):
"""
To make column norm of `affine` the same as `scale`. If diagonal is False,
returns an affine that combines orthogonal rotation and the new scale.
This is done by first decomposing `affine`, then setting the zoom factors to
`scale`, and composing a new affine; the shearing factors are removed. If
diagonal is True, returns a diagonal matrix, the scaling factors are set
to the diagonal elements. This function always return an affine with zero
translations.
Args:
affine (nxn matrix): a square matrix.
scale: new scaling factor along each dimension. if the components of the `scale` are non-positive values,
will use the corresponding components of the original pixdim, which is computed from the `affine`.
diagonal: whether to return a diagonal scaling matrix.
Defaults to True.
Raises:
ValueError: When ``affine`` is not a square matrix.
ValueError: When ``scale`` contains a nonpositive scalar.
Returns:
the updated `n x n` affine.
"""
affine = np.array(affine, dtype=float, copy=True)
if len(affine) != len(affine[0]):
raise ValueError(f"affine must be n x n, got {len(affine)} x {len(affine[0])}.")
scale_np = np.array(scale, dtype=float, copy=True)
d = len(affine) - 1
# compute original pixdim
norm = affine_to_spacing(affine, r=d)
if len(scale_np) < d: # defaults based on affine
scale_np = np.append(scale_np, norm[len(scale_np) :])
scale_np = scale_np[:d]
scale_np = np.asarray(fall_back_tuple(scale_np, norm))
scale_np[scale_np == 0] = 1.0
if diagonal:
return np.diag(np.append(scale_np, [1.0]))
rzs = affine[:-1, :-1] # rotation zoom scale
zs = np.linalg.cholesky(rzs.T @ rzs).T
rotation = rzs @ np.linalg.inv(zs)
s = np.sign(np.diag(zs)) * np.abs(scale_np)
# construct new affine with rotation and zoom
new_affine = np.eye(len(affine))
new_affine[:-1, :-1] = rotation @ np.diag(s)
return new_affine
def compute_shape_offset(
spatial_shape: np.ndarray | Sequence[int],
in_affine: NdarrayOrTensor,
out_affine: NdarrayOrTensor,
scale_extent: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""
Given input and output affine, compute appropriate shapes
in the output space based on the input array's shape.
This function also returns the offset to put the shape
in a good position with respect to the world coordinate system.
Args:
spatial_shape: input array's shape
in_affine (matrix): 2D affine matrix
out_affine (matrix): 2D affine matrix
scale_extent: whether the scale is computed based on the spacing or the full extent of voxels, for example, for
a factor of 0.5 scaling:
option 1, "o" represents a voxel, scaling the distance between voxels::
o--o--o
o-----o
option 2, each voxel has a physical extent, scaling the full voxel extent::
| voxel 1 | voxel 2 | voxel 3 | voxel 4 |
| voxel 1 | voxel 2 |
Option 1 may reduce the number of locations that requiring interpolation. Option 2 is more resolution
agnostic, that is, resampling coordinates depend on the scaling factor, not on the number of voxels.
Default is False, using option 1 to compute the shape and offset.
"""
shape = np.array(spatial_shape, copy=True, dtype=float)
sr = len(shape)
in_affine_ = convert_data_type(to_affine_nd(sr, in_affine), np.ndarray)[0]
out_affine_ = convert_data_type(to_affine_nd(sr, out_affine), np.ndarray)[0]
in_coords = [(-0.5, dim - 0.5) if scale_extent else (0.0, dim - 1.0) for dim in shape]
corners: np.ndarray = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1))
corners = np.concatenate((corners, np.ones_like(corners[:1])))
try:
corners_out = np.linalg.solve(out_affine_, in_affine_) @ corners
except np.linalg.LinAlgError as e:
raise ValueError(f"Affine {out_affine_} is not invertible") from e
corners = in_affine_ @ corners
all_dist = corners_out[:-1].copy()
corners_out = corners_out[:-1] / corners_out[-1]
out_shape = np.round(corners_out.ptp(axis=1)) if scale_extent else np.round(corners_out.ptp(axis=1) + 1.0)
offset = None
for i in range(corners.shape[1]):
min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1)
if np.allclose(min_corner, 0.0, rtol=AFFINE_TOL):
offset = corners[:-1, i] # corner is the smallest, shift the corner to origin
break
if offset is None: # otherwise make output image center aligned with the input image center
offset = in_affine_[:-1, :-1] @ (shape / 2.0) + in_affine_[:-1, -1] - out_affine_[:-1, :-1] @ (out_shape / 2.0)
if scale_extent:
in_offset = np.append(0.5 * (shape / out_shape - 1.0), 1.0)
offset = np.abs((in_affine_ @ in_offset / in_offset[-1])[:-1]) * np.sign(offset)
return out_shape.astype(int, copy=False), offset # type: ignore
def to_affine_nd(r: np.ndarray | int, affine: NdarrayTensor, dtype=np.float64) -> NdarrayTensor:
"""
Using elements from affine, to create a new affine matrix by
assigning the rotation/zoom/scaling matrix and the translation vector.
When ``r`` is an integer, output is an (r+1)x(r+1) matrix,
where the top left kxk elements are copied from ``affine``,
the last column of the output affine is copied from ``affine``'s last column.
`k` is determined by `min(r, len(affine) - 1)`.
When ``r`` is an affine matrix, the output has the same shape as ``r``,
and the top left kxk elements are copied from ``affine``,
the last column of the output affine is copied from ``affine``'s last column.
`k` is determined by `min(len(r) - 1, len(affine) - 1)`.
Args:
r (int or matrix): number of spatial dimensions or an output affine to be filled.
affine (matrix): 2D affine matrix
dtype: data type of the output array.
Raises:
ValueError: When ``affine`` dimensions is not 2.
ValueError: When ``r`` is nonpositive.
Returns:
an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type)
"""
dtype = get_equivalent_dtype(dtype, np.ndarray)
affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0]
affine_np = affine_np.copy()
if affine_np.ndim != 2:
raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.")
new_affine = np.array(r, dtype=dtype, copy=True)
if new_affine.ndim == 0:
sr: int = int(new_affine.astype(np.uint))
if not np.isfinite(sr) or sr < 0:
raise ValueError(f"r must be positive, got {sr}.")
new_affine = np.eye(sr + 1, dtype=dtype)
d = max(min(len(new_affine) - 1, len(affine_np) - 1), 1)
new_affine[:d, :d] = affine_np[:d, :d]
if d > 1:
new_affine[:d, -1] = affine_np[:d, -1]
output, *_ = convert_to_dst_type(new_affine, affine, dtype=dtype)
return output
def reorient_spatial_axes(
data_shape: Sequence[int], init_affine: NdarrayOrTensor, target_affine: NdarrayOrTensor
) -> tuple[np.ndarray, NdarrayOrTensor]:
"""
Given the input ``init_affine``, compute the orientation transform between
it and ``target_affine`` by rearranging/flipping the axes.
Returns the orientation transform and the updated affine (tensor or ndarray
depends on the input ``affine`` data type).
Note that this function requires external module ``nibabel.orientations``.
"""
init_affine_, *_ = convert_data_type(init_affine, np.ndarray)
target_affine_, *_ = convert_data_type(target_affine, np.ndarray)
start_ornt = nib.orientations.io_orientation(init_affine_)
target_ornt = nib.orientations.io_orientation(target_affine_)
try:
ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt)
except ValueError as e:
raise ValueError(f"The input affine {init_affine} and target affine {target_affine} are not compatible.") from e
new_affine = init_affine_ @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape)
new_affine, *_ = convert_to_dst_type(new_affine, init_affine)
return ornt_transform, new_affine
def create_file_basename(
postfix: str,
input_file_name: PathLike,
folder_path: PathLike,
data_root_dir: PathLike = "",
separate_folder: bool = True,
patch_index=None,
makedirs: bool = True,
) -> str:
"""
Utility function to create the path to the output file based on the input
filename (file name extension is not added by this function).
When ``data_root_dir`` is not specified, the output file name is:
`folder_path/input_file_name (no ext.) /input_file_name (no ext.)[_postfix][_patch_index]`
otherwise the relative path with respect to ``data_root_dir`` will be inserted, for example:
.. code-block:: python
from monai.data import create_file_basename
create_file_basename(
postfix="seg",
input_file_name="/foo/bar/test1/image.png",
folder_path="/output",
data_root_dir="/foo/bar",
separate_folder=True,
makedirs=False)
# output: /output/test1/image/image_seg