Skip to content

Commit

Permalink
Merge pull request py4dstem#577 from py4dstem/phase_contrast
Browse files Browse the repository at this point in the history
Thankfully these phase_contrast changes are as easy as pie
  • Loading branch information
bsavitzky authored Nov 22, 2023
2 parents 22ffa92 + 2a3a4a8 commit 3397349
Show file tree
Hide file tree
Showing 9 changed files with 459 additions and 183 deletions.
59 changes: 49 additions & 10 deletions py4DSTEM/process/phase/iterative_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def _normalize_diffraction_intensities(

diffraction_intensities = self._asnumpy(diffraction_intensities)
if positions_mask is not None:
number_of_patterns = np.count_nonzero(self._positions_mask.ravel())
number_of_patterns = np.count_nonzero(positions_mask.ravel())
else:
number_of_patterns = np.prod(diffraction_intensities.shape[:2])

Expand Down Expand Up @@ -1217,7 +1217,7 @@ def _normalize_diffraction_intensities(
for rx in range(diffraction_intensities.shape[0]):
for ry in range(diffraction_intensities.shape[1]):
if positions_mask is not None:
if not self._positions_mask[rx, ry]:
if not positions_mask[rx, ry]:
continue
intensities = get_shifted_ar(
diffraction_intensities[rx, ry],
Expand Down Expand Up @@ -1348,6 +1348,14 @@ def to_h5(self, group):
data=metadata,
)

# saving multiple None positions_mask fix
if self._positions_mask is None:
positions_mask = None
elif self._positions_mask[0] is None:
positions_mask = None
else:
positions_mask = self._positions_mask

# preprocessing metadata
self.metadata = Metadata(
name="preprocess_metadata",
Expand All @@ -1359,7 +1367,7 @@ def to_h5(self, group):
"num_diffraction_patterns": self._num_diffraction_patterns,
"sampling": self.sampling,
"angular_sampling": self.angular_sampling,
"positions_mask": self._positions_mask,
"positions_mask": positions_mask,
},
)

Expand Down Expand Up @@ -2146,6 +2154,7 @@ def plot_position_correction(
def _return_fourier_probe(
self,
probe=None,
remove_initial_probe_aberrations=False,
):
"""
Returns complex fourier probe shifted to center of array from
Expand All @@ -2155,6 +2164,8 @@ def _return_fourier_probe(
----------
probe: complex array, optional
if None is specified, uses self._probe
remove_initial_probe_aberrations: bool, optional
If True, removes initial probe aberrations from Fourier probe
Returns
-------
Expand All @@ -2168,11 +2179,17 @@ def _return_fourier_probe(
else:
probe = xp.asarray(probe, dtype=xp.complex64)

return xp.fft.fftshift(xp.fft.fft2(probe), axes=(-2, -1))
fourier_probe = xp.fft.fft2(probe)

if remove_initial_probe_aberrations:
fourier_probe *= xp.conjugate(self._known_aberrations_array)

return xp.fft.fftshift(fourier_probe, axes=(-2, -1))

def _return_fourier_probe_from_centered_probe(
self,
probe=None,
remove_initial_probe_aberrations=False,
):
"""
Returns complex fourier probe shifted to center of array from
Expand All @@ -2182,14 +2199,19 @@ def _return_fourier_probe_from_centered_probe(
----------
probe: complex array, optional
if None is specified, uses self._probe
remove_initial_probe_aberrations: bool, optional
If True, removes initial probe aberrations from Fourier probe
Returns
-------
fourier_probe: np.ndarray
Fourier-transformed and center-shifted probe.
"""
xp = self._xp
return self._return_fourier_probe(xp.fft.ifftshift(probe, axes=(-2, -1)))
return self._return_fourier_probe(
xp.fft.ifftshift(probe, axes=(-2, -1)),
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)

def _return_centered_probe(
self,
Expand Down Expand Up @@ -2482,6 +2504,7 @@ def show_uncertainty_visualization(
def show_fourier_probe(
self,
probe=None,
remove_initial_probe_aberrations=False,
cbar=True,
scalebar=True,
pixelsize=None,
Expand All @@ -2495,6 +2518,8 @@ def show_fourier_probe(
----------
probe: complex array, optional
if None is specified, uses the `probe_fourier` property
remove_initial_probe_aberrations: bool, optional
If True, removes initial probe aberrations from Fourier probe
cbar: bool, optional
if True, adds colorbar
scalebar: bool, optional
Expand All @@ -2506,18 +2531,19 @@ def show_fourier_probe(
"""
asnumpy = self._asnumpy

if probe is None:
probe = self.probe_fourier
else:
probe = asnumpy(self._return_fourier_probe(probe))
probe = asnumpy(
self._return_fourier_probe(
probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations
)
)

if pixelsize is None:
pixelsize = self._reciprocal_sampling[1]
if pixelunits is None:
pixelunits = r"$\AA^{-1}$"

figsize = kwargs.pop("figsize", (6, 6))
chroma_boost = kwargs.pop("chroma_boost", 2)
chroma_boost = kwargs.pop("chroma_boost", 1)

fig, ax = plt.subplots(figsize=figsize)
show_complex(
Expand Down Expand Up @@ -2570,6 +2596,19 @@ def probe_fourier(self):
asnumpy = self._asnumpy
return asnumpy(self._return_fourier_probe(self._probe))

@property
def probe_fourier_residual(self):
"""Current probe estimate in Fourier space"""
if not hasattr(self, "_probe"):
return None

asnumpy = self._asnumpy
return asnumpy(
self._return_fourier_probe(
self._probe, remove_initial_probe_aberrations=True
)
)

@property
def probe_centered(self):
"""Current probe estimate shifted to the center"""
Expand Down
Loading

0 comments on commit 3397349

Please sign in to comment.