diff --git a/.github/workflows/type-checking.yaml b/.github/workflows/type-checking.yaml new file mode 100644 index 00000000000..51f9fc4e88a --- /dev/null +++ b/.github/workflows/type-checking.yaml @@ -0,0 +1,38 @@ +name: type checking + +on: + push: + branches: + - main + pull_request: + paths-ignore: + - "doc/**" + - README.md + +jobs: + build: + runs-on: ubuntu-latest + name: mypy + timeout-minutes: 60 + + concurrency: + group: ${{ github.ref }}-dev + cancel-in-progress: true + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Build yt + type check deps + run: | + python3 -m pip install --upgrade pip + python3 -m pip install -e .[typecheck] + + - name: Run mypy + run: mypy yt diff --git a/pyproject.toml b/pyproject.toml index f2c10880fa4..bf0d062579e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,3 +129,15 @@ ignore = [ "yt/utilities/lib/ragged_arrays.c", "yt/utilities/lib/write_array.c", ] + + +[tool.mypy] +python_version = 3.7 +ignore_missing_imports = true +warn_unused_configs = true +warn_unused_ignores = true +warn_unreachable = true +show_error_context = true + +# color_maps can be reactivated after https://github.com/yt-project/yt/pull/3441 is merged +exclude = "(test_*|lodgeit|color_maps)" diff --git a/setup.cfg b/setup.cfg index 414f9a7e8b1..6b4bd9678ec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -107,6 +107,13 @@ test = nose-timer~=1.0.0 pytest>=6.1 pytest-xdist~=2.1.0 +typecheck = + mypy==0.910 + types-PyYAML==5.4.10 + types-chardet==4.0.0 + types-requests==2.25.9 + types-setuptools==57.4.0 + types-toml==0.10.0 [flake8] max-line-length = 88 diff --git a/yt/__init__.py b/yt/__init__.py index 207580be25f..d9727ff465e 100644 --- a/yt/__init__.py +++ b/yt/__init__.py @@ -69,7 +69,7 @@ uunion1d, uvstack, ) -from yt.units.unit_object import define_unit +from yt.units.unit_object import define_unit # type: ignore from yt.utilities.logger import set_log_level, ytLogger as mylog frontends = _frontend_container() @@ -89,7 +89,7 @@ load_unstructured_mesh, ) from yt.testing import run_nose -from yt.units.unit_systems import UnitSystem, unit_system_registry +from yt.units.unit_systems import UnitSystem, unit_system_registry # type: ignore # Import some helpful math utilities from yt.utilities.math_utils import ortho_find, periodic_position, quartiles diff --git a/yt/_maintenance/backports.py b/yt/_maintenance/backports.py index d3ce83a877b..a1b7ad41eb1 100644 --- a/yt/_maintenance/backports.py +++ b/yt/_maintenance/backports.py @@ -21,7 +21,7 @@ import sys if sys.version_info < (3, 8): - from _thread import RLock + from _thread import RLock # type: ignore from typing import List GenericAlias = type(List[int]) diff --git a/yt/_typing.py b/yt/_typing.py new file mode 100644 index 00000000000..98c7c0ed128 --- /dev/null +++ b/yt/_typing.py @@ -0,0 +1,4 @@ +from typing import List, Optional, Tuple + +FieldDescT = Tuple[str, Tuple[str, List[str], Optional[str]]] +KnownFieldsT = Tuple[FieldDescT, ...] diff --git a/yt/config.py b/yt/config.py index fd943e5b95a..d2e3868b9c3 100644 --- a/yt/config.py +++ b/yt/config.py @@ -190,7 +190,7 @@ def _repr_json_(self): _local_config_file = YTConfig.get_local_config_file() if not os.path.exists(_global_config_file): - cfg = {"yt": {}} + cfg = {"yt": {}} # type: ignore try: with open(_global_config_file, mode="w") as fd: toml.dump(cfg, fd) diff --git a/yt/data_objects/construction_data_containers.py b/yt/data_objects/construction_data_containers.py index 6f86a5176d0..76e164e6384 100644 --- a/yt/data_objects/construction_data_containers.py +++ b/yt/data_objects/construction_data_containers.py @@ -24,8 +24,8 @@ from yt.geometry import particle_deposit as particle_deposit from yt.geometry.coordinates.cartesian_coordinates import all_data from yt.loaders import load_uniform_grid -from yt.units.unit_object import Unit -from yt.units.yt_array import YTArray, uconcatenate +from yt.units.unit_object import Unit # type: ignore +from yt.units.yt_array import YTArray, uconcatenate # type: ignore from yt.utilities.exceptions import ( YTNoAPIKey, YTNotInsideNotebook, diff --git a/yt/data_objects/data_containers.py b/yt/data_objects/data_containers.py index 4dfe2e5b912..923f15f8910 100644 --- a/yt/data_objects/data_containers.py +++ b/yt/data_objects/data_containers.py @@ -1,6 +1,8 @@ +import abc import weakref from collections import defaultdict from contextlib import contextmanager +from typing import List, Tuple, Union import numpy as np @@ -9,7 +11,7 @@ from yt.fields.field_exceptions import NeedsGridType from yt.frontends.ytdata.utilities import save_as_dataset from yt.funcs import get_output_filename, is_sequence, iter_fields, mylog -from yt.units.yt_array import YTArray, YTQuantity, uconcatenate +from yt.units.yt_array import YTArray, YTQuantity, uconcatenate # type: ignore from yt.utilities.amr_kdtree.api import AMRKDTree from yt.utilities.exceptions import ( YTCouldNotGenerateField, @@ -50,7 +52,7 @@ def _get_ipython_key_completion(ds): return tuple_keys + fnames -class YTDataContainer: +class YTDataContainer(abc.ABC): """ Generic YTDataContainer container. By itself, will attempt to generate field, read fields (method defined by derived classes) @@ -59,13 +61,14 @@ class YTDataContainer: _chunk_info = None _num_ghost_zones = 0 - _con_args = () + _con_args: Tuple[str, ...] = () _skip_add = False - _container_fields = () - _tds_attrs = () - _tds_fields = () + _container_fields: Tuple[Union[str, Tuple[str, str]], ...] = () + _tds_attrs: Tuple[str, ...] = () + _tds_fields: Tuple[str, ...] = () _field_cache = None _index = None + _key_fields: List[str] def __init__(self, ds, field_parameters): """ @@ -428,8 +431,6 @@ def _parameter_iterate(self, seq): yield obj obj.field_parameters = old_fp - _key_fields = None - def write_out(self, filename, fields=None, format="%0.16e"): """Write out the YTDataContainer object in a text file. @@ -469,9 +470,6 @@ def write_out(self, filename, fields=None, format="%0.16e"): if fields is None: fields = sorted(self.field_data.keys()) - if self._key_fields is None: - raise ValueError - field_order = [("index", k) for k in self._key_fields] diff_fields = [field for field in fields if field not in field_order] field_order += diff_fields diff --git a/yt/data_objects/image_array.py b/yt/data_objects/image_array.py index c18c0e58f45..af8f1cc9818 100644 --- a/yt/data_objects/image_array.py +++ b/yt/data_objects/image_array.py @@ -1,13 +1,13 @@ import warnings import numpy as np +from unyt import unyt_array from yt.config import ytcfg -from yt.units.yt_array import YTArray from yt.visualization.image_writer import write_bitmap, write_image -class ImageArray(YTArray): +class ImageArray(unyt_array): r"""A custom Numpy ndarray used for images. This differs from ndarray in that you can optionally specify an diff --git a/yt/data_objects/index_subobjects/grid_patch.py b/yt/data_objects/index_subobjects/grid_patch.py index 8e7ed4ae3e0..5940bec4daf 100644 --- a/yt/data_objects/index_subobjects/grid_patch.py +++ b/yt/data_objects/index_subobjects/grid_patch.py @@ -282,7 +282,7 @@ def get_vertex_centered_data( "a single field as an argument." ) warnings.warn(message, DeprecationWarning, stacklevel=2) - fields = [fields] + fields = [fields] # type: ignore # Make sure the field list has only unique entries fields = list(set(fields)) diff --git a/yt/data_objects/index_subobjects/octree_subset.py b/yt/data_objects/index_subobjects/octree_subset.py index 885706b5f61..d39cc8b24ce 100644 --- a/yt/data_objects/index_subobjects/octree_subset.py +++ b/yt/data_objects/index_subobjects/octree_subset.py @@ -1,7 +1,9 @@ from contextlib import contextmanager from itertools import product, repeat +from typing import Tuple import numpy as np +from unyt import unyt_array import yt.geometry.particle_deposit as particle_deposit import yt.geometry.particle_smooth as particle_smooth @@ -9,8 +11,7 @@ YTSelectionContainer, ) from yt.geometry.particle_oct_container import ParticleOctreeContainer -from yt.units.dimensions import length -from yt.units.yt_array import YTArray +from yt.units.dimensions import length # type: ignore from yt.utilities.exceptions import ( YTFieldTypeNotFound, YTInvalidPositionArray, @@ -37,7 +38,7 @@ class OctreeSubset(YTSelectionContainer): _num_ghost_zones = 0 _type_name = "octree_subset" _skip_add = True - _con_args = ("base_region", "domain", "ds") + _con_args: Tuple[str, ...] = ("base_region", "domain", "ds") _domain_offset = 0 _cell_count = -1 _block_order = "C" @@ -658,7 +659,7 @@ def __iter__(self): yield i, OctreeSubsetBlockSlicePosition(i, self) -class YTPositionArray(YTArray): +class YTPositionArray(unyt_array): @property def morton(self): self.validate() diff --git a/yt/data_objects/particle_filters.py b/yt/data_objects/particle_filters.py index 65ab819b6d8..0626a40c408 100644 --- a/yt/data_objects/particle_filters.py +++ b/yt/data_objects/particle_filters.py @@ -1,12 +1,13 @@ import copy from contextlib import contextmanager +from typing import Dict from yt.fields.field_info_container import NullFunc, TranslationFunc from yt.funcs import mylog from yt.utilities.exceptions import YTIllDefinedFilter # One to one mapping -filter_registry = {} +filter_registry: Dict[str, "ParticleFilter"] = {} class DummyFieldInfo: diff --git a/yt/data_objects/profiles.py b/yt/data_objects/profiles.py index ee1f9de757d..e704c32edca 100644 --- a/yt/data_objects/profiles.py +++ b/yt/data_objects/profiles.py @@ -5,7 +5,7 @@ from yt.fields.derived_field import DerivedField from yt.frontends.ytdata.utilities import save_as_dataset from yt.funcs import get_output_filename, is_sequence, iter_fields, mylog -from yt.units.unit_object import Unit +from yt.units.unit_object import Unit # type: ignore from yt.units.yt_array import YTQuantity, array_like_field from yt.utilities.exceptions import ( YTIllDefinedBounds, diff --git a/yt/data_objects/selection_objects/data_selection_objects.py b/yt/data_objects/selection_objects/data_selection_objects.py index 4b7fa7571f8..3f96223e6d8 100644 --- a/yt/data_objects/selection_objects/data_selection_objects.py +++ b/yt/data_objects/selection_objects/data_selection_objects.py @@ -1,3 +1,4 @@ +import abc import itertools import uuid from collections import defaultdict @@ -31,13 +32,13 @@ ) -class YTSelectionContainer(YTDataContainer, ParallelAnalysisInterface): +class YTSelectionContainer(YTDataContainer, ParallelAnalysisInterface, abc.ABC): _locked = False _sort_by = None _selector = None _current_chunk = None _data_source = None - _dimensionality = None + _dimensionality: int _max_level = None _min_level = None _derived_quantity_chunking = "io" diff --git a/yt/data_objects/static_output.py b/yt/data_objects/static_output.py index 7295c6a6ba6..1f8d3d001b1 100644 --- a/yt/data_objects/static_output.py +++ b/yt/data_objects/static_output.py @@ -3,12 +3,14 @@ import itertools import os import pickle +import sys import time import warnings import weakref from collections import defaultdict from importlib.util import find_spec from stat import ST_CTIME +from typing import Optional, Tuple, Type, Union import numpy as np from unyt.exceptions import UnitConversionError, UnitParseError @@ -32,11 +34,15 @@ SpectralCubeCoordinateHandler, SphericalCoordinateHandler, ) +from yt.geometry.geometry_handler import Index from yt.units import UnitContainer, _wrap_display_ytarray, dimensions -from yt.units.dimensions import current_mks -from yt.units.unit_object import Unit, define_unit -from yt.units.unit_registry import UnitRegistry -from yt.units.unit_systems import create_code_unit_system, unit_system_registry +from yt.units.dimensions import current_mks # type: ignore +from yt.units.unit_object import Unit, define_unit # type: ignore +from yt.units.unit_registry import UnitRegistry # type: ignore +from yt.units.unit_systems import ( # type: ignore + create_code_unit_system, + unit_system_registry, +) from yt.units.yt_array import YTArray, YTQuantity from yt.utilities.cosmology import Cosmology from yt.utilities.exceptions import ( @@ -51,12 +57,18 @@ from yt.utilities.parallel_tools.parallel_analysis_interface import parallel_root_only from yt.utilities.parameter_file_storage import NoParameterShelf, ParameterFileStore +if sys.version_info >= (3, 9): + from collections.abc import MutableMapping +else: + from typing import MutableMapping # We want to support the movie format in the future. # When such a thing comes to pass, I'll move all the stuff that is constant up # to here, and then have it instantiate EnzoDatasets as appropriate. -_cached_datasets = weakref.WeakValueDictionary() +_cached_datasets: MutableMapping[ + Union[int, str], "Dataset" +] = weakref.WeakValueDictionary() _ds_store = ParameterFileStore() @@ -114,20 +126,20 @@ class Dataset(abc.ABC): default_fluid_type = "gas" default_field = ("gas", "density") - fluid_types = ("gas", "deposit", "index") - particle_types = ("io",) # By default we have an 'all' - particle_types_raw = ("io",) + fluid_types: Tuple[str, ...] = ("gas", "deposit", "index") + particle_types: Optional[Tuple[str, ...]] = ("io",) # By default we have an 'all' + particle_types_raw: Optional[Tuple[str, ...]] = ("io",) geometry = "cartesian" coordinates = None storage_filename = None particle_unions = None known_filters = None - _index_class = None + _index_class: Type[Index] field_units = None derived_field_list = requires_index("derived_field_list") fields = requires_index("fields") _instantiated = False - _unique_identifier = None + _unique_identifier: Optional[Union[str, int]] = None _particle_type_counts = None _proj_type = "quad_proj" _ionization_label_format = "roman_numeral" diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index bae16b5c426..4b42735b261 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -4,7 +4,7 @@ import os import weakref from functools import wraps -from typing import Optional +from typing import Optional, Type import numpy as np from more_itertools import always_iterable @@ -142,6 +142,10 @@ class DatasetSeries: ... SlicePlot(ds, "x", ("gas", "density")).save() """ + # this annotation should really be Optional[Type[Dataset]] + # but we cannot import the yt.data_objects.static_output.Dataset + # class here without creating a circular import for now + _dataset_cls: Optional[Type] = None def __init_subclass__(cls, *args, **kwargs): super().__init_subclass__(*args, **kwargs) @@ -365,8 +369,6 @@ def from_output_log(cls, output_log, line_prefix="DATASET WRITTEN", parallel=Tru obj = cls(filenames, parallel=parallel) return obj - _dataset_cls = None - def _load(self, output_fn, *, hint: Optional[str] = None, **kwargs): from yt.loaders import load diff --git a/yt/fields/derived_field.py b/yt/fields/derived_field.py index 81551b4114b..7057391ee2e 100644 --- a/yt/fields/derived_field.py +++ b/yt/fields/derived_field.py @@ -2,13 +2,13 @@ import inspect import re import warnings -from typing import Tuple +from typing import Optional, Tuple, Union from more_itertools import always_iterable import yt.units.dimensions as ytdims from yt.funcs import iter_fields, validate_field_key -from yt.units.unit_object import Unit +from yt.units.unit_object import Unit # type: ignore from yt.utilities.exceptions import YTFieldNotFound from yt.utilities.logger import ytLogger as mylog @@ -118,7 +118,7 @@ def __init__( name: Tuple[str, str], sampling_type, function, - units=None, + units: Optional[Union[str, bytes, Unit]] = None, take_log=True, validators=None, particle_type=None, @@ -163,6 +163,7 @@ def __init__( self.validators = list(always_iterable(validators)) # handle units + self.units: Optional[Union[str, bytes, Unit]] if units is None: self.units = "" elif isinstance(units, str): diff --git a/yt/fields/domain_context.py b/yt/fields/domain_context.py index 4468e4bb744..0ebc04c9b01 100644 --- a/yt/fields/domain_context.py +++ b/yt/fields/domain_context.py @@ -1,14 +1,17 @@ +import abc +from typing import Tuple + domain_context_registry = {} -class DomainContext: +class DomainContext(abc.ABC): class __metaclass__(type): def __init__(cls, name, b, d): type.__init__(cls, name, b, d) domain_context_registry[name] = cls - _known_fluid_fields = () - _known_particle_fields = () + _known_fluid_fields: Tuple[Tuple[str, str], ...] + _known_particle_fields: Tuple[Tuple[str, str], ...] def __init__(self, ds): self.ds = ds diff --git a/yt/fields/field_info_container.py b/yt/fields/field_info_container.py index 87a1be0b358..05c2935ae31 100644 --- a/yt/fields/field_info_container.py +++ b/yt/fields/field_info_container.py @@ -5,11 +5,12 @@ import numpy as np from unyt.exceptions import UnitConversionError +from yt._typing import KnownFieldsT from yt.fields.field_exceptions import NeedsConfiguration from yt.funcs import mylog, only_on_root from yt.geometry.geometry_handler import is_curvilinear -from yt.units.dimensions import dimensionless -from yt.units.unit_object import Unit +from yt.units.dimensions import dimensionless # type: ignore +from yt.units.unit_object import Unit # type: ignore from yt.utilities.exceptions import ( YTCoordinateNotImplemented, YTDomainOverflow, @@ -49,9 +50,9 @@ class FieldInfoContainer(dict): """ fallback = None - known_other_fields = () - known_particle_fields = () - extra_union_fields = () + known_other_fields: KnownFieldsT = () + known_particle_fields: KnownFieldsT = () + extra_union_fields: Tuple[Tuple[str, str], ...] = () def __init__(self, ds, field_list, slice_info=None): self._show_field_errors = [] @@ -418,7 +419,7 @@ def alias( alias_name, original_name, units=None, - deprecate: Optional[Tuple[str]] = None, + deprecate: Optional[Tuple[str, str]] = None, ): """ Alias one field to another field. diff --git a/yt/fields/fluid_fields.py b/yt/fields/fluid_fields.py index acd13344dca..3427d43f8d1 100644 --- a/yt/fields/fluid_fields.py +++ b/yt/fields/fluid_fields.py @@ -1,6 +1,6 @@ import numpy as np -from yt.units.unit_object import Unit +from yt.units.unit_object import Unit # type: ignore from yt.utilities.chemical_formulas import compute_mu from yt.utilities.lib.misc_utilities import obtain_relative_velocity_vector diff --git a/yt/fields/magnetic_field.py b/yt/fields/magnetic_field.py index 4442b96aa3c..5fea3002993 100644 --- a/yt/fields/magnetic_field.py +++ b/yt/fields/magnetic_field.py @@ -2,7 +2,7 @@ from yt.fields.derived_field import ValidateParameter from yt.units import dimensions -from yt.units.yt_array import ustack +from yt.units.yt_array import ustack # type: ignore from yt.utilities.math_utils import get_sph_phi_component, get_sph_theta_component from .field_plugin_registry import register_field_plugin @@ -278,7 +278,7 @@ def setup_magnetic_field_aliases(registry, ds_ftype, ds_fields, ftype="gas"): # Add fields if sampling_type in ["cell", "local"]: # Grid dataset case - def mag_field(fd): + def mag_field_from_field(fd): def _mag_field(field, data): return data[fd].to(field.units) @@ -288,12 +288,12 @@ def _mag_field(field, data): registry.add_field( (ftype, f"magnetic_field_{ax}"), sampling_type=sampling_type, - function=mag_field(fd), + function=mag_field_from_field(fd), units=units, ) else: # Particle dataset case - def mag_field(ax): + def mag_field_from_ax(ax): def _mag_field(field, data): return data[ds_field][:, "xyz".index(ax)] @@ -304,7 +304,7 @@ def _mag_field(field, data): registry.add_field( (ds_ftype, fname), sampling_type=sampling_type, - function=mag_field(ax), + function=mag_field_from_ax(ax), units=units, ) sph_ptypes = getattr(registry.ds, "_sph_ptypes", tuple()) diff --git a/yt/fields/particle_fields.py b/yt/fields/particle_fields.py index 7f3c7aa80f6..41d25bcf1a9 100644 --- a/yt/fields/particle_fields.py +++ b/yt/fields/particle_fields.py @@ -1,7 +1,7 @@ import numpy as np from yt.fields.derived_field import ValidateParameter, ValidateSpatial -from yt.units.yt_array import uconcatenate, ucross +from yt.units.yt_array import uconcatenate, ucross # type: ignore from yt.utilities.lib.misc_utilities import ( obtain_position_vector, obtain_relative_velocity_vector, diff --git a/yt/fields/tests/test_vector_fields.py b/yt/fields/tests/test_vector_fields.py index 2f29b7b1a7a..a2cd1ec11a0 100644 --- a/yt/fields/tests/test_vector_fields.py +++ b/yt/fields/tests/test_vector_fields.py @@ -1,7 +1,7 @@ import numpy as np from yt.testing import assert_allclose_units, fake_random_ds, requires_file -from yt.units import cm, s +from yt.units import cm, s # type: ignore from yt.utilities.answer_testing.framework import data_dir_load from yt.visualization.volume_rendering.off_axis_projection import off_axis_projection diff --git a/yt/frontends/adaptahop/data_structures.py b/yt/frontends/adaptahop/data_structures.py index 4213f94e0fc..91ddfac3d99 100644 --- a/yt/frontends/adaptahop/data_structures.py +++ b/yt/frontends/adaptahop/data_structures.py @@ -11,6 +11,7 @@ import re import stat from itertools import product +from typing import Optional import numpy as np @@ -21,7 +22,7 @@ from yt.frontends.halo_catalog.data_structures import HaloCatalogFile from yt.funcs import mylog, setdefaultattr from yt.geometry.particle_geometry_handler import ParticleIndex -from yt.units import Mpc +from yt.units import Mpc # type: ignore from yt.utilities.cython_fortran_utils import FortranFile from .definitions import ADAPTAHOP_TEMPLATES, ATTR_T, HEADER_ATTRIBUTES @@ -57,8 +58,8 @@ class AdaptaHOPDataset(Dataset): # AdaptaHOP internally assumes 1Mpc == 3.0824cm _code_length_to_Mpc = (1.0 * Mpc).to("cm").value / 3.08e24 - _header_attributes: ATTR_T = None - _halo_attributes: ATTR_T = None + _header_attributes: Optional[ATTR_T] = None + _halo_attributes: Optional[ATTR_T] = None def __init__( self, @@ -94,7 +95,7 @@ def _set_code_unit_attributes(self): setdefaultattr(self, "velocity_unit", self.quan(1.0, "km / s")) setdefaultattr(self, "time_unit", self.length_unit / self.velocity_unit) - def _guess_headers_from_file(self, filename) -> ATTR_T: + def _guess_headers_from_file(self, filename) -> None: with FortranFile(filename) as fpu: ok = False for dp, longint in product((True, False), (True, False)): diff --git a/yt/frontends/adaptahop/definitions.py b/yt/frontends/adaptahop/definitions.py index 824de1fe469..dda62353761 100644 --- a/yt/frontends/adaptahop/definitions.py +++ b/yt/frontends/adaptahop/definitions.py @@ -9,7 +9,7 @@ from yt.funcs import mylog -ATTR_T = Tuple[Tuple[Union[Tuple[str, str], str], int, str]] +ATTR_T = Tuple[Tuple[Union[Tuple[str, ...], str], int, str], ...] def HEADER_ATTRIBUTES(*, double: bool, longint: bool) -> ATTR_T: @@ -29,8 +29,6 @@ def HEADER_ATTRIBUTES(*, double: bool, longint: bool) -> ATTR_T: class AdaptaHOPDefTemplate(abc.ABC): - templates = [] - def __init_subclass__(cls, *args, **kwargs): super().__init_subclass__(*args, **kwargs) mylog.debug("Registering AdaptaHOP template class %s", cls.__name__) diff --git a/yt/frontends/adaptahop/fields.py b/yt/frontends/adaptahop/fields.py index a6339df11db..fe5edd7ef62 100644 --- a/yt/frontends/adaptahop/fields.py +++ b/yt/frontends/adaptahop/fields.py @@ -5,8 +5,7 @@ """ - - +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer m_units = "1e11 * Msun" @@ -18,9 +17,7 @@ class AdaptaHOPFieldInfo(FieldInfoContainer): - known_other_fields = () - - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_identifier", ("", [], "Halo Identity")), ("raw_position_x", (r_units, [], None)), ("raw_position_y", (r_units, [], None)), diff --git a/yt/frontends/adaptahop/io.py b/yt/frontends/adaptahop/io.py index 1b07cd3d494..6bc11dd3934 100644 --- a/yt/frontends/adaptahop/io.py +++ b/yt/frontends/adaptahop/io.py @@ -9,6 +9,7 @@ from functools import partial from operator import attrgetter +from typing import List, Tuple, Union import numpy as np @@ -199,15 +200,18 @@ def _todo_from_attributes(attributes: ATTR_T, halo_attributes: ATTR_T): # attributes. This is used to skip fields most of the fields when reading # the tree_brick files. iskip = 0 - todo = [] + todo: List[Union[int, List[Tuple[Union[Tuple[str, ...], str], int, str]]]] = [] - attributes = set(attributes) + attributes = tuple(set(attributes)) for i, (attrs, l, k) in enumerate(halo_attributes): - if not isinstance(attrs, tuple): - attrs_list = (attrs,) - else: + attrs_list: Tuple[str, ...] + if isinstance(attrs, tuple): + if not all(isinstance(a, str) for a in attrs): + raise TypeError(f"Expected a single str or a tuple of str, got {attrs}") attrs_list = attrs + else: + attrs_list = (attrs,) ok = False for attr in attrs_list: if attr in attributes: @@ -228,6 +232,8 @@ def _todo_from_attributes(attributes: ATTR_T, halo_attributes: ATTR_T): todo.append(iskip) todo.append([]) iskip = 0 + if not isinstance(todo[-1], list): + raise TypeError todo[-1].append((attrs, l, k)) state = "read" else: diff --git a/yt/frontends/ahf/fields.py b/yt/frontends/ahf/fields.py index 4ca77d5f2a8..fdd1dbfba6d 100644 --- a/yt/frontends/ahf/fields.py +++ b/yt/frontends/ahf/fields.py @@ -1,3 +1,4 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer m_units = "Msun/h" @@ -7,11 +8,10 @@ class AHFHalosFieldInfo(FieldInfoContainer): - known_other_fields = () # See http://popia.ft.uam.es/AHF/files/AHF.pdf # and search for '*.AHF_halos'. - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("ID", ("", ["particle_identifier"], None)), ("hostHalo", ("", [], None)), ("numSubStruct", ("", [], None)), diff --git a/yt/frontends/amrvac/data_structures.py b/yt/frontends/amrvac/data_structures.py index 4cf8dc32986..a4c6d1f95c9 100644 --- a/yt/frontends/amrvac/data_structures.py +++ b/yt/frontends/amrvac/data_structures.py @@ -37,7 +37,10 @@ def _is_relative_to(self, *other): except ValueError: return False - Path.is_relative_to = _is_relative_to + Path.is_relative_to = _is_relative_to # type: ignore +else: + # an else block is mandated for pyupgrade to enable auto-cleanup + pass class AMRVACGrid(AMRGridPatch): diff --git a/yt/frontends/art/data_structures.py b/yt/frontends/art/data_structures.py index 87c99b34049..50600c8e254 100644 --- a/yt/frontends/art/data_structures.py +++ b/yt/frontends/art/data_structures.py @@ -2,6 +2,7 @@ import os import struct import weakref +from typing import Type import numpy as np @@ -28,7 +29,7 @@ b2t, ) from yt.funcs import mylog, setdefaultattr -from yt.geometry.geometry_handler import YTDataChunk +from yt.geometry.geometry_handler import Index, YTDataChunk from yt.geometry.oct_container import ARTOctreeContainer from yt.geometry.oct_geometry_handler import OctreeIndex from yt.geometry.particle_geometry_handler import ParticleIndex @@ -132,7 +133,7 @@ def _chunk_io(self, dobj, cache=True, local_only=False): class ARTDataset(Dataset): - _index_class = ARTIndex + _index_class: Type[Index] = ARTIndex _field_info_class = ARTFieldInfo def __init__( @@ -388,7 +389,6 @@ def _is_valid(cls, filename, *args, **kwargs): return True except Exception: return False - return False class ARTParticleFile(ParticleFile): @@ -721,7 +721,6 @@ def _is_valid(cls, filename, *args, **kwargs): return True except Exception: return False - return False class ARTDomainSubset(OctreeSubset): diff --git a/yt/frontends/art/fields.py b/yt/frontends/art/fields.py index ec3cd1c9bfa..6cddf72d8f1 100644 --- a/yt/frontends/art/fields.py +++ b/yt/frontends/art/fields.py @@ -1,3 +1,4 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer b_units = "code_magnetic" @@ -10,7 +11,7 @@ class ARTFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("Density", (rho_units, ["density"], None)), ("TotalEnergy", (en_units, ["total_energy_density"], None)), ("XMomentumDensity", (mom_units, ["momentum_density_x"], None)), @@ -25,7 +26,7 @@ class ARTFieldInfo(FieldInfoContainer): ("PotentialOld", ("", ["gas_potential"], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_position_x", ("code_length", [], None)), ("particle_position_y", ("code_length", [], None)), ("particle_position_z", ("code_length", [], None)), diff --git a/yt/frontends/artio/fields.py b/yt/frontends/artio/fields.py index 57cdff331b4..662b8226f46 100644 --- a/yt/frontends/artio/fields.py +++ b/yt/frontends/artio/fields.py @@ -1,5 +1,6 @@ import numpy as np +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer from yt.units.yt_array import YTArray from yt.utilities.physical_constants import amu_cgs, boltzmann_constant_cgs @@ -15,7 +16,7 @@ class ARTIOFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("HVAR_GAS_DENSITY", (rho_units, ["density"], None)), ("HVAR_GAS_ENERGY", (en_units, ["total_energy_density"], None)), ("HVAR_INTERNAL_ENERGY", (en_units, ["thermal_energy_density"], None)), @@ -36,7 +37,7 @@ class ARTIOFieldInfo(FieldInfoContainer): ("RT_HVAR_HeIII", (rho_units, ["He_p2_density"], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("POSITION_X", ("code_length", ["particle_position_x"], None)), ("POSITION_Y", ("code_length", ["particle_position_y"], None)), ("POSITION_Z", ("code_length", ["particle_position_z"], None)), diff --git a/yt/frontends/athena/fields.py b/yt/frontends/athena/fields.py index 271c758a281..ed94a864d76 100644 --- a/yt/frontends/athena/fields.py +++ b/yt/frontends/athena/fields.py @@ -1,3 +1,4 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer from yt.utilities.physical_constants import kboltz, mh @@ -15,7 +16,7 @@ def _velocity(field, data): class AthenaFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("density", ("code_mass/code_length**3", ["density"], None)), ("cell_centered_B_x", (b_units, [], None)), ("cell_centered_B_y", (b_units, [], None)), diff --git a/yt/frontends/athena_pp/fields.py b/yt/frontends/athena_pp/fields.py index 361e9e93476..65bb7d5535a 100644 --- a/yt/frontends/athena_pp/fields.py +++ b/yt/frontends/athena_pp/fields.py @@ -1,3 +1,4 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer from yt.utilities.physical_constants import kboltz, mh @@ -15,7 +16,7 @@ def _velocity(field, data): class AthenaPPFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("rho", (rho_units, ["density"], None)), ("dens", (rho_units, ["density"], None)), ("Bcc1", (b_units, [], None)), diff --git a/yt/frontends/boxlib/data_structures.py b/yt/frontends/boxlib/data_structures.py index 2158ea801b3..f403632c774 100644 --- a/yt/frontends/boxlib/data_structures.py +++ b/yt/frontends/boxlib/data_structures.py @@ -3,11 +3,13 @@ import re from collections import namedtuple from stat import ST_CTIME +from typing import Type import numpy as np from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch from yt.data_objects.static_output import Dataset +from yt.fields.field_info_container import FieldInfoContainer from yt.funcs import mylog, setdefaultattr from yt.geometry.grid_geometry_handler import GridIndex from yt.utilities.io_handler import io_registry @@ -628,10 +630,9 @@ class BoxlibDataset(Dataset): """ _index_class = BoxlibHierarchy - _field_info_class = BoxlibFieldInfo + _field_info_class: Type[FieldInfoContainer] = BoxlibFieldInfo _output_prefix = None _default_cparam_filename = "job_info" - _periodicity = (False, False, False) def __init__( self, @@ -728,6 +729,7 @@ def _parse_parameter_file(self): Parses the parameter file and establishes the various dictionaries. """ + self._periodicity = (False, False, False) self._parse_header_file() # Let's read the file hfn = os.path.join(self.output_dir, "Header") diff --git a/yt/frontends/boxlib/fields.py b/yt/frontends/boxlib/fields.py index 84d41c73356..5be3d5861b1 100644 --- a/yt/frontends/boxlib/fields.py +++ b/yt/frontends/boxlib/fields.py @@ -3,6 +3,7 @@ import numpy as np +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer from yt.frontends.boxlib.misc import BoxlibSetupParticleFieldsMixin from yt.units import YTQuantity @@ -48,7 +49,7 @@ def _temperature(field, data): class WarpXFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("Bx", ("T", ["magnetic_field_x", "B_x"], None)), ("By", ("T", ["magnetic_field_y", "B_y"], None)), ("Bz", ("T", ["magnetic_field_z", "B_z"], None)), @@ -60,7 +61,7 @@ class WarpXFieldInfo(FieldInfoContainer): ("jz", ("A", ["current_z", "Jz", "J_z"], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_weight", ("", ["particle_weighting"], None)), ("particle_position_x", ("m", [], None)), ("particle_position_y", ("m", [], None)), @@ -179,8 +180,7 @@ def get_velocity_z(field, data): class NyxFieldInfo(FieldInfoContainer): - known_other_fields = () - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_position_x", ("code_length", [], None)), ("particle_position_y", ("code_length", [], None)), ("particle_position_z", ("code_length", [], None)), @@ -188,7 +188,7 @@ class NyxFieldInfo(FieldInfoContainer): class BoxlibFieldInfo(FieldInfoContainer, BoxlibSetupParticleFieldsMixin): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("density", (rho_units, ["density"], None)), ("eden", (eden_units, ["total_energy_density"], None)), ("xmom", (mom_units, ["momentum_density_x"], None)), @@ -204,7 +204,7 @@ class BoxlibFieldInfo(FieldInfoContainer, BoxlibSetupParticleFieldsMixin): ("zvel", ("cm/s", ["velocity_z"], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_mass", ("code_mass", [], None)), ("particle_position_x", ("code_length", [], None)), ("particle_position_y", ("code_length", [], None)), @@ -286,7 +286,7 @@ def momentum(field, data): class CastroFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("density", ("g/cm**3", ["density"], r"\rho")), ("xmom", ("g/(cm**2 * s)", ["momentum_density_x"], r"\rho u")), ("ymom", ("g/(cm**2 * s)", ["momentum_density_y"], r"\rho v")), @@ -330,7 +330,7 @@ class CastroFieldInfo(FieldInfoContainer): ("rot_z", ("cm/s**2", [], r"\mathbf{f}_{\rm{rot}} \cdot \mathbf{e}_z")), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_position_x", ("code_length", [], None)), ("particle_position_y", ("code_length", [], None)), ("particle_position_z", ("code_length", [], None)), @@ -364,7 +364,7 @@ def setup_fluid_fields(self): class MaestroFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("density", ("g/cm**3", ["density"], None)), ("x_vel", ("cm/s", ["velocity_x"], r"\tilde{u}")), ("y_vel", ("cm/s", ["velocity_y"], r"\tilde{v}")), diff --git a/yt/frontends/chombo/data_structures.py b/yt/frontends/chombo/data_structures.py index a0c941a8ab6..5b06188ac23 100644 --- a/yt/frontends/chombo/data_structures.py +++ b/yt/frontends/chombo/data_structures.py @@ -1,11 +1,13 @@ import os import re import weakref +from typing import Type import numpy as np from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch from yt.data_objects.static_output import Dataset +from yt.fields.field_info_container import FieldInfoContainer from yt.frontends.boxlib.misc import BoxlibReadParticleFileMixin from yt.funcs import mylog, setdefaultattr from yt.geometry.grid_geometry_handler import GridIndex @@ -237,7 +239,7 @@ def _reconstruct_parent_child(self): class ChomboDataset(Dataset): _index_class = ChomboHierarchy - _field_info_class = ChomboFieldInfo + _field_info_class: Type[FieldInfoContainer] = ChomboFieldInfo def __init__( self, diff --git a/yt/frontends/chombo/fields.py b/yt/frontends/chombo/fields.py index 89938e32546..7270f47bc94 100644 --- a/yt/frontends/chombo/fields.py +++ b/yt/frontends/chombo/fields.py @@ -1,5 +1,6 @@ import numpy as np +from yt._typing import KnownFieldsT from yt.fields.field_info_container import ( FieldInfoContainer, particle_deposition_functions, @@ -7,7 +8,7 @@ standard_particle_fields, ) from yt.frontends.boxlib.misc import BoxlibSetupParticleFieldsMixin -from yt.units.unit_object import Unit +from yt.units.unit_object import Unit # type: ignore from yt.utilities.exceptions import YTFieldNotFound rho_units = "code_mass / code_length**3" @@ -17,17 +18,16 @@ b_units = "code_magnetic" -# Chombo does not have any known fields by itself. class ChomboFieldInfo(FieldInfoContainer): - known_other_fields = () - known_particle_fields = () + # no custom behaviour is needed yet + pass # Orion 2 Fields # We duplicate everything here from Boxlib, because we want to be able to # subclass it and that can be somewhat tricky. class Orion2FieldInfo(ChomboFieldInfo, BoxlibSetupParticleFieldsMixin): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("density", (rho_units, ["density"], None)), ("energy-density", (eden_units, ["total_energy_density"], None)), ("radiation-energy-density", (eden_units, ["radiation_energy_density"], None)), @@ -43,7 +43,7 @@ class Orion2FieldInfo(ChomboFieldInfo, BoxlibSetupParticleFieldsMixin): ("directrad-dpydt-density", (mom_units, ["directrad-dpydt-density"], None)), ("directrad-dpzdt-density", (mom_units, ["directrad-dpzdt-density"], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_mass", ("code_mass", [], None)), ("particle_position_x", ("code_length", [], None)), ("particle_position_y", ("code_length", [], None)), @@ -176,7 +176,7 @@ def velocity(field, data): class ChomboPICFieldInfo3D(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("density", (rho_units, ["density", "Density"], None)), ( "potential", @@ -186,7 +186,7 @@ class ChomboPICFieldInfo3D(FieldInfoContainer): ("gravitational_field_y", ("code_length / code_time**2", [], None)), ("gravitational_field_z", ("code_length / code_time**2", [], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_mass", ("code_mass", [], None)), ("particle_position_x", ("code_length", [], None)), ("particle_position_y", ("code_length", [], None)), @@ -262,7 +262,7 @@ def _dummy_field(field, data): class ChomboPICFieldInfo2D(ChomboPICFieldInfo3D): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("density", (rho_units, ["density", "Density"], None)), ( "potential", @@ -271,7 +271,7 @@ class ChomboPICFieldInfo2D(ChomboPICFieldInfo3D): ("gravitational_field_x", ("code_length / code_time**2", [], None)), ("gravitational_field_y", ("code_length / code_time**2", [], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_mass", ("code_mass", [], None)), ("particle_position_x", ("code_length", [], None)), ("particle_position_y", ("code_length", [], None)), @@ -307,7 +307,7 @@ def __init__(self, ds, field_list): class ChomboPICFieldInfo1D(ChomboPICFieldInfo3D): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("density", (rho_units, ["density", "Density"], None)), ( "potential", @@ -315,7 +315,7 @@ class ChomboPICFieldInfo1D(ChomboPICFieldInfo3D): ), ("gravitational_field_x", ("code_length / code_time**2", [], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_mass", ("code_mass", [], None)), ("particle_position_x", ("code_length", [], None)), ("particle_velocity_x", ("code_length / code_time", [], None)), @@ -367,7 +367,7 @@ def __init__(self, ds, field_list): class PlutoFieldInfo(ChomboFieldInfo): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("rho", (rho_units, ["density"], None)), ("prs", ("code_mass / (code_length * code_time**2)", ["pressure"], None)), ("vx1", (vel_units, ["velocity_x"], None)), diff --git a/yt/frontends/eagle/data_structures.py b/yt/frontends/eagle/data_structures.py index 019cfce377c..92046e4e3fa 100644 --- a/yt/frontends/eagle/data_structures.py +++ b/yt/frontends/eagle/data_structures.py @@ -1,6 +1,9 @@ +from typing import Type + import numpy as np import yt.units +from yt.fields.field_info_container import FieldInfoContainer from yt.frontends.gadget.data_structures import GadgetHDF5Dataset from yt.frontends.owls.fields import OWLSFieldInfo from yt.utilities.on_demand_imports import _h5py as h5py @@ -10,7 +13,7 @@ class EagleDataset(GadgetHDF5Dataset): _particle_mass_name = "Mass" - _field_info_class = OWLSFieldInfo + _field_info_class: Type[FieldInfoContainer] = OWLSFieldInfo _time_readin_ = "Time" def _parse_parameter_file(self): diff --git a/yt/frontends/enzo/fields.py b/yt/frontends/enzo/fields.py index c1fcbe98d42..a09ed39061d 100644 --- a/yt/frontends/enzo/fields.py +++ b/yt/frontends/enzo/fields.py @@ -1,5 +1,6 @@ import numpy as np +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer from yt.utilities.physical_constants import me, mp @@ -48,7 +49,7 @@ class EnzoFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("Cooling_Time", ("s", ["cooling_time"], None)), ("Dengo_Cooling_Rate", ("erg/g/s", [], None)), ("Grackle_Cooling_Rate", ("erg/s/cm**3", [], None)), @@ -89,7 +90,7 @@ class EnzoFieldInfo(FieldInfoContainer): ("Electron_Density", (rho_units, [], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_position_x", ("code_length", [], None)), ("particle_position_y", ("code_length", [], None)), ("particle_position_z", ("code_length", [], None)), diff --git a/yt/frontends/enzo/io.py b/yt/frontends/enzo/io.py index 81433f45286..9e6112bfedc 100644 --- a/yt/frontends/enzo/io.py +++ b/yt/frontends/enzo/io.py @@ -1,3 +1,5 @@ +from typing import Dict + import numpy as np from yt.geometry.selection_routines import GridSelector @@ -7,7 +9,7 @@ _convert_mass = ("particle_mass", "mass") -_particle_position_names = {} +_particle_position_names: Dict[str, str] = {} class IOHandlerPackedHDF5(BaseIOHandler): diff --git a/yt/frontends/enzo_e/fields.py b/yt/frontends/enzo_e/fields.py index 3a6b56daf42..c2d3fdd6e43 100644 --- a/yt/frontends/enzo_e/fields.py +++ b/yt/frontends/enzo_e/fields.py @@ -1,5 +1,6 @@ import numpy as np +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer from yt.fields.magnetic_field import setup_magnetic_field_aliases from yt.fields.particle_fields import add_union_field @@ -11,8 +12,6 @@ energy_units = "code_velocity**2" b_units = "code_magnetic" -known_species_names = {} - NODAL_FLAGS = { "bfieldi_x": [1, 0, 0], "bfieldi_y": [0, 1, 0], @@ -21,7 +20,7 @@ class EnzoEFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("velocity_x", (vel_units, ["velocity_x"], None)), ("velocity_y", (vel_units, ["velocity_y"], None)), ("velocity_z", (vel_units, ["velocity_z"], None)), @@ -40,7 +39,7 @@ class EnzoEFieldInfo(FieldInfoContainer): ("bfieldi_z", (b_units, [], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("x", ("code_length", ["particle_position_x"], None)), ("y", ("code_length", ["particle_position_y"], None)), ("z", ("code_length", ["particle_position_z"], None)), diff --git a/yt/frontends/enzo_e/io.py b/yt/frontends/enzo_e/io.py index f50f33684ed..f9a370de349 100644 --- a/yt/frontends/enzo_e/io.py +++ b/yt/frontends/enzo_e/io.py @@ -4,8 +4,6 @@ from yt.utilities.io_handler import BaseIOHandler from yt.utilities.on_demand_imports import _h5py as h5py -_particle_position_names = {} - class EnzoEIOHandler(BaseIOHandler): diff --git a/yt/frontends/fits/data_structures.py b/yt/frontends/fits/data_structures.py index 580681d0eff..e06b0c3e9fb 100644 --- a/yt/frontends/fits/data_structures.py +++ b/yt/frontends/fits/data_structures.py @@ -4,6 +4,7 @@ import warnings import weakref from collections import defaultdict +from typing import Type import numpy as np import numpy.core.defchararray as np_char @@ -12,12 +13,16 @@ from yt.config import ytcfg from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch from yt.data_objects.static_output import Dataset +from yt.fields.field_info_container import FieldInfoContainer from yt.funcs import mylog, setdefaultattr from yt.geometry.geometry_handler import YTDataChunk from yt.geometry.grid_geometry_handler import GridIndex from yt.units import dimensions -from yt.units.unit_lookup_table import default_unit_symbol_lut, unit_prefixes -from yt.units.unit_object import UnitParseError +from yt.units.unit_lookup_table import ( # type: ignore + default_unit_symbol_lut, + unit_prefixes, +) +from yt.units.unit_object import UnitParseError # type: ignore from yt.units.yt_array import YTQuantity from yt.utilities.decompose import decompose_array, get_psize from yt.utilities.file_handler import FITSFileHandler @@ -32,9 +37,9 @@ spec_names = {"V": "Velocity", "F": "Frequency", "E": "Energy", "W": "Wavelength"} space_prefixes = list(set(lon_prefixes + lat_prefixes)) -sky_prefixes = set(space_prefixes) -sky_prefixes.difference_update({"X", "Y", "LINEAR"}) -sky_prefixes = list(sky_prefixes) +unique_sky_prefixes = set(space_prefixes) +unique_sky_prefixes.difference_update({"X", "Y", "LINEAR"}) +sky_prefixes = list(unique_sky_prefixes) spec_prefixes = list(spec_names.keys()) @@ -313,7 +318,7 @@ def check_sky_coords(filename, ndim): class FITSDataset(Dataset): _index_class = FITSHierarchy - _field_info_class = FITSFieldInfo + _field_info_class: Type[FieldInfoContainer] = FITSFieldInfo _dataset_type = "fits" _handle = None diff --git a/yt/frontends/fits/fields.py b/yt/frontends/fits/fields.py index 2ef4f663e9a..244ba17139e 100644 --- a/yt/frontends/fits/fields.py +++ b/yt/frontends/fits/fields.py @@ -1,3 +1,4 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer @@ -12,7 +13,7 @@ def __init__(self, ds, field_list, slice_info=None): class YTFITSFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("density", ("code_mass/code_length**3", ["density"], None)), ( "dark_matter_density", diff --git a/yt/frontends/flash/data_structures.py b/yt/frontends/flash/data_structures.py index 9243936cf81..50eee652ef1 100644 --- a/yt/frontends/flash/data_structures.py +++ b/yt/frontends/flash/data_structures.py @@ -1,11 +1,13 @@ import os import weakref +from typing import Type import numpy as np from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch from yt.data_objects.static_output import Dataset, ParticleFile, validate_index_order from yt.funcs import mylog, setdefaultattr +from yt.geometry.geometry_handler import Index from yt.geometry.grid_geometry_handler import GridIndex from yt.geometry.particle_geometry_handler import ParticleIndex from yt.utilities.file_handler import HDF5FileHandler, warn_h5py @@ -162,7 +164,7 @@ def _populate_grid_objects(self): class FLASHDataset(Dataset): - _index_class = FLASHHierarchy + _index_class: Type[Index] = FLASHHierarchy _field_info_class = FLASHFieldInfo _handle = None diff --git a/yt/frontends/flash/fields.py b/yt/frontends/flash/fields.py index 45525d4c07f..1aa662147ab 100644 --- a/yt/frontends/flash/fields.py +++ b/yt/frontends/flash/fields.py @@ -1,3 +1,4 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer # Common fields in FLASH: (Thanks to John ZuHone for this list) @@ -24,7 +25,7 @@ class FLASHFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("velx", ("code_length/code_time", ["velocity_x"], None)), ("vely", ("code_length/code_time", ["velocity_y"], None)), ("velz", ("code_length/code_time", ["velocity_z"], None)), diff --git a/yt/frontends/gadget/data_structures.py b/yt/frontends/gadget/data_structures.py index a6cedd40730..1e4f8057eca 100644 --- a/yt/frontends/gadget/data_structures.py +++ b/yt/frontends/gadget/data_structures.py @@ -1,12 +1,15 @@ import os import stat import struct +from typing import Type import numpy as np from yt.data_objects.static_output import ParticleFile +from yt.fields.field_info_container import FieldInfoContainer from yt.frontends.sph.data_structures import SPHDataset, SPHParticleIndex from yt.funcs import only_on_root +from yt.geometry.geometry_handler import Index from yt.utilities.chemical_formulas import compute_mu from yt.utilities.cosmology import Cosmology from yt.utilities.fortran_utils import read_record @@ -208,9 +211,9 @@ def _initialize_frontend_specific(self): class GadgetDataset(SPHDataset): - _index_class = GadgetBinaryIndex - _file_class = GadgetBinaryFile - _field_info_class = GadgetFieldInfo + _index_class: Type[Index] = GadgetBinaryIndex + _file_class: Type[ParticleFile] = GadgetBinaryFile + _field_info_class: Type[FieldInfoContainer] = GadgetFieldInfo _particle_mass_name = "Mass" _particle_coordinates_name = "Coordinates" _particle_velocity_name = "Velocities" @@ -562,7 +565,7 @@ def _is_valid(cls, filename, *args, **kwargs): class GadgetHDF5Dataset(GadgetDataset): _file_class = ParticleFile _index_class = SPHParticleIndex - _field_info_class = GadgetFieldInfo + _field_info_class: Type[FieldInfoContainer] = GadgetFieldInfo _particle_mass_name = "Masses" _sph_ptypes = ("PartType0",) _suffix = ".hdf5" diff --git a/yt/frontends/gadget/io.py b/yt/frontends/gadget/io.py index 0cf1de53300..ee483f2da95 100644 --- a/yt/frontends/gadget/io.py +++ b/yt/frontends/gadget/io.py @@ -4,7 +4,7 @@ import numpy as np from yt.frontends.sph.io import IOHandlerSPH -from yt.units.yt_array import uconcatenate +from yt.units.yt_array import uconcatenate # type: ignore from yt.utilities.lib.particle_kdtree_tools import generate_smoothing_length from yt.utilities.logger import ytLogger as mylog from yt.utilities.on_demand_imports import _h5py as h5py @@ -308,7 +308,7 @@ def _identify_fields(self, data_file): class IOHandlerGadgetBinary(IOHandlerSPH): _dataset_type = "gadget_binary" - _vector_fields = ( + _vector_fields = ( # type: ignore ("Coordinates", 3), ("Velocity", 3), ("Velocities", 3), diff --git a/yt/frontends/gadget_fof/fields.py b/yt/frontends/gadget_fof/fields.py index a13af6ba2f5..2f80c1dbee0 100644 --- a/yt/frontends/gadget_fof/fields.py +++ b/yt/frontends/gadget_fof/fields.py @@ -1,3 +1,4 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer m_units = "code_mass" @@ -5,13 +6,13 @@ v_units = "code_velocity" _pnums = 6 -_type_fields = tuple( +_type_fields: KnownFieldsT = tuple( ("%s%sType_%d" % (ptype, field, pnum), (units, [], None)) for pnum in range(_pnums) for field, units in (("Mass", m_units), ("Len", p_units)) for ptype in ("Group", "Subhalo") ) -_sub_type_fields = tuple( +_sub_type_fields: KnownFieldsT = tuple( ("Subhalo%sType_%d" % (field, pnum), (units, [], None)) for pnum in range(_pnums) for field, units in ( @@ -73,8 +74,6 @@ class GadgetFOFFieldInfo(FieldInfoContainer): - known_other_fields = () - known_particle_fields = _particle_fields # these are extra fields to be created for the "all" particle type @@ -92,6 +91,4 @@ class GadgetFOFFieldInfo(FieldInfoContainer): class GadgetFOFHaloFieldInfo(FieldInfoContainer): - known_other_fields = () - known_particle_fields = _particle_fields + (("ID", ("", ["member_ids"], None)),) diff --git a/yt/frontends/gamer/fields.py b/yt/frontends/gamer/fields.py index 9cf5f21cf8b..8334100082e 100644 --- a/yt/frontends/gamer/fields.py +++ b/yt/frontends/gamer/fields.py @@ -1,5 +1,6 @@ import numpy as np +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer b_units = "code_magnetic" @@ -14,7 +15,7 @@ class GAMERFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( # hydro fields on disk (GAMER outputs conservative variables) ("Dens", (rho_units, [], None)), ("MomX", (mom_units, ["momentum_density_x"], None)), @@ -34,7 +35,7 @@ class GAMERFieldInfo(FieldInfoContainer): ("TotalDens", (rho_units, ["total_density_on_grid"], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("ParMass", ("code_mass", ["particle_mass"], None)), ("ParPosX", ("code_length", ["particle_position_x"], None)), ("ParPosY", ("code_length", ["particle_position_y"], None)), diff --git a/yt/frontends/gdf/data_structures.py b/yt/frontends/gdf/data_structures.py index 4ab77c2377f..91e8edde86c 100644 --- a/yt/frontends/gdf/data_structures.py +++ b/yt/frontends/gdf/data_structures.py @@ -7,9 +7,9 @@ from yt.data_objects.static_output import Dataset from yt.funcs import just_one, setdefaultattr from yt.geometry.grid_geometry_handler import GridIndex -from yt.units.dimensions import dimensionless as sympy_one -from yt.units.unit_object import Unit -from yt.units.unit_systems import unit_system_registry +from yt.units.dimensions import dimensionless as sympy_one # type: ignore +from yt.units.unit_object import Unit # type: ignore +from yt.units.unit_systems import unit_system_registry # type: ignore from yt.utilities.exceptions import YTGDFUnknownGeometry from yt.utilities.lib.misc_utilities import get_box_grids_level from yt.utilities.logger import ytLogger as mylog diff --git a/yt/frontends/gizmo/fields.py b/yt/frontends/gizmo/fields.py index 9bc71485582..46b480467e6 100644 --- a/yt/frontends/gizmo/fields.py +++ b/yt/frontends/gizmo/fields.py @@ -1,3 +1,4 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer from yt.fields.magnetic_field import setup_magnetic_field_aliases from yt.fields.species_fields import add_species_field_by_density, setup_species_fields @@ -10,7 +11,7 @@ class GizmoFieldInfo(GadgetFieldInfo): # The known fields list is according to the GIZMO User Guide. See # http://www.tapir.caltech.edu/~phopkins/Site/GIZMO_files/gizmo_documentation.html#snaps-reading - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("Coordinates", ("code_length", ["particle_position"], None)), ("Velocities", ("code_velocity", ["particle_velocity"], None)), ("ParticleIDs", ("", ["particle_index"], None)), diff --git a/yt/frontends/halo_catalog/fields.py b/yt/frontends/halo_catalog/fields.py index 6884141e60a..cacb2c415e4 100644 --- a/yt/frontends/halo_catalog/fields.py +++ b/yt/frontends/halo_catalog/fields.py @@ -1,3 +1,4 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer m_units = "g" @@ -5,7 +6,7 @@ v_units = "cm / s" r_units = "cm" -_particle_fields = ( +_particle_fields: KnownFieldsT = ( ("particle_identifier", ("", [], None)), ("particle_position_x", (p_units, [], None)), ("particle_position_y", (p_units, [], None)), diff --git a/yt/frontends/halo_catalog/io.py b/yt/frontends/halo_catalog/io.py index f7566342184..c4315d0cc5d 100644 --- a/yt/frontends/halo_catalog/io.py +++ b/yt/frontends/halo_catalog/io.py @@ -4,7 +4,7 @@ from yt.frontends.gadget_fof.io import IOHandlerGadgetFOFHaloHDF5 from yt.funcs import parse_h5_attr -from yt.units.yt_array import uvstack +from yt.units.yt_array import uvstack # type: ignore from yt.utilities.io_handler import BaseIOHandler from yt.utilities.on_demand_imports import _h5py as h5py @@ -116,7 +116,12 @@ def _read_particle_fields(self, dobj, ptf): _read_particle_selection = IOHandlerGadgetFOFHaloHDF5._read_particle_selection -class IOHandlerYTHalo(HaloDatasetIOHandler, IOHandlerYTHaloCatalog): +# ignoring type in this mixing to circunvent this error from mypy +# Definition of "_read_particle_fields" in base class "HaloDatasetIOHandler" +# is incompatible with definition in base class "IOHandlerYTHaloCatalog" +# +# it may not be possible to refactor out of this situation without breaking downstream +class IOHandlerYTHalo(HaloDatasetIOHandler, IOHandlerYTHaloCatalog): # type: ignore _dataset_type = "ythalo" def _identify_fields(self, data_file): diff --git a/yt/frontends/open_pmd/data_structures.py b/yt/frontends/open_pmd/data_structures.py index e87afe0d87d..3830b551eb3 100644 --- a/yt/frontends/open_pmd/data_structures.py +++ b/yt/frontends/open_pmd/data_structures.py @@ -2,6 +2,7 @@ from operator import mul from os import listdir, path from re import match +from typing import List, Optional import numpy as np from packaging.version import Version @@ -32,8 +33,9 @@ class OpenPMDGrid(AMRGridPatch): _id_offset = 0 __slots__ = ["_level_id"] # Every particle species and mesh might have different hdf5-indices and offsets - ftypes = [] - ptypes = [] + + ftypes: Optional[List[str]] = [] + ptypes: Optional[List[str]] = [] findex = 0 foffset = 0 pindex = 0 diff --git a/yt/frontends/open_pmd/fields.py b/yt/frontends/open_pmd/fields.py index 86f3664d1c1..1edfe86fd2b 100644 --- a/yt/frontends/open_pmd/fields.py +++ b/yt/frontends/open_pmd/fields.py @@ -1,3 +1,5 @@ +from typing import List + import numpy as np from yt.fields.field_info_container import FieldInfoContainer @@ -139,7 +141,7 @@ class OpenPMDFieldInfo(FieldInfoContainer): * [1] http://yt-project.org/docs/dev/reference/field_list.html#universal-fields """ - _mag_fields = [] + _mag_fields: List[str] = [] def __init__(self, ds, field_list): f = ds._handle diff --git a/yt/frontends/owls/fields.py b/yt/frontends/owls/fields.py index 5bf919f202d..0ea6ede86e1 100644 --- a/yt/frontends/owls/fields.py +++ b/yt/frontends/owls/fields.py @@ -1,4 +1,5 @@ import os +from typing import Tuple import numpy as np @@ -44,7 +45,7 @@ def _get_ion_mass_frac(ion, ftype, itab, data): class OWLSFieldInfo(SPHFieldInfo): - _ions = ( + _ions: Tuple[str, ...] = ( "c1", "c2", "c3", diff --git a/yt/frontends/owls_subfind/fields.py b/yt/frontends/owls_subfind/fields.py index 0a137ade725..8bc38a9c667 100644 --- a/yt/frontends/owls_subfind/fields.py +++ b/yt/frontends/owls_subfind/fields.py @@ -1,3 +1,4 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer m_units = "code_mass" @@ -7,9 +8,7 @@ class OWLSSubfindFieldInfo(FieldInfoContainer): - known_other_fields = () - - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("CenterOfMass_0", (p_units, ["particle_position_x"], None)), ("CenterOfMass_1", (p_units, ["particle_position_y"], None)), ("CenterOfMass_2", (p_units, ["particle_position_z"], None)), diff --git a/yt/frontends/ramses/definitions.py b/yt/frontends/ramses/definitions.py index d446312fe47..41365079d88 100644 --- a/yt/frontends/ramses/definitions.py +++ b/yt/frontends/ramses/definitions.py @@ -71,7 +71,7 @@ def ramses_header(hvals): if ytcfg.has_section("ramses-families"): for key in particle_families.keys(): - val = ytcfg.get("ramses-families", key, fallback=None) + val = ytcfg.get("ramses-families", key, callback=None) if val is not None: mylog.info( "Changing family %s from %s to %s", key, particle_families[key], val diff --git a/yt/frontends/ramses/field_handlers.py b/yt/frontends/ramses/field_handlers.py index a50cfae95eb..691de7cf288 100644 --- a/yt/frontends/ramses/field_handlers.py +++ b/yt/frontends/ramses/field_handlers.py @@ -1,7 +1,7 @@ import abc import glob import os -from typing import List +from typing import List, Optional, Set, Tuple, Type from yt.config import ytcfg from yt.funcs import mylog @@ -10,7 +10,7 @@ from .io import _read_fluid_file_descriptor from .io_utils import read_offset -FIELD_HANDLERS = set() +FIELD_HANDLERS: Set[Type["FieldFileHandler"]] = set() def get_field_handlers(): @@ -21,7 +21,7 @@ def register_field_handler(ph): FIELD_HANDLERS.add(ph) -DETECTED_FIELDS = {} +DETECTED_FIELDS = {} # type: ignore class HandlerMixin: @@ -141,13 +141,15 @@ class FieldFileHandler(abc.ABC, HandlerMixin): _file_type = "field" # These properties are static properties - ftype = None # The name to give to the field type - fname = None # The name of the file(s) - attrs = None # The attributes of the header + ftype: Optional[str] = None # The name to give to the field type + fname: Optional[str] = None # The name of the file(s) + attrs: Optional[ + Tuple[Tuple[str, int, str], ...] + ] = None # The attributes of the header known_fields = None # A list of tuple containing the field name and its type - config_field = None # Name of the config section (if any) + config_field: Optional[str] = None # Name of the config section (if any) - file_descriptor = None # The name of the file descriptor (if any) + file_descriptor: Optional[str] = None # The name of the file descriptor (if any) # These properties are computed dynamically field_offsets = None # Mapping from field to offset in file diff --git a/yt/frontends/ramses/fields.py b/yt/frontends/ramses/fields.py index ccd2c5ebc8d..87a04a1bcfe 100644 --- a/yt/frontends/ramses/fields.py +++ b/yt/frontends/ramses/fields.py @@ -3,6 +3,7 @@ import numpy as np from yt import units +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer from yt.frontends.ramses.io import convert_ramses_ages from yt.utilities.cython_fortran_utils import FortranFile @@ -92,7 +93,7 @@ class RAMSESFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("Density", (rho_units, ["density"], None)), ("x-velocity", (vel_units, ["velocity_x"], None)), ("y-velocity", (vel_units, ["velocity_y"], None)), @@ -114,7 +115,7 @@ class RAMSESFieldInfo(FieldInfoContainer): ("B_z_left", (b_units, ["magnetic_field_z_left"], None)), ("B_z_right", (b_units, ["magnetic_field_z_right"], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_position_x", ("code_length", [], None)), ("particle_position_y", ("code_length", [], None)), ("particle_position_z", ("code_length", [], None)), @@ -131,7 +132,7 @@ class RAMSESFieldInfo(FieldInfoContainer): ("particle_tag", ("", [], None)), ) - known_sink_fields = ( + known_sink_fields: KnownFieldsT = ( ("particle_position_x", ("code_length", [], None)), ("particle_position_y", ("code_length", [], None)), ("particle_position_z", ("code_length", [], None)), diff --git a/yt/frontends/ramses/particle_handlers.py b/yt/frontends/ramses/particle_handlers.py index 39866edb4f8..92374601e9f 100644 --- a/yt/frontends/ramses/particle_handlers.py +++ b/yt/frontends/ramses/particle_handlers.py @@ -1,5 +1,6 @@ import abc import os +from typing import List, Optional, Set, Tuple, Type from yt.config import ytcfg from yt.funcs import mylog @@ -8,7 +9,7 @@ from .field_handlers import HandlerMixin from .io import _read_part_file_descriptor -PARTICLE_HANDLERS = set() +PARTICLE_HANDLERS: Set[Type["ParticleFileHandler"]] = set() def get_particle_handlers(): @@ -32,13 +33,15 @@ class ParticleFileHandler(abc.ABC, HandlerMixin): _file_type = "particle" # These properties are static properties - ptype = None # The name to give to the particle type - fname = None # The name of the file(s). - file_descriptor = None # The name of the file descriptor (if any) - - attrs = None # The attributes of the header - known_fields = None # A list of tuple containing the field name and its type - config_field = None # Name of the config section (if any) + ptype: Optional[str] = None # The name to give to the particle type + fname: Optional[str] = None # The name of the file(s). + file_descriptor: Optional[str] = None # The name of the file descriptor (if any) + + attrs: Tuple[Tuple[str, int, str], ...] # The attributes of the header + known_fields: Optional[ + List[Tuple[str, str]] + ] = None # A list of tuple containing the field name and its type + config_field: Optional[str] = None # Name of the config section (if any) # These properties are computed dynamically field_offsets = None # Mapping from field to offset in file diff --git a/yt/frontends/rockstar/definitions.py b/yt/frontends/rockstar/definitions.py index 1d3fe63af83..f5c4c5e96f3 100644 --- a/yt/frontends/rockstar/definitions.py +++ b/yt/frontends/rockstar/definitions.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, List, Tuple, Union + import numpy as np BINARY_HEADER_SIZE = 256 @@ -23,9 +25,11 @@ # Note the final field here, which is a field for min/max format revision in # which the field appears. -KNOWN_REVISIONS = [0, 1, 2] +KNOWN_REVISIONS: List[int] = [0, 1, 2] -halo_dt = [ +# using typing.Any here in lieu of numpy.typing.DTypeLike (should be backported for numpy < 1.20) +HaloDataType = Union[Tuple[str, Any], Tuple[str, Any, Tuple[int, int]]] +halo_dt: List[HaloDataType] = [ ("particle_identifier", np.int64), ("particle_position_x", np.float32), ("particle_position_y", np.float32), @@ -94,18 +98,21 @@ ("av_density", np.float32, (2, 100)), ] -halo_dts = {} +# using typing.Any here in lieu of numpy.typing.DTypeLike (should be backported for numpy < 1.20) +halo_dts_tmp: Dict[int, List[HaloDataType]] = {} +halo_dts: Dict[int, np.dtype] = {} for rev in KNOWN_REVISIONS: - halo_dts[rev] = [] + halo_dts_tmp[rev] = [] for item in halo_dt: if len(item) == 2: - halo_dts[rev].append(item) - else: - mi, ma = item[2] + halo_dts_tmp[rev].append(item) + elif len(item) == 3: + mi, ma = item[2] # type: ignore if (mi <= rev) and (rev <= ma): - halo_dts[rev].append(item[:2]) - halo_dts[rev] = np.dtype(halo_dts[rev], align=True) + halo_dts_tmp[rev].append(item[:2]) + halo_dts[rev] = np.dtype(halo_dts_tmp[rev], align=True) +del halo_dts_tmp particle_dt = np.dtype( [ diff --git a/yt/frontends/rockstar/fields.py b/yt/frontends/rockstar/fields.py index 724483e74f3..72520df0743 100644 --- a/yt/frontends/rockstar/fields.py +++ b/yt/frontends/rockstar/fields.py @@ -1,3 +1,4 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer m_units = "Msun / h" # Msun / h @@ -7,9 +8,8 @@ class RockstarFieldInfo(FieldInfoContainer): - known_other_fields = () - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_identifier", ("", [], None)), ("particle_position_x", (p_units, [], None)), ("particle_position_y", (p_units, [], None)), diff --git a/yt/frontends/sph/fields.py b/yt/frontends/sph/fields.py index d5139e0b34e..27e6e9e036b 100644 --- a/yt/frontends/sph/fields.py +++ b/yt/frontends/sph/fields.py @@ -1,11 +1,11 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer from yt.fields.species_fields import setup_species_fields class SPHFieldInfo(FieldInfoContainer): - known_other_fields = () - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("Mass", ("code_mass", ["particle_mass"], None)), ("Masses", ("code_mass", ["particle_mass"], None)), ("Coordinates", ("code_length", ["particle_position"], None)), diff --git a/yt/frontends/stream/data_structures.py b/yt/frontends/stream/data_structures.py index da3fa1e6aed..b786081c4cf 100644 --- a/yt/frontends/stream/data_structures.py +++ b/yt/frontends/stream/data_structures.py @@ -4,6 +4,7 @@ import weakref from itertools import chain, product, repeat from numbers import Number as numeric_type +from typing import Type import numpy as np from more_itertools import always_iterable @@ -19,7 +20,7 @@ from yt.data_objects.static_output import Dataset, ParticleFile from yt.data_objects.unions import MeshUnion from yt.frontends.sph.data_structures import SPHParticleIndex -from yt.geometry.geometry_handler import YTDataChunk +from yt.geometry.geometry_handler import Index, YTDataChunk from yt.geometry.grid_geometry_handler import GridIndex from yt.geometry.oct_container import OctreeContainer from yt.geometry.oct_geometry_handler import OctreeIndex @@ -262,7 +263,7 @@ def update_data(self, data): class StreamDataset(Dataset): - _index_class = StreamHierarchy + _index_class: Type[Index] = StreamHierarchy _field_info_class = StreamFieldInfo _dataset_type = "stream" diff --git a/yt/frontends/stream/fields.py b/yt/frontends/stream/fields.py index cc36d4f832f..1f14e39cfb6 100644 --- a/yt/frontends/stream/fields.py +++ b/yt/frontends/stream/fields.py @@ -1,8 +1,9 @@ +from yt._typing import KnownFieldsT from yt.fields.field_info_container import FieldInfoContainer class StreamFieldInfo(FieldInfoContainer): - known_other_fields = ( + known_other_fields: KnownFieldsT = ( ("density", ("code_mass/code_length**3", ["density"], None)), ( "dark_matter_density", @@ -46,7 +47,7 @@ class StreamFieldInfo(FieldInfoContainer): ("dii_density", ("code_mass/code_length**3", ["dii_density"], None)), ) - known_particle_fields = ( + known_particle_fields: KnownFieldsT = ( ("particle_position", ("code_length", ["particle_position"], None)), ("particle_position_x", ("code_length", ["particle_position_x"], None)), ("particle_position_y", ("code_length", ["particle_position_y"], None)), diff --git a/yt/frontends/ytdata/data_structures.py b/yt/frontends/ytdata/data_structures.py index cfdda49ccdd..5af309c52d3 100644 --- a/yt/frontends/ytdata/data_structures.py +++ b/yt/frontends/ytdata/data_structures.py @@ -2,6 +2,7 @@ import weakref from collections import defaultdict from numbers import Number as numeric_type +from typing import Tuple, Type import numpy as np @@ -14,12 +15,14 @@ ) from yt.data_objects.static_output import Dataset, ParticleFile, validate_index_order from yt.fields.field_exceptions import NeedsGridType +from yt.fields.field_info_container import FieldInfoContainer from yt.funcs import is_root, parse_h5_attr +from yt.geometry.geometry_handler import Index from yt.geometry.grid_geometry_handler import GridIndex from yt.geometry.particle_geometry_handler import ParticleIndex from yt.units import dimensions -from yt.units.unit_registry import UnitRegistry -from yt.units.yt_array import YTQuantity, uconcatenate +from yt.units.unit_registry import UnitRegistry # type: ignore +from yt.units.yt_array import YTQuantity, uconcatenate # type: ignore from yt.utilities.exceptions import GenerationInProgress, YTFieldTypeNotFound from yt.utilities.logger import ytLogger as mylog from yt.utilities.on_demand_imports import _h5py as h5py @@ -37,7 +40,7 @@ class SavedDataset(Dataset): Base dataset class for products of calling save_as_dataset. """ - _con_attrs = () + _con_attrs: Tuple[str, ...] = () def _parse_parameter_file(self): self.refine_by = 2 @@ -229,7 +232,7 @@ class YTDataContainerDataset(YTDataset): _index_class = ParticleIndex _file_class = YTDataHDF5File - _field_info_class = YTDataContainerFieldInfo + _field_info_class: Type[FieldInfoContainer] = YTDataContainerFieldInfo _suffix = ".h5" fluid_types = ("grid", "gas", "deposit", "index") @@ -476,12 +479,12 @@ def _populate_grid_objects(self): class YTGridDataset(YTDataset): """Dataset for saved covering grids, arbitrary grids, and FRBs.""" - _index_class = YTGridHierarchy + _index_class: Type[Index] = YTGridHierarchy _field_info_class = YTGridFieldInfo _dataset_type = "ytgridhdf5" geometry = "cartesian" default_fluid_type = "grid" - fluid_types = ("grid", "gas", "deposit", "index") + fluid_types: Tuple[str, ...] = ("grid", "gas", "deposit", "index") def __init__(self, filename, unit_system="cgs"): super().__init__(filename, self._dataset_type, unit_system=unit_system) @@ -716,7 +719,7 @@ class YTNonspatialDataset(YTGridDataset): _dataset_type = "ytnonspatialhdf5" geometry = "cartesian" default_fluid_type = "data" - fluid_types = ("data", "gas") + fluid_types: Tuple[str, ...] = ("data", "gas") def _parse_parameter_file(self): super(YTGridDataset, self)._parse_parameter_file() diff --git a/yt/frontends/ytdata/io.py b/yt/frontends/ytdata/io.py index a84ee843f56..9104f3207e8 100644 --- a/yt/frontends/ytdata/io.py +++ b/yt/frontends/ytdata/io.py @@ -2,7 +2,7 @@ from yt.funcs import mylog, parse_h5_attr from yt.geometry.selection_routines import GridSelector -from yt.units.yt_array import uvstack +from yt.units.yt_array import uvstack # type: ignore from yt.utilities.io_handler import BaseIOHandler from yt.utilities.on_demand_imports import _h5py as h5py diff --git a/yt/geometry/coordinates/cartesian_coordinates.py b/yt/geometry/coordinates/cartesian_coordinates.py index cc6b06efe89..56e38a1ac21 100644 --- a/yt/geometry/coordinates/cartesian_coordinates.py +++ b/yt/geometry/coordinates/cartesian_coordinates.py @@ -2,7 +2,7 @@ from yt.data_objects.index_subobjects.unstructured_mesh import SemiStructuredMesh from yt.funcs import mylog -from yt.units.yt_array import YTArray, uconcatenate, uvstack +from yt.units.yt_array import YTArray, uconcatenate, uvstack # type: ignore from yt.utilities.lib.pixelization_routines import ( interpolate_sph_grid_gather, normalization_2d_utility, diff --git a/yt/geometry/coordinates/coordinate_handler.py b/yt/geometry/coordinates/coordinate_handler.py index 95fea60f6bb..983cb19a9e4 100644 --- a/yt/geometry/coordinates/coordinate_handler.py +++ b/yt/geometry/coordinates/coordinate_handler.py @@ -129,7 +129,7 @@ def validate_sequence_width(width, ds, unit=None): class CoordinateHandler(abc.ABC): - name = None + name: str def __init__(self, ds, ordering): self.ds = weakref.proxy(ds) diff --git a/yt/geometry/geometry_handler.py b/yt/geometry/geometry_handler.py index e37f58097a2..fa6e6a42233 100644 --- a/yt/geometry/geometry_handler.py +++ b/yt/geometry/geometry_handler.py @@ -1,11 +1,12 @@ import abc import os import weakref +from typing import Tuple import numpy as np from yt.config import ytcfg -from yt.units.yt_array import YTArray, uconcatenate +from yt.units.yt_array import YTArray, uconcatenate # type: ignore from yt.utilities.exceptions import YTFieldNotFound from yt.utilities.io_handler import io_registry from yt.utilities.logger import ytLogger as mylog @@ -19,8 +20,8 @@ class Index(ParallelAnalysisInterface, abc.ABC): """The base index class""" - _unsupported_objects = () - _index_properties = () + _unsupported_objects: Tuple[str, ...] = () + _index_properties: Tuple[str, ...] = () def __init__(self, ds, dataset_type): ParallelAnalysisInterface.__init__(self) diff --git a/yt/mods.py b/yt/mods.py index d15c75db2c6..b73279d748b 100644 --- a/yt/mods.py +++ b/yt/mods.py @@ -24,7 +24,7 @@ unparsed_args = __startup_tasks.unparsed_args -if _level >= int(ytcfg_defaults["yt"]["log_level"]): +if _level >= int(ytcfg_defaults["yt"]["log_level"]): # type: ignore # This won't get displayed. mylog.debug("Turning off NumPy error reporting") np.seterr(all="ignore") diff --git a/yt/sample_data/api.py b/yt/sample_data/api.py index 83131202486..2e77133d5a4 100644 --- a/yt/sample_data/api.py +++ b/yt/sample_data/api.py @@ -65,8 +65,15 @@ def _parse_byte_size(s: str): # input is not a string (likely a np.nan) return pd.NA - val = float(re.search(num_exp, s).group()) - unit = re.search(byte_unit_exp, s).group() + match = re.search(num_exp, s) + if match is None: + raise ValueError + val = float(match.group()) + + match = re.search(byte_unit_exp, s) + if match is None: + raise ValueError + unit = match.group() prefixes = ["B", "K", "M", "G", "T"] raw_res = val * 1024 ** prefixes.index(unit[0]) return int(float(f"{raw_res:.3e}")) diff --git a/yt/startup_tasks.py b/yt/startup_tasks.py index af4eb777df5..211bc143739 100644 --- a/yt/startup_tasks.py +++ b/yt/startup_tasks.py @@ -4,6 +4,7 @@ import os import signal import sys +from typing import List from yt.config import ytcfg from yt.funcs import ( @@ -135,7 +136,7 @@ def error(self, message): if not hasattr(sys, "argv") or sys.argv is None: sys.argv = [] -unparsed_args = [] +unparsed_args: List[str] = [] parallel_capable = False if not ytcfg.get("yt", "internals", "command_line"): diff --git a/yt/units/unit_registry.py b/yt/units/unit_registry.py index 2f8d840aa49..bc3c1cf4316 100644 --- a/yt/units/unit_registry.py +++ b/yt/units/unit_registry.py @@ -1,6 +1,6 @@ from unyt.dimensions import dimensionless from unyt.unit_registry import * -default_unit_registry = UnitRegistry(unit_system="cgs") +default_unit_registry = UnitRegistry(unit_system="cgs") # type: ignore default_unit_registry.add("h", 1.0, dimensionless, tex_repr=r"h") diff --git a/yt/utilities/command_line.py b/yt/utilities/command_line.py index b4320baf4ae..4fd72128d2d 100644 --- a/yt/utilities/command_line.py +++ b/yt/utilities/command_line.py @@ -8,6 +8,7 @@ import textwrap import urllib import urllib.request +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from more_itertools import always_iterable @@ -44,7 +45,6 @@ pass _default_colormap = ytcfg.get("yt", "default_colormap") -_arg_groups = {} def _fix_ds(arg, *args, **kwargs): @@ -187,12 +187,12 @@ def __init__(cls, name, b, d): class YTCommand(metaclass=YTCommandSubtype): - args = () - name = None - description = "" + args: Tuple[Union[str, Dict[str, Any]], ...] = () + name: Optional[Union[str, List[str]]] = None + description: str = "" aliases = () - ndatasets = 1 - subparser = None + ndatasets: int = 1 + subparser: Optional[str] = None @classmethod def run(cls, args): @@ -1360,21 +1360,18 @@ def load_config(self, args): self.config_file = config_file -_global_local_args = [ - ( - "exclusive", - dict( - short="--local", - action="store_true", - help="Store the configuration in the local configuration file.", - ), - dict( - short="--global", - action="store_true", - help="Store the configuration in the global configuration file.", - ), +_global_local_args = ( + dict( + short="--local", + action="store_true", + help="Store the configuration in the local configuration file.", ), -] + dict( + short="--global", + action="store_true", + help="Store the configuration in the global configuration file.", + ), +) class YTConfigGetCmd(YTCommand, YTConfigLocalConfigHandler): diff --git a/yt/utilities/cosmology.py b/yt/utilities/cosmology.py index 9446bf2129a..b69fbaa1837 100644 --- a/yt/utilities/cosmology.py +++ b/yt/utilities/cosmology.py @@ -3,8 +3,8 @@ import numpy as np from yt.units import dimensions -from yt.units.unit_object import Unit -from yt.units.unit_registry import UnitRegistry +from yt.units.unit_object import Unit # type: ignore +from yt.units.unit_registry import UnitRegistry # type: ignore from yt.units.yt_array import YTArray, YTQuantity from yt.utilities.physical_constants import ( gravitational_constant_cgs as G, diff --git a/yt/utilities/io_handler.py b/yt/utilities/io_handler.py index ed45fb95c77..a9bd252aab4 100644 --- a/yt/utilities/io_handler.py +++ b/yt/utilities/io_handler.py @@ -2,6 +2,7 @@ from collections import defaultdict from contextlib import contextmanager from functools import _make_key, lru_cache +from typing import Tuple import numpy as np @@ -20,8 +21,8 @@ def _make_io_key(args, *_args, **kwargs): class BaseIOHandler: - _vector_fields = () - _dataset_type = None + _vector_fields: Tuple[str, ...] = () + _dataset_type: str _particle_reader = False _cache_on = False _misses = 0 diff --git a/yt/utilities/lib/cykdtree/tests/test_utils.py b/yt/utilities/lib/cykdtree/tests/test_utils.py index d677be2493b..f3341f09b5c 100644 --- a/yt/utilities/lib/cykdtree/tests/test_utils.py +++ b/yt/utilities/lib/cykdtree/tests/test_utils.py @@ -1,7 +1,7 @@ import numpy as np from nose.tools import assert_equal -from yt.utilities.lib.cykdtree import utils +from yt.utilities.lib.cykdtree import utils # type: ignore from yt.utilities.lib.cykdtree.tests import assert_less_equal, parametrize diff --git a/yt/utilities/minimal_representation.py b/yt/utilities/minimal_representation.py index 58b8f8211ef..918038dfc0d 100644 --- a/yt/utilities/minimal_representation.py +++ b/yt/utilities/minimal_representation.py @@ -1,6 +1,7 @@ import abc import json import os +from typing import Tuple from uuid import uuid4 import numpy as np @@ -180,7 +181,7 @@ def _generate_post(self): class MinimalMappableData(MinimalRepresentation): - _attr_list = ( + _attr_list: Tuple[str, ...] = ( "field_data", "field", "weight_field", diff --git a/yt/utilities/object_registries.py b/yt/utilities/object_registries.py index f80055849d5..fde2edd5619 100644 --- a/yt/utilities/object_registries.py +++ b/yt/utilities/object_registries.py @@ -1,8 +1,21 @@ # These are some of the data object registries that are used in different places in the # code. Not all of the self-registering objects are included in these. -analysis_task_registry = {} -data_object_registry = {} -derived_quantity_registry = {} -output_type_registry = {} -simulation_time_series_registry = {} +# type hints are simplified as raw Type (instead of, e.g., Type[Dataset]) +# to workaround circular imports +from typing import Dict, Type + +# subclasses of yt.data_objects.analyzer_objects.AnalysisTask +analysis_task_registry: Dict[str, Type] = {} + +# subclasses of yt.data_objects.data_containers.YTDataContainer +data_object_registry: Dict[str, Type] = {} + +# suclasses of yt.data_objects.derived_quantity.DerivedQuantity +derived_quantity_registry: Dict[str, Type] = {} + +# suclasses of yt.data_objects.static_outputs.Dataset +output_type_registry: Dict[str, Type] = {} + +# subclasses of yt.data_objects.time_series.DatasetSeries +simulation_time_series_registry: Dict[str, Type] = {} diff --git a/yt/utilities/parallel_tools/parallel_analysis_interface.py b/yt/utilities/parallel_tools/parallel_analysis_interface.py index 57728c00837..d50f5e5224a 100644 --- a/yt/utilities/parallel_tools/parallel_analysis_interface.py +++ b/yt/utilities/parallel_tools/parallel_analysis_interface.py @@ -5,7 +5,7 @@ import traceback from functools import wraps from io import StringIO -from typing import Optional +from typing import List import numpy as np from more_itertools import always_iterable @@ -14,7 +14,7 @@ from yt.config import ytcfg from yt.data_objects.image_array import ImageArray from yt.funcs import is_sequence -from yt.units.unit_registry import UnitRegistry +from yt.units.unit_registry import UnitRegistry # type: ignore from yt.units.yt_array import YTArray from yt.utilities.exceptions import YTNoDataInObjectError from yt.utilities.lib.quad_tree import QuadTree, merge_quadtrees @@ -69,9 +69,7 @@ def default_mpi_excepthook(exception_type, exception_value, tb): MPI.COMM_WORLD.Abort(1) -def enable_parallelism( - suppress_logging: bool = False, communicator: Optional["MPI.Comm"] = None -) -> bool: +def enable_parallelism(suppress_logging: bool = False, communicator=None) -> bool: """ This method is used inside a script to turn on MPI parallelism, via mpi4py. More information about running yt in parallel can be found @@ -123,7 +121,7 @@ def enable_parallelism( ytcfg["yt", "internals", "global_parallel_rank"] = communicator.rank ytcfg["yt", "internals", "global_parallel_size"] = communicator.size ytcfg["yt", "internals", "parallel"] = True - if exe_name == "embed_enzo" or ("_parallel" in dir(sys) and sys._parallel): + if exe_name == "embed_enzo" or ("_parallel" in dir(sys) and sys._parallel): # type: ignore ytcfg["yt", "inline"] = True yt.utilities.logger.uncolorize_logging() # Even though the uncolorize function already resets the format string, @@ -665,7 +663,7 @@ def parallel_ring(objects, generator_func, mutable=False): class CommunicationSystem: - communicators = [] + communicators: List["Communicator"] = [] def __init__(self): self.communicators.append(Communicator(None)) diff --git a/yt/utilities/parameter_file_storage.py b/yt/utilities/parameter_file_storage.py index f9e733e5157..20020eeaaa4 100644 --- a/yt/utilities/parameter_file_storage.py +++ b/yt/utilities/parameter_file_storage.py @@ -37,7 +37,7 @@ class ParameterFileStore: """ - _shared_state = {} + _shared_state = {} # type: ignore _distributed = True _processing = False _owner = 0 diff --git a/yt/utilities/particle_generator.py b/yt/utilities/particle_generator.py index 2a9564cca29..18a2b24a8c7 100644 --- a/yt/utilities/particle_generator.py +++ b/yt/utilities/particle_generator.py @@ -1,7 +1,7 @@ import numpy as np from yt.funcs import get_pbar -from yt.units.yt_array import uconcatenate +from yt.units.yt_array import uconcatenate # type: ignore from yt.utilities.lib.particle_mesh_operations import CICSample_3 diff --git a/yt/utilities/performance_counters.py b/yt/utilities/performance_counters.py index d0ca2c1f4d5..5dcfc5ac579 100644 --- a/yt/utilities/performance_counters.py +++ b/yt/utilities/performance_counters.py @@ -10,7 +10,7 @@ class PerformanceCounters: - _shared_state = {} + _shared_state = {} # type: ignore def __new__(cls, *args, **kwargs): self = object.__new__(cls, *args, **kwargs) diff --git a/yt/visualization/eps_writer.py b/yt/visualization/eps_writer.py index b0f7a40f7f5..e3aa1ffc1e1 100644 --- a/yt/visualization/eps_writer.py +++ b/yt/visualization/eps_writer.py @@ -5,7 +5,7 @@ from matplotlib import cm, pyplot as plt from yt.config import ytcfg -from yt.units.unit_object import Unit +from yt.units.unit_object import Unit # type: ignore from yt.units.yt_array import YTQuantity from yt.utilities.logger import ytLogger as mylog diff --git a/yt/visualization/fits_image.py b/yt/visualization/fits_image.py index bc65793586c..7aa8bc3c4e5 100644 --- a/yt/visualization/fits_image.py +++ b/yt/visualization/fits_image.py @@ -10,7 +10,7 @@ from yt.fields.derived_field import DerivedField from yt.funcs import fix_axis, is_sequence, iter_fields, mylog from yt.units import dimensions -from yt.units.unit_object import Unit +from yt.units.unit_object import Unit # type: ignore from yt.units.yt_array import YTArray, YTQuantity from yt.utilities.on_demand_imports import _astropy from yt.utilities.parallel_tools.parallel_analysis_interface import parallel_root_only diff --git a/yt/visualization/fixed_resolution.py b/yt/visualization/fixed_resolution.py index 7244d426e0e..db365e340f9 100644 --- a/yt/visualization/fixed_resolution.py +++ b/yt/visualization/fixed_resolution.py @@ -7,7 +7,10 @@ from yt.frontends.ytdata.utilities import save_as_dataset from yt.funcs import get_output_filename, iter_fields, mylog from yt.loaders import load_uniform_grid -from yt.utilities.lib.api import CICDeposit_2, add_points_to_greyscale_image +from yt.utilities.lib.api import ( # type: ignore + CICDeposit_2, + add_points_to_greyscale_image, +) from yt.utilities.lib.pixelization_routines import pixelize_cylinder from yt.utilities.on_demand_imports import _h5py as h5py diff --git a/yt/visualization/geo_plot_utils.py b/yt/visualization/geo_plot_utils.py index fc2c933e89e..db6716e42e8 100644 --- a/yt/visualization/geo_plot_utils.py +++ b/yt/visualization/geo_plot_utils.py @@ -1,4 +1,7 @@ -valid_transforms = {} +from types import FunctionType +from typing import Any, Dict, Optional, Tuple + +valid_transforms: Dict[str, FunctionType] = {} transform_list = [ "PlateCarree", @@ -40,7 +43,7 @@ def _func(*args, **kwargs): return _func -def get_mpl_transform(mpl_proj): +def get_mpl_transform(mpl_proj) -> Optional[FunctionType]: r"""This returns an instantiated transform function given a transform function name and arguments. @@ -67,11 +70,11 @@ def get_mpl_transform(mpl_proj): # check to see if mpl_proj is a string or tuple, and construct args and # kwargs to pass to cartopy function based on that. - key = None + key: Optional[str] = None + args: Tuple = () + kwargs: Dict[str, Any] = {} if isinstance(mpl_proj, str): key = mpl_proj - args = () - kwargs = {} instantiated_func = valid_transforms[key](*args, **kwargs) elif isinstance(mpl_proj, tuple): if len(mpl_proj) == 2: @@ -79,6 +82,12 @@ def get_mpl_transform(mpl_proj): kwargs = {} elif len(mpl_proj) == 3: key, args, kwargs = mpl_proj + else: + raise ValueError(f"Expected a tuple with len 2 or 3, received {mpl_proj}") + if not isinstance(key, str): + raise TypeError( + f"Expected a string a the first element in mpl_proj, got {key!r}" + ) instantiated_func = valid_transforms[key](*args, **kwargs) elif hasattr(mpl_proj, "globe"): # cartopy transforms have a globe method associated with them diff --git a/yt/visualization/line_plot.py b/yt/visualization/line_plot.py index 34838a39c9a..481ce94390a 100644 --- a/yt/visualization/line_plot.py +++ b/yt/visualization/line_plot.py @@ -3,7 +3,7 @@ import numpy as np from yt.funcs import is_sequence, mylog -from yt.units.unit_object import Unit +from yt.units.unit_object import Unit # type: ignore from yt.units.yt_array import YTArray from yt.visualization.base_plot_types import PlotMPL from yt.visualization.plot_container import ( diff --git a/yt/visualization/particle_plots.py b/yt/visualization/particle_plots.py index 4051caf9e6c..d75faea903b 100644 --- a/yt/visualization/particle_plots.py +++ b/yt/visualization/particle_plots.py @@ -1,3 +1,5 @@ +from typing import List + import numpy as np from yt.data_objects.profiles import create_profile @@ -14,7 +16,7 @@ class ParticleAxisAlignedDummyDataSource: _dimensionality = 2 _con_args = ("center", "axis", "width", "fields", "weight_field") _tds_attrs = () - _key_fields = [] + _key_fields: List[str] = [] def __init__( self, diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 1868f2908ab..27aa21eae40 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -1,3 +1,4 @@ +import abc import base64 import builtins import os @@ -5,7 +6,7 @@ import warnings from collections import defaultdict from functools import wraps -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from matplotlib.cm import get_cmap @@ -16,7 +17,7 @@ from yt.data_objects.time_series import DatasetSeries from yt.funcs import dictWithFactory, ensure_dir, is_sequence, iter_fields, mylog from yt.units import YTQuantity -from yt.units.unit_object import Unit +from yt.units.unit_object import Unit # type: ignore from yt.utilities.definitions import formatted_length_unit_names from yt.utilities.exceptions import YTNotInsideNotebook @@ -209,10 +210,10 @@ def __init__(self, data_source, default_factory=None): return defaultdict.__init__(self, default_factory) -class PlotContainer: +class PlotContainer(abc.ABC): """A container for generic plots""" - _plot_type = None + _plot_type: Optional[str] = None _plot_valid = False # Plot defaults @@ -521,7 +522,7 @@ def set_figure_size(self, size): @validate_plot def save( self, - name: Optional[str] = None, + name: Optional[Union[str, List[str], Tuple[str, ...]]] = None, suffix: str = ".png", mpl_kwargs: Optional[Dict[str, Any]] = None, ): @@ -557,9 +558,11 @@ def save( if name is None: name = str(self.ds) - - # ///// Magic area. Muggles, keep out ! - if isinstance(name, (tuple, list)): + elif isinstance(name, (list, tuple)): + if not all(isinstance(_, str) for _ in name): + raise TypeError( + f"Expected a single str or an iterable of str, got {name!r}" + ) name = os.path.join(*name) name = os.path.expanduser(name) @@ -575,7 +578,8 @@ def save( new_name = validate_image_name(name, suffix) if new_name == name: - for v in self.plots.values(): + # somehow mypy thinks we may not have a plots attr yet, hence we turn it off here + for v in self.plots.values(): # type: ignore out_name = v.save(name, mpl_kwargs) names.append(out_name) return names @@ -595,10 +599,17 @@ def save( weight = weight[1].replace(" ", "_") if "Cutting" in self.data_source.__class__.__name__: plot_type = "OffAxisSlice" - for k, v in self.plots.items(): + + # somehow mypy thinks we may not have a plots attr yet, hence we turn it off here + for k, v in self.plots.items(): # type: ignore if isinstance(k, tuple): k = k[1] + if plot_type is None: + # implemented this check to make mypy happy, because we can't use str.join + # with PlotContainer._plot_type = None + raise TypeError(f"{self.__class__} is missing a _plot_type value (str)") + name_elements = [prefix, plot_type] if axis: name_elements.append(axis) diff --git a/yt/visualization/plot_modifications.py b/yt/visualization/plot_modifications.py index 23aaa5091ac..5ece51d6e8c 100644 --- a/yt/visualization/plot_modifications.py +++ b/yt/visualization/plot_modifications.py @@ -2,6 +2,7 @@ import warnings from functools import wraps from numbers import Number +from typing import Tuple import matplotlib import numpy as np @@ -15,7 +16,7 @@ from yt.geometry.geometry_handler import is_curvilinear from yt.geometry.unstructured_mesh_handler import UnstructuredIndex from yt.units import dimensions -from yt.units.yt_array import YTArray, YTQuantity, uhstack +from yt.units.yt_array import YTArray, YTQuantity, uhstack # type: ignore from yt.utilities.exceptions import YTDataTypeUnsupported from yt.utilities.lib.geometry_utils import triangle_plane_intersect from yt.utilities.lib.line_integral_convolution import line_integral_convolution_2d @@ -54,7 +55,7 @@ class PlotCallback: # "figure" this is disregarded. If "force" is included in the tuple, it # will *not* check whether or not the coord_system is in axis or figure, # and will only look at the geometries. - _supported_geometries = None + _supported_geometries: Tuple[str, ...] def __init_subclass__(cls, *args, **kwargs): super().__init_subclass__(*args, **kwargs) @@ -1042,7 +1043,12 @@ class LinePlotCallback(PlotCallback): """ _type_name = "line" - _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical") + _supported_geometries: Tuple[str, ...] = ( + "cartesian", + "spectral_cube", + "polar", + "cylindrical", + ) def __init__(self, p1, p2, data_coords=False, coord_system="data", plot_args=None): PlotCallback.__init__(self) @@ -1707,7 +1713,12 @@ class TextLabelCallback(PlotCallback): """ _type_name = "text" - _supported_geometries = ("cartesian", "spectral_cube", "polar", "cylindrical") + _supported_geometries: Tuple[str, ...] = ( + "cartesian", + "spectral_cube", + "polar", + "cylindrical", + ) def __init__( self, diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 9caa64c606a..3301527944e 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -2,7 +2,7 @@ from collections import defaultdict from functools import wraps from numbers import Number -from typing import Union +from typing import List, Optional, Type, Union import matplotlib import matplotlib.pyplot as plt @@ -16,8 +16,8 @@ from yt.data_objects.image_array import ImageArray from yt.frontends.ytdata.data_structures import YTSpatialPlotDataset from yt.funcs import fix_axis, fix_unitary, is_sequence, iter_fields, mylog, obj_length -from yt.units.unit_object import Unit -from yt.units.unit_registry import UnitParseError +from yt.units.unit_object import Unit # type: ignore +from yt.units.unit_registry import UnitParseError # type: ignore from yt.units.yt_array import YTArray, YTQuantity from yt.utilities.exceptions import ( YTCannotParseUnitDisplayName, @@ -270,26 +270,22 @@ def piter(self, *args, **kwargs): _frb = None - def frb(): - doc = "The frb property." - - def fget(self): - if self._frb is None or not self._data_valid: - self._recreate_frb() - return self._frb - - def fset(self, value): - self._frb = value - self._data_valid = True - - def fdel(self): - del self._frb - self._frb = None - self._data_valid = False + @property + def frb(self): + if self._frb is None or not self._data_valid: + self._recreate_frb() + return self._frb - return locals() + @frb.setter + def frb(self, value): + self._frb = value + self._data_valid = True - frb = property(**frb()) + @frb.deleter + def frb(self): + del self._frb + self._frb = None + self._data_valid = False def _recreate_frb(self): old_fields = None @@ -849,8 +845,8 @@ class PWViewerMPL(PlotWindow): """Viewer using matplotlib as a backend via the WindowPlotMPL.""" _current_field = None - _frb_generator = None - _plot_type = None + _frb_generator: Optional[Type[FixedResolutionBuffer]] = None + _plot_type: Optional[str] = None _data_valid = False def __init__(self, *args, **kwargs): @@ -1487,7 +1483,7 @@ def sanitize_normal_vector(ds, normal) -> Union[str, np.ndarray]: ) return normal - if isinstance(normal, (int, np.integer)): + if isinstance(normal, int): if normal not in (0, 1, 2): raise ValueError( f"{normal} is not a valid axis identifier. Expected either 0, 1, or 2." @@ -1661,7 +1657,11 @@ class SlicePlot(NormalPlot): """ - def __new__( + # ignoring type check here, because mypy doesn't allow __new__ methods to + # return instances of subclasses. The design we use here is however based + # on the pathlib.Path class from the standard library + # https://github.com/python/mypy/issues/1020 + def __new__( # type: ignore cls, ds, normal, fields, *args, **kwargs ) -> Union["AxisAlignedSlicePlot", "OffAxisSlicePlot"]: if cls is SlicePlot: @@ -1718,7 +1718,11 @@ class ProjectionPlot(NormalPlot): """ - def __new__( + # ignoring type check here, because mypy doesn't allow __new__ methods to + # return instances of subclasses. The design we use here is however based + # on the pathlib.Path class from the standard library + # https://github.com/python/mypy/issues/1020 + def __new__( # type: ignore cls, ds, normal, fields, *args, **kwargs ) -> Union["AxisAlignedProjectionPlot", "OffAxisProjectionPlot"]: if cls is ProjectionPlot: @@ -2300,7 +2304,7 @@ def __init__( class OffAxisProjectionDummyDataSource: _type_name = "proj" - _key_fields = [] + _key_fields: List[str] = [] def __init__( self, diff --git a/yt/visualization/tests/test_splat.py b/yt/visualization/tests/test_splat.py index 8871b34f77a..a7d1b13dc70 100644 --- a/yt/visualization/tests/test_splat.py +++ b/yt/visualization/tests/test_splat.py @@ -7,7 +7,7 @@ import yt from yt.testing import assert_equal -from yt.utilities.lib.api import add_rgba_points_to_image +from yt.utilities.lib.api import add_rgba_points_to_image # type: ignore def setup(): diff --git a/yt/visualization/volume_rendering/camera.py b/yt/visualization/volume_rendering/camera.py index 2c208ce627f..34bcaaa853a 100644 --- a/yt/visualization/volume_rendering/camera.py +++ b/yt/visualization/volume_rendering/camera.py @@ -151,8 +151,9 @@ def __init__(self, scene, data_source=None, lens_type="plane-parallel", auto=Fal self.set_lens(lens_type) - def position(): - doc = """ + @property + def position(self): + r""" The location of the camera. Parameters @@ -163,31 +164,28 @@ def position(): coordinates. If an iterable, must contain only scalars or (length, unit) tuples. """ + return self._position - def fget(self): - return self._position - - def fset(self, value): - position = _sanitize_camera_property_units(value, self.scene) - if np.array_equal(position, self.focus): - raise RuntimeError( - "Cannot set the camera focus and position to the same value" - ) - self._position = position - self.switch_orientation( - normal_vector=self.focus - self._position, - north_vector=self.north_vector, + @position.setter + def position(self, value): + position = _sanitize_camera_property_units(value, self.scene) + if np.array_equal(position, self.focus): + raise RuntimeError( + "Cannot set the camera focus and position to the same value" ) + self._position = position + self.switch_orientation( + normal_vector=self.focus - self._position, + north_vector=self.north_vector, + ) - def fdel(self): - del self._position - - return locals() - - position = property(**position()) + @position.deleter + def position(self): + del self._position - def width(): - doc = """The width of the region that will be seen in the image. + @property + def width(self): + r"""The width of the region that will be seen in the image. Parameters ---------- @@ -198,25 +196,22 @@ def width(): all three directions. If an iterable, must contain only scalars or (length, unit) tuples. """ + return self._width - def fget(self): - return self._width - - def fset(self, value): - width = _sanitize_camera_property_units(value, self.scene) - self._width = width - self.switch_orientation() - - def fdel(self): - del self._width - self._width = None - - return locals() + @width.setter + def width(self, value): + width = _sanitize_camera_property_units(value, self.scene) + self._width = width + self.switch_orientation() - width = property(**width()) + @width.deleter + def width(self): + del self._width + self._width = None - def focus(): - doc = """ + @property + def focus(self): + r""" The focus defines the point the Camera is pointed at. Parameters @@ -228,50 +223,43 @@ def focus(): all three directions. If an iterable, must contain only scalars or (length, unit) tuples. """ + return self._focus - def fget(self): - return self._focus - - def fset(self, value): - focus = _sanitize_camera_property_units(value, self.scene) - if np.array_equal(focus, self.position): - raise RuntimeError( - "Cannot set the camera focus and position to the same value" - ) - self._focus = focus - self.switch_orientation( - normal_vector=self.focus - self._position, north_vector=None + @focus.setter + def focus(self, value): + focus = _sanitize_camera_property_units(value, self.scene) + if np.array_equal(focus, self.position): + raise RuntimeError( + "Cannot set the camera focus and position to the same value" ) + self._focus = focus + self.switch_orientation( + normal_vector=self.focus - self._position, north_vector=None + ) - def fdel(self): - del self._focus - - return locals() - - focus = property(**focus()) - - def resolution(): - doc = """The resolution is the number of pixels in the image that - will be produced. Must be a 2-tuple of integers or an integer.""" - - def fget(self): - return self._resolution - - def fset(self, value): - if is_sequence(value): - if len(value) != 2: - raise RuntimeError - else: - value = (value, value) - self._resolution = value - - def fdel(self): - del self._resolution - self._resolution = None - - return locals() + @focus.deleter + def focus(self): + del self._focus + + @property + def resolution(self): + r"""The resolution is the number of pixels in the image that + will be produced. Must be a 2-tuple of integers or an integer.""" + return self._resolution + + @resolution.setter + def resolution(self, value): + if is_sequence(value): + if len(value) != 2: + raise RuntimeError + else: + value = (value, value) + self._resolution = value - resolution = property(**resolution()) + @resolution.deleter + def resolution(self): + del self._resolution + self._resolution = None def set_resolution(self, resolution): """ diff --git a/yt/visualization/volume_rendering/lens.py b/yt/visualization/volume_rendering/lens.py index d18b9fc734c..ecc61839a14 100644 --- a/yt/visualization/volume_rendering/lens.py +++ b/yt/visualization/volume_rendering/lens.py @@ -1,7 +1,7 @@ import numpy as np from yt.data_objects.image_array import ImageArray -from yt.units.yt_array import uhstack, unorm, uvstack +from yt.units.yt_array import uhstack, unorm, uvstack # type: ignore from yt.utilities.lib.grid_traversal import arr_fisheye_vectors from yt.utilities.math_utils import get_rotation_matrix from yt.utilities.parallel_tools.parallel_analysis_interface import ( diff --git a/yt/visualization/volume_rendering/off_axis_projection.py b/yt/visualization/volume_rendering/off_axis_projection.py index b371556cee6..110a6f11fae 100644 --- a/yt/visualization/volume_rendering/off_axis_projection.py +++ b/yt/visualization/volume_rendering/off_axis_projection.py @@ -2,7 +2,7 @@ from yt.data_objects.api import ImageArray from yt.funcs import is_sequence, mylog -from yt.units.unit_object import Unit +from yt.units.unit_object import Unit # type: ignore from yt.utilities.lib.partitioned_grid import PartitionedGrid from yt.utilities.lib.pixelization_routines import ( normalization_2d_utility, diff --git a/yt/visualization/volume_rendering/old_camera.py b/yt/visualization/volume_rendering/old_camera.py index e245ce01138..f94e49d2fe7 100644 --- a/yt/visualization/volume_rendering/old_camera.py +++ b/yt/visualization/volume_rendering/old_camera.py @@ -1,5 +1,6 @@ import builtins from copy import deepcopy +from typing import List import numpy as np @@ -1185,7 +1186,7 @@ def rotation(self, theta, n_steps, rot_vector=None, clip_ratio=None): class InteractiveCamera(Camera): - frames = [] + frames: List[ImageArray] = [] def snapshot(self, fn=None, clip_ratio=None): self._pyplot.figure(2) diff --git a/yt/visualization/volume_rendering/render_source.py b/yt/visualization/volume_rendering/render_source.py index 969badccc8e..2f42ef7afd6 100644 --- a/yt/visualization/volume_rendering/render_source.py +++ b/yt/visualization/volume_rendering/render_source.py @@ -1,5 +1,6 @@ import abc from functools import wraps +from typing import Optional import numpy as np @@ -36,13 +37,13 @@ from .zbuffer_array import ZBuffer try: - from yt.utilities.lib.embree_mesh import mesh_traversal + from yt.utilities.lib.embree_mesh import mesh_traversal # type: ignore # Catch ValueError in case size of objects in Cython change except (ImportError, ValueError): mesh_traversal = NotAModule("pyembree") ytcfg["yt", "ray_tracing_engine"] = "yt" try: - from yt.utilities.lib.embree_mesh import mesh_construction + from yt.utilities.lib.embree_mesh import mesh_construction # type: ignore # Catch ValueError in case size of objects in Cython change except (ImportError, ValueError): mesh_construction = NotAModule("pyembree") @@ -84,15 +85,14 @@ def wrapper(*args, **kwargs): return wrapper -class RenderSource(ParallelAnalysisInterface): - +class RenderSource(ParallelAnalysisInterface, abc.ABC): """Base Class for Render Sources. Will be inherited for volumes, streamlines, etc. """ - volume_method = None + volume_method: Optional[str] = None def __init__(self): super().__init__() @@ -182,7 +182,6 @@ class VolumeSource(RenderSource, abc.ABC): _image = None data_source = None - volume_method = None def __init__(self, data_source, field): r"""Initialize a new volumetric source for rendering.""" @@ -278,7 +277,7 @@ def field(self): """The field to be rendered""" return self._field - @field.setter + @field.setter # type: ignore @invalidate_volume def field(self, value): field = self.data_source._determine_fields(value) @@ -304,7 +303,7 @@ def log_field(self): """Whether or not the field rendering is computed in log space""" return self._log_field - @log_field.setter + @log_field.setter # type: ignore @invalidate_volume def log_field(self, value): self.transfer_function = None @@ -317,7 +316,7 @@ def use_ghost_zones(self): values at grid boundaries""" return self._use_ghost_zones - @use_ghost_zones.setter + @use_ghost_zones.setter # type: ignore @invalidate_volume def use_ghost_zones(self, value): self._use_ghost_zones = value @@ -330,7 +329,7 @@ def weight_field(self): """ return self._weight_field - @weight_field.setter + @weight_field.setter # type: ignore @invalidate_volume def weight_field(self, value): self._weight_field = value @@ -701,26 +700,23 @@ def __init__(self, data_source, field): "Invalid ray-tracing engine selected. Choices are 'embree' and 'yt'." ) - def cmap(): + @property + def cmap(self): """ This is the name of the colormap that will be used when rendering this MeshSource object. Should be a string, like 'cmyt.arbre', or 'cmyt.dusk'. """ + return self._cmap - def fget(self): - return self._cmap - - def fset(self, cmap_name): - self._cmap = cmap_name - if hasattr(self, "data"): - self.current_image = self.apply_colormap() + @cmap.setter + def cmap(self, cmap_name): + self._cmap = cmap_name + if hasattr(self, "data"): + self.current_image = self.apply_colormap() - return locals() - - cmap = property(**cmap()) - - def color_bounds(): + @property + def color_bounds(self): """ These are the bounds that will be used with the colormap to the display the rendered image. Should be a (vmin, vmax) tuple, like (0.0, 2.0). If @@ -728,18 +724,13 @@ def color_bounds(): the rendered data. """ + return self._color_bounds - def fget(self): - return self._color_bounds - - def fset(self, bounds): - self._color_bounds = bounds - if hasattr(self, "data"): - self.current_image = self.apply_colormap() - - return locals() - - color_bounds = property(**color_bounds()) + @color_bounds.setter + def color_bounds(self, bounds): + self._color_bounds = bounds + if hasattr(self, "data"): + self.current_image = self.apply_colormap() def _validate(self): """Make sure that all dependencies have been met""" @@ -1009,6 +1000,9 @@ def __init__(self, positions, colors=None, color_stride=1, radii=None): self.color_stride = color_stride self.radii = radii + def _validate(self): + pass + def render(self, camera, zbuffer=None): """Renders an image using the provided camera @@ -1131,6 +1125,9 @@ def __init__(self, positions, colors=None, color_stride=1): self.colors = colors self.color_stride = color_stride + def _validate(self): + pass + def render(self, camera, zbuffer=None): """Renders an image using the provided camera @@ -1258,6 +1255,9 @@ def __init__(self, left_edge, right_edge, color=None): super().__init__(vertices, color, color_stride=24) + def _validate(self): + pass + class GridSource(LineSource): r"""A render source for drawing grids in a scene. @@ -1418,6 +1418,9 @@ def __init__(self, colors=None, alpha=1.0): colors[:, 3] = alpha self.colors = colors + def _validate(self): + pass + def render(self, camera, zbuffer=None): """Renders an image using the provided camera diff --git a/yt/visualization/volume_rendering/scene.py b/yt/visualization/volume_rendering/scene.py index 90f9b449107..8ac8cf83d22 100644 --- a/yt/visualization/volume_rendering/scene.py +++ b/yt/visualization/volume_rendering/scene.py @@ -7,8 +7,8 @@ from yt.config import ytcfg from yt.funcs import mylog -from yt.units.dimensions import length -from yt.units.unit_registry import UnitRegistry +from yt.units.dimensions import length # type: ignore +from yt.units.unit_registry import UnitRegistry # type: ignore from yt.units.yt_array import YTArray, YTQuantity from yt.utilities.exceptions import YTNotInsideNotebook from yt.visualization._commons import get_canvas, validate_image_name @@ -248,7 +248,7 @@ def _render_on_demand(self, render): def _get_render_sources(self): return [s for s in self.sources.values() if isinstance(s, RenderSource)] - def _setup_save(self, fname, render): + def _setup_save(self, fname, render) -> str: self._render_on_demand(render) @@ -363,7 +363,7 @@ def save_annotated( self, fname: Optional[str] = None, label_fmt: Optional[str] = None, - text_annotate: Optional[str] = None, + text_annotate=None, dpi: int = 100, sigma_clip: Optional[float] = None, render: bool = True, @@ -620,60 +620,55 @@ def add_camera(self, data_source=None, lens_type="plane-parallel", auto=False): self._camera = Camera(self, data_source, lens_type, auto) return self.camera - def camera(): - doc = r"""The camera property. + @property + def camera(self): + r"""The camera property. This is the default camera that will be used when rendering. Can be set manually, but Camera type will be checked for validity. """ + return self._camera - def fget(self): - return self._camera - - def fset(self, value): - value.width = self.arr(value.width) - value.focus = self.arr(value.focus) - value.position = self.arr(value.position) - self._camera = value - - def fdel(self): - del self._camera - self._camera = None - - return locals() - - camera = property(**camera()) - - def unit_registry(): - def fget(self): - ur = self._unit_registry - if ur is None: - ur = UnitRegistry() - # This will be updated when we add a volume source - ur.add("unitary", 1.0, length) - self._unit_registry = ur - return self._unit_registry - - def fset(self, value): - self._unit_registry = value - if self.camera is not None: - self.camera.width = YTArray( - self.camera.width.in_units("unitary"), registry=value - ) - self.camera.focus = YTArray( - self.camera.focus.in_units("unitary"), registry=value - ) - self.camera.position = YTArray( - self.camera.position.in_units("unitary"), registry=value - ) + @camera.setter + def camera(self, value): + value.width = self.arr(value.width) + value.focus = self.arr(value.focus) + value.position = self.arr(value.position) + self._camera = value - def fdel(self): - del self._unit_registry - self._unit_registry = None + @camera.deleter + def camera(self): + del self._camera + self._camera = None - return locals() + @property + def unit_registry(self): + ur = self._unit_registry + if ur is None: + ur = UnitRegistry() + # This will be updated when we add a volume source + ur.add("unitary", 1.0, length) + self._unit_registry = ur + return self._unit_registry + + @unit_registry.setter + def unit_registry(self, value): + self._unit_registry = value + if self.camera is not None: + self.camera.width = YTArray( + self.camera.width.in_units("unitary"), registry=value + ) + self.camera.focus = YTArray( + self.camera.focus.in_units("unitary"), registry=value + ) + self.camera.position = YTArray( + self.camera.position.in_units("unitary"), registry=value + ) - unit_registry = property(**unit_registry()) + @unit_registry.deleter + def unit_registry(self): + del self._unit_registry + self._unit_registry = None def set_camera(self, camera): r""" diff --git a/yt/visualization/volume_rendering/tests/test_zbuff.py b/yt/visualization/volume_rendering/tests/test_zbuff.py index 224458c922c..222ead86526 100644 --- a/yt/visualization/volume_rendering/tests/test_zbuff.py +++ b/yt/visualization/volume_rendering/tests/test_zbuff.py @@ -14,6 +14,15 @@ ) +class FakeOpaqueSource(OpaqueSource): + # A minimal (mock) concrete implementation of OpaqueSource + def render(self, camera, zbuffer=None): + pass + + def _validate(self): + pass + + def setup(): """Test specific setup.""" from yt.config import ytcfg @@ -65,7 +74,7 @@ def test_composite_vr(self): empty[:, :, 2] = 1.0 # Set blue to 1's empty[:, :, 3] = 1.0 # Set alpha to 1's zbuffer = ZBuffer(empty, z) - zsource = OpaqueSource() + zsource = FakeOpaqueSource() zsource.set_zbuffer(zbuffer) sc.add_source(zsource) diff --git a/yt/visualization/volume_rendering/utils.py b/yt/visualization/volume_rendering/utils.py index fff5271d82f..d275f37b414 100644 --- a/yt/visualization/volume_rendering/utils.py +++ b/yt/visualization/volume_rendering/utils.py @@ -13,7 +13,7 @@ from yt.utilities.on_demand_imports import NotAModule try: - from yt.utilities.lib.embree_mesh import mesh_traversal + from yt.utilities.lib.embree_mesh import mesh_traversal # type: ignore # Catch ValueError in case size of objects in Cython change except (ImportError, ValueError): mesh_traversal = NotAModule("pyembree") diff --git a/yt/visualization/volume_rendering/zbuffer_array.py b/yt/visualization/volume_rendering/zbuffer_array.py index 1ec1a70beab..27dd47d6e5c 100644 --- a/yt/visualization/volume_rendering/zbuffer_array.py +++ b/yt/visualization/volume_rendering/zbuffer_array.py @@ -1,3 +1,5 @@ +from typing import List, Tuple + import numpy as np @@ -74,8 +76,9 @@ def paint(self, ind, value, z): if __name__ == "__main__": - shape = (64, 64) - for shape in [(64, 64), (16, 16, 4), (128), (16, 32)]: + shape: Tuple[int, ...] = (64, 64) + shapes: List[Tuple[int, ...]] = [(64, 64), (16, 16, 4), (128,), (16, 32)] + for shape in shapes: b1 = ZBuffer(np.random.random(shape), np.ones(shape)) b2 = ZBuffer(np.random.random(shape), np.zeros(shape)) c = b1 + b2