From 58fb941882b51033b1772ca2e786c34d70576997 Mon Sep 17 00:00:00 2001 From: Alex Ganose Date: Tue, 10 Oct 2023 18:34:32 +0100 Subject: [PATCH] Tidy annotations --- src/ifermi/analysis.py | 5 +- src/ifermi/brillouin_zone.py | 18 ++-- src/ifermi/interpolate.py | 14 ++-- src/ifermi/kpoints.py | 7 +- src/ifermi/plot.py | 157 ++++++++++++++++++----------------- src/ifermi/slice.py | 31 ++++--- src/ifermi/surface.py | 59 +++++++------ 7 files changed, 161 insertions(+), 130 deletions(-) diff --git a/src/ifermi/analysis.py b/src/ifermi/analysis.py index d0f1af42..205de7d5 100644 --- a/src/ifermi/analysis.py +++ b/src/ifermi/analysis.py @@ -1,7 +1,8 @@ """Isosurface and isoline analysis functions.""" +from __future__ import annotations + import warnings -from typing import Union import numpy as np @@ -41,7 +42,7 @@ def isosurface_area(vertices: np.ndarray, faces: np.ndarray) -> float: def average_properties( vertices: np.ndarray, faces: np.ndarray, properties: np.ndarray, norm: bool = False -) -> Union[float, np.ndarray]: +) -> float | np.ndarray: """Average property across an isosurface. Args: diff --git a/src/ifermi/brillouin_zone.py b/src/ifermi/brillouin_zone.py index 4ae49c33..5ff0f73f 100644 --- a/src/ifermi/brillouin_zone.py +++ b/src/ifermi/brillouin_zone.py @@ -1,12 +1,16 @@ """Brillouin zone and slice geometries.""" +from __future__ import annotations + import itertools from dataclasses import dataclass, field -from typing import Optional +from typing import TYPE_CHECKING import numpy as np from monty.json import MSONable -from pymatgen.core.structure import Structure + +if TYPE_CHECKING: + from pymatgen.core.structure import Structure __all__ = ["ReciprocalSlice", "ReciprocalCell", "WignerSeitzCell"] @@ -23,10 +27,10 @@ class ReciprocalSlice(MSONable): the 3D Brillouin zone to points on the 2D slice. """ - reciprocal_space: "ReciprocalCell" + reciprocal_space: ReciprocalCell vertices: np.ndarray transformation: np.ndarray - _edges: Optional[np.ndarray] = field(default=None, init=False) + _edges: np.ndarray | None = field(default=None, init=False) def __post_init__(self): """Ensure all inputs are numpy arrays.""" @@ -66,7 +70,7 @@ class ReciprocalCell(MSONable): faces: list[list[int]] centers: np.ndarray normals: np.ndarray - _edges: Optional[np.ndarray] = field(default=None, init=False) + _edges: np.ndarray | None = field(default=None, init=False) def __post_init__(self): """Ensure all inputs are numpy arrays.""" @@ -76,7 +80,7 @@ def __post_init__(self): self.normals = np.array(self.normals) @classmethod - def from_structure(cls, structure: Structure) -> "ReciprocalCell": + def from_structure(cls, structure: Structure) -> ReciprocalCell: """Initialise the reciprocal cell from a structure. Args: @@ -195,7 +199,7 @@ class WignerSeitzCell(ReciprocalCell): """ @classmethod - def from_structure(cls, structure: Structure) -> "WignerSeitzCell": + def from_structure(cls, structure: Structure) -> WignerSeitzCell: """Initialise the Wigner-Seitz cell from a structure. Args: diff --git a/src/ifermi/interpolate.py b/src/ifermi/interpolate.py index 9edb7198..b456fe05 100644 --- a/src/ifermi/interpolate.py +++ b/src/ifermi/interpolate.py @@ -1,10 +1,14 @@ """Tools for Fourier and linear interpolation.""" -from typing import Optional +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np from pymatgen.electronic_structure.bandstructure import BandStructure -from pymatgen.electronic_structure.core import Spin + +if TYPE_CHECKING: + from pymatgen.electronic_structure.core import Spin __all__ = ["FourierInterpolator", "LinearInterpolator", "trim_bandstructure"] @@ -23,8 +27,8 @@ class FourierInterpolator: def __init__( self, band_structure: BandStructure, - magmom: Optional[np.ndarray] = None, - mommat: Optional[np.ndarray] = None, + magmom: np.ndarray | None = None, + mommat: np.ndarray | None = None, ): from BoltzTraP2.units import Angstrom from pymatgen.io.ase import AseAtomsAdaptor @@ -160,7 +164,7 @@ def __init__( kpoints: np.ndarray, energies: np.ndarray, lattice_matrix: np.ndarray, - mommat: Optional[np.ndarray] = None, + mommat: np.ndarray | None = None, ): self.kpoints = kpoints self.ebands = energies diff --git a/src/ifermi/kpoints.py b/src/ifermi/kpoints.py index c67f73c3..aff52885 100644 --- a/src/ifermi/kpoints.py +++ b/src/ifermi/kpoints.py @@ -1,12 +1,17 @@ """k-point manipulation functions.""" +from __future__ import annotations + import warnings +from typing import TYPE_CHECKING import numpy as np -from pymatgen.electronic_structure.bandstructure import BandStructure from ifermi.defaults import KTOL +if TYPE_CHECKING: + from pymatgen.electronic_structure.bandstructure import BandStructure + __all__ = [ "kpoints_to_first_bz", "kpoints_from_bandstructure", diff --git a/src/ifermi/plot.py b/src/ifermi/plot.py index 5072f371..3ec169f9 100644 --- a/src/ifermi/plot.py +++ b/src/ifermi/plot.py @@ -1,11 +1,11 @@ """Tools to plot FermiSurface and FermiSlice objects.""" +from __future__ import annotations + import os import warnings -from collections.abc import Collection from dataclasses import dataclass -from pathlib import Path -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any import numpy as np from matplotlib.colors import Colormap, Normalize @@ -13,9 +13,14 @@ from pymatgen.electronic_structure.core import Spin from ifermi.defaults import AZIMUTH, COLORMAP, ELEVATION, SCALE, SYMPREC, VECTOR_SPACING -from ifermi.slice import FermiSlice from ifermi.surface import FermiSurface +if TYPE_CHECKING: + from collections.abc import Collection + from pathlib import Path + + from ifermi.slice import FermiSlice + try: import mayavi.mlab as mlab except ImportError: @@ -125,10 +130,10 @@ class _FermiSurfacePlotData: colors: list[tuple[int, int, int]] properties: list[np.ndarray] arrows: list[tuple[np.ndarray, np.ndarray, np.ndarray]] - properties_colormap: Optional[Colormap] - arrow_colormap: Optional[Colormap] - cmin: Optional[float] - cmax: Optional[float] + properties_colormap: Colormap | None + arrow_colormap: Colormap | None + cmin: float | None + cmax: float | None hide_labels: bool hide_cell: bool @@ -139,10 +144,10 @@ class _FermiSlicePlotData: colors: list[tuple[int, int, int]] properties: list[np.ndarray] arrows: list[tuple[np.ndarray, np.ndarray, np.ndarray]] - properties_colormap: Optional[Colormap] - arrow_colormap: Optional[Colormap] - cmin: Optional[float] - cmax: Optional[float] + properties_colormap: Colormap | None + arrow_colormap: Colormap | None + cmin: float | None + cmax: float | None hide_labels: bool hide_cell: bool @@ -200,21 +205,21 @@ def get_symmetry_points( def get_plot( self, plot_type: str = "plotly", - spin: Optional[Spin] = None, - colors: Optional[Union[str, dict, list]] = None, + spin: Spin | None = None, + colors: str | dict | list | None = None, azimuth: float = AZIMUTH, elevation: float = ELEVATION, - color_properties: Union[str, bool] = True, - vector_properties: Union[str, bool] = False, - projection_axis: Optional[tuple[int, int, int]] = None, + color_properties: str | bool = True, + vector_properties: str | bool = False, + projection_axis: tuple[int, int, int] | None = None, vector_spacing: float = VECTOR_SPACING, - cmin: Optional[float] = None, - cmax: Optional[float] = None, - vnorm: Optional[float] = None, + cmin: float | None = None, + cmax: float | None = None, + vnorm: float | None = None, hide_surface: bool = False, hide_labels: bool = False, hide_cell: bool = False, - plot_index: Optional[Union[int, list, dict]] = None, + plot_index: int | list | dict | None = None, **plot_kwargs, ): """Plot the Fermi surface. @@ -333,17 +338,17 @@ def get_plot( def _get_plot_data( self, - spin: Optional[Spin] = None, + spin: Spin | None = None, azimuth: float = AZIMUTH, elevation: float = ELEVATION, - colors: Optional[Union[str, dict, list]] = None, - color_properties: Union[str, bool] = True, - vector_properties: Union[str, bool] = False, - projection_axis: Optional[tuple[int, int, int]] = None, + colors: str | dict | list | None = None, + color_properties: str | bool = True, + vector_properties: str | bool = False, + projection_axis: tuple[int, int, int] | None = None, vector_spacing: float = VECTOR_SPACING, - cmin: Optional[float] = None, - cmax: Optional[float] = None, - vnorm: Optional[float] = None, + cmin: float | None = None, + cmax: float | None = None, + vnorm: float | None = None, hide_surface: bool = False, hide_labels: bool = False, hide_cell: bool = False, @@ -424,13 +429,13 @@ def _get_plot_data( def _get_matplotlib_plot( self, plot_data: _FermiSurfacePlotData, - ax: Optional[Any] = None, - trisurf_kwargs: Optional[dict[str, Any]] = None, - cbar_kwargs: Optional[dict[str, Any]] = None, - quiver_kwargs: Optional[dict[str, Any]] = None, - bz_kwargs: Optional[dict[str, Any]] = None, - sym_pt_kwargs: Optional[dict[str, Any]] = None, - sym_label_kwargs: Optional[dict[str, Any]] = None, + ax: Any | None = None, + trisurf_kwargs: dict[str, Any] | None = None, + cbar_kwargs: dict[str, Any] | None = None, + quiver_kwargs: dict[str, Any] | None = None, + bz_kwargs: dict[str, Any] | None = None, + sym_pt_kwargs: dict[str, Any] | None = None, + sym_label_kwargs: dict[str, Any] | None = None, ): """Plot the Fermi surface using matplotlib. @@ -519,12 +524,12 @@ def _get_matplotlib_plot( def _get_plotly_plot( self, plot_data: _FermiSurfacePlotData, - mesh_kwargs: Optional[dict[str, Any]] = None, - arrow_line_kwargs: Optional[dict[str, Any]] = None, - arrow_cone_kwargs: Optional[dict[str, Any]] = None, - bz_kwargs: Optional[dict[str, Any]] = None, - sym_pt_kwargs: Optional[dict[str, Any]] = None, - sym_label_kwargs: Optional[dict[str, Any]] = None, + mesh_kwargs: dict[str, Any] | None = None, + arrow_line_kwargs: dict[str, Any] | None = None, + arrow_cone_kwargs: dict[str, Any] | None = None, + bz_kwargs: dict[str, Any] | None = None, + sym_pt_kwargs: dict[str, Any] | None = None, + sym_label_kwargs: dict[str, Any] | None = None, ): """Plot the Fermi surface using plotly. @@ -864,28 +869,28 @@ def get_symmetry_points( def get_plot( self, - ax: Optional[Any] = None, - spin: Optional[Spin] = None, - colors: Optional[Union[str, dict, list]] = None, - color_properties: Union[str, bool] = True, - vector_properties: Union[str, bool] = False, - projection_axis: Optional[tuple[int, int, int]] = None, - scale_linewidth: Union[bool, float] = False, + ax: Any | None = None, + spin: Spin | None = None, + colors: str | dict | list | None = None, + color_properties: str | bool = True, + vector_properties: str | bool = False, + projection_axis: tuple[int, int, int] | None = None, + scale_linewidth: bool | float = False, vector_spacing: float = VECTOR_SPACING, - cmin: Optional[float] = None, - cmax: Optional[float] = None, - vnorm: Optional[float] = None, + cmin: float | None = None, + cmax: float | None = None, + vnorm: float | None = None, hide_slice: bool = False, hide_labels: bool = False, hide_cell: bool = False, plot_index: list[int] = None, arrow_pivot: str = "tail", - slice_kwargs: Optional[dict[str, Any]] = None, - cbar_kwargs: Optional[dict[str, Any]] = None, - quiver_kwargs: Optional[dict[str, Any]] = None, - bz_kwargs: Optional[dict[str, Any]] = None, - sym_pt_kwargs: Optional[dict[str, Any]] = None, - sym_label_kwargs: Optional[dict[str, Any]] = None, + slice_kwargs: dict[str, Any] | None = None, + cbar_kwargs: dict[str, Any] | None = None, + quiver_kwargs: dict[str, Any] | None = None, + bz_kwargs: dict[str, Any] | None = None, + sym_pt_kwargs: dict[str, Any] | None = None, + sym_label_kwargs: dict[str, Any] | None = None, ): """Plot the Fermi slice. @@ -1097,15 +1102,15 @@ def get_plot( def _get_plot_data( self, - spin: Optional[Spin] = None, - colors: Optional[Union[str, dict, list]] = None, - color_properties: Union[str, bool] = True, - vector_properties: Union[str, bool] = False, - projection_axis: Optional[tuple[int, int, int]] = None, + spin: Spin | None = None, + colors: str | dict | list | None = None, + color_properties: str | bool = True, + vector_properties: str | bool = False, + projection_axis: tuple[int, int, int] | None = None, vector_spacing: float = VECTOR_SPACING, - cmin: Optional[float] = None, - cmax: Optional[float] = None, - vnorm: Optional[float] = None, + cmin: float | None = None, + cmax: float | None = None, + vnorm: float | None = None, hide_slice: bool = False, hide_labels: bool = False, hide_cell: bool = False, @@ -1199,7 +1204,7 @@ def show_plot(plot: Any): plot.show() -def save_plot(plot: Any, filename: Union[Path, str], scale: float = SCALE): +def save_plot(plot: Any, filename: Path | str, scale: float = SCALE): """Save a plot to file. Args: @@ -1255,8 +1260,8 @@ def get_plot_type(plot: Any) -> str: def get_isosurface_colors( - colors: Optional[Union[str, dict, list]], - fermi_object: Union[FermiSurface, FermiSlice], + colors: str | dict | list | None, + fermi_object: FermiSurface | FermiSlice, spins: list[Spin], ) -> list[tuple[float, float, float]]: """Get colors for each Fermi surface. @@ -1329,8 +1334,8 @@ def get_face_arrows( fermi_surface: FermiSurface, spins: list[Spin], vector_spacing: float, - vnorm: Optional[float], - projection_axis: Optional[tuple[int, int, int]], + vnorm: float | None, + projection_axis: tuple[int, int, int] | None, ) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]: """Get face arrows from vector properties. @@ -1398,8 +1403,8 @@ def get_segment_arrows( fermi_slice: FermiSlice, spins: Collection[Spin], vector_spacing: float, - vnorm: Optional[float], - projection_axis: Optional[tuple[int, int, int]], + vnorm: float | None, + projection_axis: tuple[int, int, int] | None, ) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]: """Get segment arrows from vector properties. @@ -1473,7 +1478,7 @@ def get_segment_arrows( def _get_properties_limits( - projections: list[np.ndarray], cmin: Optional[float], cmax: Optional[float] + projections: list[np.ndarray], cmin: float | None, cmax: float | None ) -> tuple[float, float]: """Get the min and max properties if they are not already set. @@ -1513,8 +1518,8 @@ def plotly_arrow( start: np.ndarray, stop: np.ndarray, color: tuple[float, float, float], - line_kwargs: Optional[dict[str, Any]] = None, - cone_kwargs: Optional[dict[str, Any]] = None, + line_kwargs: dict[str, Any] | None = None, + cone_kwargs: dict[str, Any] | None = None, ) -> tuple[Any, Any]: """Create an arrow object. diff --git a/src/ifermi/slice.py b/src/ifermi/slice.py index 88842161..78bec265 100644 --- a/src/ifermi/slice.py +++ b/src/ifermi/slice.py @@ -1,16 +1,23 @@ """Tools to generate Isolines and Fermi slices.""" + +from __future__ import annotations + import warnings -from collections.abc import Collection from dataclasses import dataclass -from typing import Optional, Union +from typing import TYPE_CHECKING import numpy as np from monty.json import MSONable -from pymatgen.core.structure import Structure from pymatgen.electronic_structure.core import Spin from ifermi.analysis import equivalent_vertices, longest_simple_paths -from ifermi.brillouin_zone import ReciprocalSlice + +if TYPE_CHECKING: + from collections.abc import Collection + + from pymatgen.core.structure import Structure + + from ifermi.brillouin_zone import ReciprocalSlice __all__ = ["Isoline", "FermiSlice", "process_lines", "interpolate_segments"] @@ -28,7 +35,7 @@ class Isoline(MSONable): segments: np.ndarray band_idx: int - properties: Optional[np.ndarray] = None + properties: np.ndarray | None = None def __post_init__(self): """Ensure all inputs are numpy arrays.""" @@ -161,8 +168,8 @@ def properties_ndim(self) -> int: def all_lines( self, - spins: Optional[Union[Spin, Collection[Spin]]] = None, - band_index: Optional[Union[int, list, dict]] = None, + spins: Spin | Collection[Spin] | None = None, + band_index: int | list | dict | None = None, ) -> list[np.ndarray]: """Get the segments for all isolines. @@ -210,9 +217,9 @@ def all_lines( def all_properties( self, - spins: Optional[Union[Spin, Collection[Spin]]] = None, - band_index: Optional[Union[int, list, dict]] = None, - projection_axis: Optional[tuple[int, int, int]] = None, + spins: Spin | Collection[Spin] | None = None, + band_index: int | list | dict | None = None, + projection_axis: tuple[int, int, int] | None = None, norm: bool = False, ) -> list[np.ndarray]: """Get the properties for all isolines. @@ -275,7 +282,7 @@ def from_fermi_surface( fermi_surface, plane_normal: tuple[int, int, int], distance: float = 0, - ) -> "FermiSlice": + ) -> FermiSlice: """Get a slice through the Fermi surface. The slice is defined by the intersection of a plane with the Fermi surface. @@ -337,7 +344,7 @@ def from_fermi_surface( return FermiSlice(dict(isolines), reciprocal_slice, fermi_surface.structure) @classmethod - def from_dict(cls, d) -> "FermiSlice": + def from_dict(cls, d) -> FermiSlice: """Return FermiSlice object from a dict.""" fs = super().from_dict(d) fs.isolines = {Spin(int(k)): v for k, v in fs.isolines.items()} diff --git a/src/ifermi/surface.py b/src/ifermi/surface.py index f141c13f..7f1457de 100644 --- a/src/ifermi/surface.py +++ b/src/ifermi/surface.py @@ -1,21 +1,26 @@ """Tools to generate isosurfaces and Fermi surfaces.""" +from __future__ import annotations + import warnings -from collections.abc import Collection from copy import deepcopy from dataclasses import dataclass -from typing import Optional, Union +from typing import TYPE_CHECKING import numpy as np from monty.dev import requires from monty.json import MSONable, jsanitize -from pymatgen.core.structure import Structure -from pymatgen.electronic_structure.bandstructure import BandStructure from pymatgen.electronic_structure.core import Spin from ifermi.brillouin_zone import ReciprocalCell, WignerSeitzCell from ifermi.interpolate import LinearInterpolator +if TYPE_CHECKING: + from collections.abc import Collection + + from pymatgen.core.structure import Structure + from pymatgen.electronic_structure.bandstructure import BandStructure + try: import mcubes except ImportError: @@ -54,9 +59,9 @@ class Isosurface(MSONable): vertices: np.ndarray faces: np.ndarray band_idx: int - properties: Optional[np.ndarray] = None - dimensionality: Optional[str] = None - orientation: Optional[tuple[int, int, int]] = None + properties: np.ndarray | None = None + dimensionality: str | None = None + orientation: tuple[int, int, int] | None = None def __post_init__(self): """Ensure all inputs are numpy arrays.""" @@ -98,8 +103,8 @@ def properties_ndim(self) -> int: return self.properties.ndim def average_properties( - self, norm: bool = False, projection_axis: Optional[tuple[int, int, int]] = None - ) -> Union[float, np.ndarray]: + self, norm: bool = False, projection_axis: tuple[int, int, int] | None = None + ) -> float | np.ndarray: """Average property across isosurface. Args: @@ -225,8 +230,8 @@ def has_properties(self) -> bool: return all(all(i.has_properties for i in s) for s in self.isosurfaces.values()) def average_properties( - self, norm: bool = False, projection_axis: Optional[tuple[int, int, int]] = None - ) -> Union[float, np.ndarray]: + self, norm: bool = False, projection_axis: tuple[int, int, int] | None = None + ) -> float | np.ndarray: """Average property across the full Fermi surface. Args: @@ -250,8 +255,8 @@ def average_properties( return scaled_average / total_area def average_properties_surfaces( - self, norm: bool = False, projection_axis: Optional[tuple[int, int, int]] = None - ) -> dict[Spin, list[Union[float, np.ndarray]]]: + self, norm: bool = False, projection_axis: tuple[int, int, int] | None = None + ) -> dict[Spin, list[float | np.ndarray]]: """Average property for each isosurface in the Fermi surface. Args: @@ -290,8 +295,8 @@ def spins(self): def all_vertices_faces( self, - spins: Optional[Union[Spin, Collection[Spin]]] = None, - band_index: Optional[Union[int, list, dict]] = None, + spins: Spin | Collection[Spin] | None = None, + band_index: int | list | dict | None = None, ) -> list[tuple[np.ndarray, np.ndarray]]: """Get the vertices and faces for all isosurfaces. @@ -339,9 +344,9 @@ def all_vertices_faces( def all_properties( self, - spins: Optional[Union[Spin, Collection[Spin]]] = None, - band_index: Optional[Union[int, list, dict]] = None, - projection_axis: Optional[tuple[int, int, int]] = None, + spins: Spin | Collection[Spin] | None = None, + band_index: int | list | dict | None = None, + projection_axis: tuple[int, int, int] | None = None, norm: bool = False, ) -> list[np.ndarray]: """Get the properties for all isosurfaces. @@ -407,13 +412,13 @@ def from_band_structure( band_structure: BandStructure, mu: float = 0.0, wigner_seitz: bool = False, - decimate_factor: Optional[float] = None, + decimate_factor: float | None = None, decimate_method: str = "quadric", smooth: bool = False, - property_data: Optional[dict[Spin, np.ndarray]] = None, - property_kpoints: Optional[np.ndarray] = None, + property_data: dict[Spin, np.ndarray] | None = None, + property_kpoints: np.ndarray | None = None, calculate_dimensionality: bool = False, - ) -> "FermiSurface": + ) -> FermiSurface: """Create a FermiSurface from a pymatgen band structure object. Args: @@ -515,7 +520,7 @@ def get_fermi_slice(self, plane_normal: tuple[int, int, int], distance: float = return FermiSlice.from_fermi_surface(self, plane_normal, distance=distance) @classmethod - def from_dict(cls, d) -> "FermiSurface": + def from_dict(cls, d) -> FermiSurface: """Return FermiSurface object from dict.""" fs = super().from_dict(d) fs.isosurfaces = {Spin(int(k)): v for k, v in fs.isosurfaces.items()} @@ -533,11 +538,11 @@ def compute_isosurfaces( kpoints: np.ndarray, fermi_level: float, reciprocal_space: ReciprocalCell, - decimate_factor: Optional[float] = None, + decimate_factor: float | None = None, decimate_method: str = "quadric", smooth: bool = False, calculate_dimensionality: bool = False, - property_interpolator: Optional[LinearInterpolator] = None, + property_interpolator: LinearInterpolator | None = None, ) -> dict[Spin, list[Isosurface]]: """Compute the isosurfaces at a particular energy level. @@ -616,11 +621,11 @@ def _calculate_band_isosurfaces( spacing: np.ndarray, reference: np.ndarray, reciprocal_space: ReciprocalCell, - decimate_factor: Optional[float], + decimate_factor: float | None, decimate_method: str, smooth: bool, calculate_dimensionality: bool, - property_interpolator: Optional[LinearInterpolator], + property_interpolator: LinearInterpolator | None, ): """Helper function to calculate the connected isosurfaces for a band.""" from skimage.measure import marching_cubes