diff --git a/doc/helper_scripts/show_fields.py b/doc/helper_scripts/show_fields.py index b826829d891..eac8627fda8 100644 --- a/doc/helper_scripts/show_fields.py +++ b/doc/helper_scripts/show_fields.py @@ -37,7 +37,7 @@ def _strip_ftype(field): np.random.seed(int(0x4D3D3D3)) -units = [base_ds._get_field_info(*f).units for f in fields] +units = [base_ds._get_field_info(f).units for f in fields] fields = [_strip_ftype(f) for f in fields] ds = fake_random_ds(16, fields=fields, units=units, particles=1) ds.parameters["HydroMethod"] = "streaming" diff --git a/yt/data_objects/construction_data_containers.py b/yt/data_objects/construction_data_containers.py index 206c91be7c7..53bdd022176 100644 --- a/yt/data_objects/construction_data_containers.py +++ b/yt/data_objects/construction_data_containers.py @@ -249,15 +249,16 @@ def get_data(self, fields=None): def _sq_field(field, data, fname: FieldKey): return data[fname] ** 2 - for fname in fields: - fd = self.ds._get_field_info(*fname) + for field in fields: + fd = self.ds._get_field_info(field) + ftype, fname = field self.ds.add_field( - (fname[0], f"tmp_{fname[1]}_squared"), - partial(_sq_field, fname=fname), + (ftype, f"tmp_{fname}_squared"), + partial(_sq_field, fname=field), sampling_type=fd.sampling_type, units=f"({fd.units})*({fd.units})", ) - sfields.append((fname[0], f"tmp_{fname[1]}_squared")) + sfields.append((ftype, f"tmp_{fname}_squared")) nfields = len(fields) nsfields = len(sfields) # We need a new tree for every single set of fields we add @@ -387,7 +388,7 @@ def _initialize_projected_units(self, fields, chunk): for field in self.data_source._determine_fields(fields): if field in self._projected_units: continue - finfo = self.ds._get_field_info(*field) + finfo = self.ds._get_field_info(field) if finfo.units is None: # First time calling a units="auto" field, infer units and cache # for future field accesses. @@ -930,7 +931,7 @@ def _split_fields(self, fields_to_get): particles = [] alias = {} for field in gen: - finfo = self.ds._get_field_info(*field) + finfo = self.ds._get_field_info(field) if finfo.is_alias: alias[field] = finfo continue @@ -939,7 +940,7 @@ def _split_fields(self, fields_to_get): except NeedsOriginalGrid: fill.append(field) for field in fill: - finfo = self.ds._get_field_info(*field) + finfo = self.ds._get_field_info(field) if finfo.sampling_type == "particle": particles.append(field) gen = [f for f in gen if f not in fill and f not in alias] @@ -1086,9 +1087,9 @@ def _fill_fields(self, fields): if self.comm.size > 1: for i in range(len(fields)): output_fields[i] = self.comm.mpi_allreduce(output_fields[i], op="sum") - for name, v in zip(fields, output_fields): - fi = self.ds._get_field_info(*name) - self[name] = self.ds.arr(v, fi.units) + for field, v in zip(fields, output_fields): + fi = self.ds._get_field_info(field) + self[field] = self.ds.arr(v, fi.units) def _generate_container_field(self, field): rv = self.ds.arr(np.ones(self.ActiveDimensions, dtype="float64"), "") @@ -1492,11 +1493,11 @@ def _fill_fields(self, fields): category=RuntimeWarning, ) mylog.debug("Caught %d runtime errors.", runtime_errors_count) - for name, v in zip(fields, ls.fields): + for field, v in zip(fields, ls.fields): if self.level > 0: v = v[1:-1, 1:-1, 1:-1] - fi = self.ds._get_field_info(*name) - self[name] = self.ds.arr(v, fi.units) + fi = self.ds._get_field_info(field) + self[field] = self.ds.arr(v, fi.units) def _initialize_level_state(self, fields): ls = LevelState() diff --git a/yt/data_objects/data_containers.py b/yt/data_objects/data_containers.py index 9f411a8a5bc..69b3ccf1673 100644 --- a/yt/data_objects/data_containers.py +++ b/yt/data_objects/data_containers.py @@ -2,18 +2,18 @@ import weakref from collections import defaultdict from contextlib import contextmanager -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Tuple import numpy as np from yt._maintenance.deprecation import issue_deprecation_warning -from yt._typing import AnyFieldKey +from yt._typing import AnyFieldKey, FieldKey, FieldName from yt.config import ytcfg from yt.data_objects.field_data import YTFieldData from yt.data_objects.profiles import create_profile from yt.fields.field_exceptions import NeedsGridType from yt.frontends.ytdata.utilities import save_as_dataset -from yt.funcs import get_output_filename, is_sequence, iter_fields, mylog +from yt.funcs import get_output_filename, iter_fields, mylog from yt.units._numpy_wrapper_functions import uconcatenate from yt.units.yt_array import YTArray, YTQuantity from yt.utilities.amr_kdtree.api import AMRKDTree @@ -21,7 +21,6 @@ YTCouldNotGenerateField, YTException, YTFieldNotFound, - YTFieldNotParseable, YTFieldTypeNotFound, YTNonIndexedDataContainer, YTSpatialFieldUnitError, @@ -30,6 +29,9 @@ from yt.utilities.on_demand_imports import _firefly as firefly from yt.utilities.parameter_file_storage import ParameterFileStore +if TYPE_CHECKING: + from yt.data_objects.static_output import Dataset + def sanitize_weight_field(ds, field, weight): field_object = ds._get_field_info(field) @@ -85,6 +87,8 @@ def __init__(self, ds, field_parameters): # Dataset._add_object_class but it can also be passed as a parameter to the # constructor, in which case it will override the default. # This code ensures it is never not set. + + self.ds: "Dataset" if ds is not None: self.ds = ds else: @@ -162,7 +166,7 @@ def apply_units(self, arr, units): except AttributeError: return self.ds.arr(arr, units=units) - def _first_matching_field(self, field): + def _first_matching_field(self, field: FieldName) -> FieldKey: for ftype, fname in self.ds.derived_field_list: if fname == field: return (ftype, fname) @@ -270,10 +274,7 @@ def __getitem__(self, key): try: rv = self.field_data[f] except KeyError: - if isinstance(f, tuple): - fi = self.ds._get_field_info(*f) - elif isinstance(f, bytes): - fi = self.ds._get_field_info("unknown", f) + fi = self.ds._get_field_info(f) rv = self.ds.arr(self.field_data[key], fi.units) return rv @@ -296,7 +297,7 @@ def __delitem__(self, key): def _generate_field(self, field): ftype, fname = field - finfo = self.ds._get_field_info(*field) + finfo = self.ds._get_field_info(field) with self._field_type_state(ftype, finfo): if fname in self._container_fields: tr = self._generate_container_field(field) @@ -310,8 +311,7 @@ def _generate_field(self, field): def _generate_fluid_field(self, field): # First we check the validator - ftype, fname = field - finfo = self.ds._get_field_info(ftype, fname) + finfo = self.ds._get_field_info(field) if self._current_chunk is None or self._current_chunk.chunk_type != "spatial": gen_obj = self else: @@ -326,7 +326,7 @@ def _generate_fluid_field(self, field): return rv def _generate_spatial_fluid(self, field, ngz): - finfo = self.ds._get_field_info(*field) + finfo = self.ds._get_field_info(field) if finfo.units is None: raise YTSpatialFieldUnitError(field) units = finfo.units @@ -383,7 +383,7 @@ def _generate_particle_field(self, field): else: gen_obj = self._current_chunk.objs[0] try: - finfo = self.ds._get_field_info(*field) + finfo = self.ds._get_field_info(field) finfo.check_available(gen_obj) except NeedsGridType as ngt_exception: if ngt_exception.ghost_zones != 0: @@ -407,7 +407,7 @@ def _generate_particle_field(self, field): ind += data.size else: with self._field_type_state(ftype, finfo, gen_obj): - rv = self.ds._get_field_info(*field)(gen_obj) + rv = self.ds._get_field_info(field)(gen_obj) return rv def _count_particles(self, ftype): @@ -1487,49 +1487,6 @@ def _field_type_state(self, ftype, finfo, obj=None): obj._current_particle_type = old_particle_type obj._current_fluid_type = old_fluid_type - def _tupleize_field(self, field): - - try: - ftype, fname = field.name - return ftype, fname - except AttributeError: - pass - - if is_sequence(field) and not isinstance(field, str): - try: - ftype, fname = field - if not all(isinstance(_, str) for _ in field): - raise TypeError - return ftype, fname - except TypeError as e: - raise YTFieldNotParseable(field) from e - except ValueError: - pass - - try: - fname = field - finfo = self.ds._get_field_info(field) - if finfo.sampling_type == "particle": - ftype = self._current_particle_type - if hasattr(self.ds, "_sph_ptypes"): - ptypes = self.ds._sph_ptypes - if finfo.name[0] in ptypes: - ftype = finfo.name[0] - elif finfo.is_alias and finfo.alias_name[0] in ptypes: - ftype = self._current_fluid_type - else: - ftype = self._current_fluid_type - if (ftype, fname) not in self.ds.field_info: - ftype = self.ds._last_freq[0] - return ftype, fname - except YTFieldNotFound: - pass - - if isinstance(field, str): - return "unknown", field - - raise YTFieldNotParseable(field) - def _determine_fields(self, fields): if str(fields) in self.ds._determined_fields: return self.ds._determined_fields[str(fields)] @@ -1539,9 +1496,8 @@ def _determine_fields(self, fields): explicit_fields.append(field) continue - ftype, fname = self._tupleize_field(field) - finfo = self.ds._get_field_info(ftype, fname) - + finfo = self.ds._get_field_info(field) + ftype, fname = finfo.name # really ugly check to ensure that this field really does exist somewhere, # in some naming convention, before returning it as a possible field type if ( diff --git a/yt/data_objects/index_subobjects/grid_patch.py b/yt/data_objects/index_subobjects/grid_patch.py index cc6db81c7b1..3d4bc1c57c3 100644 --- a/yt/data_objects/index_subobjects/grid_patch.py +++ b/yt/data_objects/index_subobjects/grid_patch.py @@ -78,7 +78,7 @@ def __getitem__(self, key): fields = self._determine_fields(key) except YTFieldTypeNotFound: return tr - finfo = self.ds._get_field_info(*fields[0]) + finfo = self.ds._get_field_info(fields[0]) if not finfo.sampling_type == "particle": num_nodes = 2 ** sum(finfo.nodal_flag) new_shape = list(self.ActiveDimensions) diff --git a/yt/data_objects/index_subobjects/octree_subset.py b/yt/data_objects/index_subobjects/octree_subset.py index 7a0f55bb5a2..0b32c987299 100644 --- a/yt/data_objects/index_subobjects/octree_subset.py +++ b/yt/data_objects/index_subobjects/octree_subset.py @@ -64,7 +64,7 @@ def __getitem__(self, key): fields = self._determine_fields(key) except YTFieldTypeNotFound: return tr - finfo = self.ds._get_field_info(*fields[0]) + finfo = self.ds._get_field_info(fields[0]) if not finfo.sampling_type == "particle": # We may need to reshape the field, if it is being queried from # field_data. If it's already cached, it just passes through. diff --git a/yt/data_objects/region_expression.py b/yt/data_objects/region_expression.py index 2fa37da548d..821e291bd46 100644 --- a/yt/data_objects/region_expression.py +++ b/yt/data_objects/region_expression.py @@ -3,7 +3,11 @@ from yt.funcs import obj_length from yt.units.yt_array import YTQuantity -from yt.utilities.exceptions import YTDimensionalityError, YTFieldNotParseable +from yt.utilities.exceptions import ( + YTDimensionalityError, + YTFieldNotFound, + YTFieldNotParseable, +) from yt.visualization.line_plot import LineBuffer from .data_containers import _get_ipython_key_completion @@ -23,7 +27,11 @@ def __getitem__(self, item): # that result in a rectangular prism or a slice. try: return self.all_data[item] - except (TypeError, YTFieldNotParseable): + except (YTFieldNotParseable, YTFieldNotFound): + # any error raised by self.ds._get_field_info + # signals a type error (not a field), however we don't want to + # catch plain TypeErrors as this may create subtle bugs very hard + # to decipher, like broken internal function calls. pass if isinstance(item, slice): diff --git a/yt/data_objects/selection_objects/data_selection_objects.py b/yt/data_objects/selection_objects/data_selection_objects.py index 0a4081882ea..71e1432297e 100644 --- a/yt/data_objects/selection_objects/data_selection_objects.py +++ b/yt/data_objects/selection_objects/data_selection_objects.py @@ -107,7 +107,7 @@ def _identify_dependencies(self, fields_to_get, spatial=False): if inspected >= len(fields_to_get): break inspected += 1 - fi = self.ds._get_field_info(*field) + fi = self.ds._get_field_info(field) fd = self.ds.field_dependencies.get( field, None ) or self.ds.field_dependencies.get(field[1], None) @@ -165,7 +165,7 @@ def get_data(self, fields=None): for field in self._determine_fields(fields): if field in self.field_data: continue - finfo = self.ds._get_field_info(*field) + finfo = self.ds._get_field_info(field) try: finfo.check_available(self) except NeedsGridType: @@ -183,13 +183,13 @@ def get_data(self, fields=None): # We now split up into readers for the types of fields fluids, particles = [], [] finfos = {} - for ftype, fname in fields_to_get: - finfo = self.ds._get_field_info(ftype, fname) - finfos[ftype, fname] = finfo + for field_key in fields_to_get: + finfo = self.ds._get_field_info(field_key) + finfos[field_key] = finfo if finfo.sampling_type == "particle": - particles.append((ftype, fname)) - elif (ftype, fname) not in fluids: - fluids.append((ftype, fname)) + particles.append(field_key) + elif field_key not in fluids: + fluids.append(field_key) # The _read method will figure out which fields it needs to get from # disk, and return a dict of those fields along with the fields that # need to be generated. @@ -228,7 +228,7 @@ def _generate_fields(self, fields_to_generate): index += 1 if field in self.field_data: continue - fi = self.ds._get_field_info(*field) + fi = self.ds._get_field_info(field) try: fd = self._generate_field(field) if hasattr(fd, "units"): diff --git a/yt/data_objects/static_output.py b/yt/data_objects/static_output.py index 9a7eed66cfc..9ed0f8267e1 100644 --- a/yt/data_objects/static_output.py +++ b/yt/data_objects/static_output.py @@ -20,12 +20,12 @@ from unyt.exceptions import UnitConversionError, UnitParseError from yt._maintenance.deprecation import issue_deprecation_warning -from yt._typing import AnyFieldKey, FieldKey, FieldType, ParticleType +from yt._typing import AnyFieldKey, FieldKey, FieldType, ImplicitFieldKey, ParticleType from yt.config import ytcfg from yt.data_objects.particle_filters import ParticleFilter, filter_registry from yt.data_objects.region_expression import RegionExpression from yt.data_objects.unions import ParticleUnion -from yt.fields.derived_field import ValidateSpatial +from yt.fields.derived_field import DerivedField, ValidateSpatial from yt.fields.field_type_container import FieldTypeContainer from yt.fields.fluid_fields import setup_gradient_fields from yt.funcs import is_sequence, iter_fields, mylog, set_intersection, setdefaultattr @@ -53,6 +53,7 @@ from yt.utilities.cosmology import Cosmology from yt.utilities.exceptions import ( YTFieldNotFound, + YTFieldNotParseable, YTGeometryNotSupported, YTIllDefinedParticleFilter, YTObjectNotImplemented, @@ -146,9 +147,7 @@ class Dataset(abc.ABC): default_fluid_type = "gas" default_field = ("gas", "density") fluid_types: Tuple[FieldType, ...] = ("gas", "deposit", "index") - particle_types: Optional[Tuple[ParticleType, ...]] = ( - "io", - ) # By default we have an 'all' + particle_types: Tuple[ParticleType, ...] = ("io",) # By default we have an 'all' particle_types_raw: Optional[Tuple[ParticleType, ...]] = ("io",) geometry = "cartesian" coordinates = None @@ -659,7 +658,6 @@ def create_field_info(self): # Now that we've detected the fields, set this flag so that # deprecated fields will be logged if they are used self.fields_detected = True - self._last_freq = (None, None) def set_field_label_format(self, format_property, value): """ @@ -754,7 +752,7 @@ def add_particle_union(self, union): units = set() for s in union: # First we check our existing fields for units - funits = self._get_field_info(s, field).units + funits = self._get_field_info((s, field)).units # Then we override with field_units settings. funits = self.field_units.get((s, field), funits) units.add(funits) @@ -860,11 +858,12 @@ def _setup_particle_types(self, ptypes=None): df += self._setup_particle_type(ptype) return df - _last_freq = (None, None) - _last_finfo = None - - def _get_field_info(self, ftype, fname=None): - field_info, candidates = self._get_field_info_helper(ftype, fname) + def _get_field_info( + self, + field: Union[FieldKey, ImplicitFieldKey, DerivedField], + /, + ) -> DerivedField: + field_info, candidates = self._get_field_info_helper(field) if field_info.name[1] in ("px", "py", "pz", "pdx", "pdy", "pdz"): # escape early as a bandaid solution to @@ -886,9 +885,9 @@ def _are_ambiguous(candidates: List[FieldKey]) -> bool: all_equivalent_particle_fields: bool if ( - self.particle_types is None - or self.particle_unions is None - or self.particle_types_raw is None + not self.particle_types + or not self.particle_unions + or not self.particle_types_raw ): all_equivalent_particle_fields = False elif all(ft in self.particle_types for ft in ftypes): @@ -920,60 +919,31 @@ def _are_ambiguous(candidates: List[FieldKey]) -> bool: ) return field_info - def _get_field_info_helper(self, ftype, fname=None): + def _get_field_info_helper( + self, + field: Union[FieldKey, ImplicitFieldKey, DerivedField], + /, + ) -> Tuple[DerivedField, List[FieldKey]]: self.index - # store the original inputs in case we need to raise an error - INPUT = ftype, fname - if fname is None: - try: - ftype, fname = ftype.name - except AttributeError: - ftype, fname = "unknown", ftype - - candidates: List[FieldKey] = [] - - # storing this condition before altering it - guessing_type = ftype == "unknown" - if guessing_type: - ftype = self._last_freq[0] or ftype - candidates = [(ft, fn) for ft, fn in self.field_info.keys() if fn == fname] - - field = (ftype, fname) - - if ( - field == self._last_freq - and field not in self.field_info.field_aliases.values() - ): - return self._last_finfo, candidates - if field in self.field_info: - self._last_freq = field - self._last_finfo = self.field_info[(ftype, fname)] - return self._last_finfo, candidates + ftype: str + fname: str + if isinstance(field, str): + ftype, fname = "unknown", field + elif isinstance(field, tuple) and len(field) == 2: + ftype, fname = field + elif isinstance(field, DerivedField): + ftype, fname = field.name + else: + raise YTFieldNotParseable(field) - try: - # Sometimes, if guessing_type == True, this will be switched for - # the type of field it is. So we look at the field type and - # determine if we need to change the type. - fi = self._last_finfo = self.field_info[fname] - if ( - fi.sampling_type == "particle" - and self._last_freq[0] not in self.particle_types - ): - field = "all", field[1] - elif ( - not fi.sampling_type == "particle" - and self._last_freq[0] not in self.fluid_types - ): - field = self.default_fluid_type, field[1] - self._last_freq = field - return self._last_finfo, candidates - except KeyError: - pass + if ftype == "unknown": + candidates: List[FieldKey] = [ + (ft, fn) for ft, fn in self.field_info if fn == fname + ] - # We also should check "all" for particles, which can show up if you're - # mixing deposition/gas fields with particle fields. - if guessing_type: + # We also should check "all" for particles, which can show up if you're + # mixing deposition/gas fields with particle fields. if hasattr(self, "_sph_ptype"): to_guess = [self.default_fluid_type, "all"] else: @@ -981,10 +951,12 @@ def _get_field_info_helper(self, ftype, fname=None): to_guess += list(self.fluid_types) + list(self.particle_types) for ftype in to_guess: if (ftype, fname) in self.field_info: - self._last_freq = (ftype, fname) - self._last_finfo = self.field_info[(ftype, fname)] - return self._last_finfo, candidates - raise YTFieldNotFound(field=INPUT, ds=self) + return self.field_info[ftype, fname], candidates + + elif (ftype, fname) in self.field_info: + return self.field_info[ftype, fname], [] + + raise YTFieldNotFound(field, ds=self) def _setup_classes(self): # Called by subclass diff --git a/yt/data_objects/tests/test_cutting_plane.py b/yt/data_objects/tests/test_cutting_plane.py index 0c451f2f67d..50bf2b522da 100644 --- a/yt/data_objects/tests/test_cutting_plane.py +++ b/yt/data_objects/tests/test_cutting_plane.py @@ -40,7 +40,7 @@ def test_cutting_plane(): for width in [(1.0, "unitary"), 1.0, ds.quan(0.5, "code_length")]: frb = cut.to_frb(width, 64) for cut_field in [("index", "ones"), ("gas", "density")]: - fi = ds._get_field_info("unknown", cut_field) + fi = ds._get_field_info(cut_field) data = frb[cut_field] assert_equal(data.info["data_source"], cut.__str__()) assert_equal(data.info["axis"], 4) diff --git a/yt/data_objects/tests/test_data_containers.py b/yt/data_objects/tests/test_data_containers.py index 6ea7ca9818b..39c8b9ef0c1 100644 --- a/yt/data_objects/tests/test_data_containers.py +++ b/yt/data_objects/tests/test_data_containers.py @@ -56,7 +56,7 @@ def test_yt_data_container(self): # Delete a non-existent field with assert_raises(YTFieldNotFound) as ex: del proj["p_mass"] - desired = "Could not find field ('unknown', 'p_mass') in UniformGridData." + desired = "Could not find field 'p_mass' in UniformGridData." assert_equal(str(ex.exception), desired) def test_write_out(self): diff --git a/yt/data_objects/tests/test_firefly.py b/yt/data_objects/tests/test_firefly.py index 3bd47c296c4..0d8ff851267 100644 --- a/yt/data_objects/tests/test_firefly.py +++ b/yt/data_objects/tests/test_firefly.py @@ -221,11 +221,6 @@ def test_field_tuple_specification( ["code_length"], YTFieldNotFound, ), # Test nonexistent field (dinos) - ( - ["pt2only_field"], - ["code_length"], - YTFieldNotFound, - ), # Test unique field (match_any_particle_types=False) ( ["common_field"], ["code_length"], @@ -236,11 +231,6 @@ def test_field_tuple_specification( ["code_length"], YTFieldNotFound, ), # Test nonexistent field tuple (pt1, pt2only_field) - ( - ["pt2only_field", ("pt1", "common_field")], - ["code_length", "code_length"], - YTFieldNotFound, - ), # Test mixed field spec (match_any_particle_types=False) ], ) def test_field_invalid_specification( diff --git a/yt/fields/field_detector.py b/yt/fields/field_detector.py index 6bebccf54aa..0ae4ca79ffe 100644 --- a/yt/fields/field_detector.py +++ b/yt/fields/field_detector.py @@ -1,4 +1,5 @@ from collections import defaultdict +from typing import Tuple, Union import numpy as np @@ -100,14 +101,14 @@ def _reshape_vals(self, arr): return arr return arr.reshape(self.ActiveDimensions, order="C") - def __missing__(self, item): + def __missing__(self, item: Union[Tuple[str, str], str]): from yt.fields.derived_field import NullFunc if not isinstance(item, tuple): field = ("unknown", item) else: field = item - finfo = self.ds._get_field_info(*field) + finfo = self.ds._get_field_info(field) params, permute_params = finfo._get_needed_parameters(self) self.field_parameters.update(params) # For those cases where we are guessing the field type, we will @@ -116,7 +117,7 @@ def __missing__(self, item): # types not getting correctly identified. # Note that the *only* way this works is if we also fix our field # dependencies during checking. Bug #627 talks about this. - item = self.ds._last_freq + _item: Tuple[str, str] = finfo.name if finfo is not None and finfo._function is not NullFunc: try: for param, param_v in permute_params.items(): @@ -144,43 +145,43 @@ def __missing__(self, item): self.requested_parameters.append(i) if vv is not None: if not self.flat: - self[item] = vv + self[_item] = vv else: - self[item] = vv.ravel() - return self[item] + self[_item] = vv.ravel() + return self[_item] elif finfo is not None and finfo.sampling_type == "particle": io = io_registry[self.ds.dataset_type](self.ds) if hasattr(io, "_vector_fields") and ( - item in io._vector_fields or item[1] in io._vector_fields + _item in io._vector_fields or _item[1] in io._vector_fields ): try: - cols = io._vector_fields[item] + cols = io._vector_fields[_item] except KeyError: - cols = io._vector_fields[item[1]] + cols = io._vector_fields[_item[1]] # A vector - self[item] = YTArray( + self[_item] = YTArray( np.ones((self.NumberOfParticles, cols)), finfo.units, registry=self.ds.unit_registry, ) else: # Not a vector - self[item] = YTArray( + self[_item] = YTArray( np.ones(self.NumberOfParticles), finfo.units, registry=self.ds.unit_registry, ) - if item == ("STAR", "BIRTH_TIME"): + if _item == ("STAR", "BIRTH_TIME"): # hack for the artio frontend so we pass valid times to # the artio functions for calculating physical times # from internal times - self[item] *= -0.1 - self.requested.append(item) - return self[item] - self.requested.append(item) - if item not in self: - self[item] = self._read_data(item) - return self[item] + self[_item] *= -0.1 + self.requested.append(_item) + return self[_item] + self.requested.append(_item) + if _item not in self: + self[_item] = self._read_data(_item) + return self[_item] def _debug(self): # We allow this to pass through. @@ -212,7 +213,7 @@ def particle_operation(self, *args, **kwargs): def _read_data(self, field_name): self.requested.append(field_name) - finfo = self.ds._get_field_info(*field_name) + finfo = self.ds._get_field_info(field_name) if finfo.sampling_type == "particle": self.requested.append(field_name) return np.ones(self.NumberOfParticles) diff --git a/yt/fields/particle_fields.py b/yt/fields/particle_fields.py index 6b9d8ed787b..6c24d5cbd4f 100644 --- a/yt/fields/particle_fields.py +++ b/yt/fields/particle_fields.py @@ -67,7 +67,6 @@ def _field_concat(fname): def _AllFields(field, data): v = [] for ptype in data.ds.particle_types: - data.ds._last_freq = (ptype, None) if ptype == "all" or ptype in data.ds.known_filters: continue v.append(data[ptype, fname].copy()) @@ -81,7 +80,6 @@ def _field_concat_slice(fname, axi): def _AllFields(field, data): v = [] for ptype in data.ds.particle_types: - data.ds._last_freq = (ptype, None) if ptype == "all" or ptype in data.ds.known_filters: continue v.append(data[ptype, fname][:, axi]) diff --git a/yt/fields/tests/test_fields.py b/yt/fields/tests/test_fields.py index f1d77a6714c..ee130021921 100644 --- a/yt/fields/tests/test_fields.py +++ b/yt/fields/tests/test_fields.py @@ -68,7 +68,7 @@ def __init__(self, field_name, ds, nprocs): self.ds = ds def __call__(self): - field = self.ds._get_field_info(*self.field_name) + field = self.ds._get_field_info(self.field_name) skip_grids = False needs_spatial = False for v in field.validators: @@ -489,14 +489,6 @@ def test_morton_index(): assert_array_equal(a1, a2) -def test_field_inference(): - ds = fake_random_ds(16) - ds.index - # If this is not true this means the result of field inference depends - # on the order we did field detection, which is random in Python3 - assert_equal(ds._last_freq, (None, None)) - - @requires_module("h5py") @requires_file(ISOGAL) def test_deposit_amr(): diff --git a/yt/fields/xray_emission_fields.py b/yt/fields/xray_emission_fields.py index 2fb0e1a97a2..2d6c1c09d50 100644 --- a/yt/fields/xray_emission_fields.py +++ b/yt/fields/xray_emission_fields.py @@ -216,7 +216,7 @@ def add_xray_emissivity_field( """ if not isinstance(metallicity, float) and metallicity is not None: try: - metallicity = ds._get_field_info(*metallicity) + metallicity = ds._get_field_info(metallicity) except YTFieldNotFound as e: raise RuntimeError( f"Your dataset does not have a {metallicity} field! " diff --git a/yt/frontends/enzo_e/data_structures.py b/yt/frontends/enzo_e/data_structures.py index a4d3f69ffda..1c557827133 100644 --- a/yt/frontends/enzo_e/data_structures.py +++ b/yt/frontends/enzo_e/data_structures.py @@ -1,5 +1,6 @@ import os from functools import cached_property +from typing import Tuple import numpy as np @@ -290,7 +291,7 @@ class EnzoEDataset(Dataset): _index_class = EnzoEHierarchy _field_info_class = EnzoEFieldInfo _suffix = ".block_list" - particle_types = None + particle_types: Tuple[str, ...] = () particle_types_raw = None def __init__( diff --git a/yt/frontends/stream/tests/test_stream_particles.py b/yt/frontends/stream/tests/test_stream_particles.py index 72071b9f7c7..49b6eb53600 100644 --- a/yt/frontends/stream/tests/test_stream_particles.py +++ b/yt/frontends/stream/tests/test_stream_particles.py @@ -56,35 +56,35 @@ def test_stream_particles(): for ptype in ("all", "io"): assert ( - ug1._get_field_info(ptype, "particle_position_x").sampling_type + ug1._get_field_info((ptype, "particle_position_x")).sampling_type == "particle" ) assert ( - ug1._get_field_info(ptype, "particle_position_y").sampling_type + ug1._get_field_info((ptype, "particle_position_y")).sampling_type == "particle" ) assert ( - ug1._get_field_info(ptype, "particle_position_z").sampling_type + ug1._get_field_info((ptype, "particle_position_z")).sampling_type == "particle" ) - assert ug1._get_field_info(ptype, "particle_mass").sampling_type == "particle" - assert not ug1._get_field_info("gas", "density").sampling_type == "particle" + assert ug1._get_field_info((ptype, "particle_mass")).sampling_type == "particle" + assert not ug1._get_field_info(("gas", "density")).sampling_type == "particle" for ptype in ("all", "io"): assert ( - ug2._get_field_info(ptype, "particle_position_x").sampling_type + ug2._get_field_info((ptype, "particle_position_x")).sampling_type == "particle" ) assert ( - ug2._get_field_info(ptype, "particle_position_y").sampling_type + ug2._get_field_info((ptype, "particle_position_y")).sampling_type == "particle" ) assert ( - ug2._get_field_info(ptype, "particle_position_z").sampling_type + ug2._get_field_info((ptype, "particle_position_z")).sampling_type == "particle" ) - assert ug2._get_field_info(ptype, "particle_mass").sampling_type == "particle" - assert not ug2._get_field_info("gas", "density").sampling_type == "particle" + assert ug2._get_field_info((ptype, "particle_mass")).sampling_type == "particle" + assert not ug2._get_field_info(("gas", "density")).sampling_type == "particle" # Now perform similar checks, but with multiple particle types @@ -138,31 +138,31 @@ def test_stream_particles(): for ptype in ("dm", "star"): assert ( - ug3._get_field_info(ptype, "particle_position_x").sampling_type + ug3._get_field_info((ptype, "particle_position_x")).sampling_type == "particle" ) assert ( - ug3._get_field_info(ptype, "particle_position_y").sampling_type + ug3._get_field_info((ptype, "particle_position_y")).sampling_type == "particle" ) assert ( - ug3._get_field_info(ptype, "particle_position_z").sampling_type + ug3._get_field_info((ptype, "particle_position_z")).sampling_type == "particle" ) - assert ug3._get_field_info(ptype, "particle_mass").sampling_type == "particle" + assert ug3._get_field_info((ptype, "particle_mass")).sampling_type == "particle" assert ( - ug4._get_field_info(ptype, "particle_position_x").sampling_type + ug4._get_field_info((ptype, "particle_position_x")).sampling_type == "particle" ) assert ( - ug4._get_field_info(ptype, "particle_position_y").sampling_type + ug4._get_field_info((ptype, "particle_position_y")).sampling_type == "particle" ) assert ( - ug4._get_field_info(ptype, "particle_position_z").sampling_type + ug4._get_field_info((ptype, "particle_position_z")).sampling_type == "particle" ) - assert ug4._get_field_info(ptype, "particle_mass").sampling_type == "particle" + assert ug4._get_field_info((ptype, "particle_mass")).sampling_type == "particle" def test_load_particles_types(): diff --git a/yt/frontends/ytdata/data_structures.py b/yt/frontends/ytdata/data_structures.py index 2874c8e9fe8..c1ac240b96d 100644 --- a/yt/frontends/ytdata/data_structures.py +++ b/yt/frontends/ytdata/data_structures.py @@ -407,7 +407,7 @@ def __getitem__(self, key): fields = self._determine_fields(key) except YTFieldTypeNotFound: return tr - finfo = self.ds._get_field_info(*fields[0]) + finfo = self.ds._get_field_info(fields[0]) if not finfo.sampling_type == "particle": return tr.reshape(self.ActiveDimensions[: self.ds.dimensionality]) return tr @@ -578,7 +578,7 @@ def __getitem__(self, key): fields = self._determine_fields(key) except YTFieldTypeNotFound: return tr - self.ds._get_field_info(*fields[0]) + self.ds._get_field_info(fields[0]) return tr def get_data(self, fields=None): @@ -610,7 +610,7 @@ def get_data(self, fields=None): for field in self._determine_fields(fields): if field in self.field_data: continue - finfo = self.ds._get_field_info(*field) + finfo = self.ds._get_field_info(field) try: finfo.check_available(self) except NeedsGridType: @@ -628,13 +628,13 @@ def get_data(self, fields=None): # We now split up into readers for the types of fields fluids, particles = [], [] finfos = {} - for ftype, fname in fields_to_get: - finfo = self.ds._get_field_info(ftype, fname) - finfos[ftype, fname] = finfo + for field_key in fields_to_get: + finfo = self.ds._get_field_info(field_key) + finfos[field_key] = finfo if finfo.sampling_type == "particle": - particles.append((ftype, fname)) - elif (ftype, fname) not in fluids: - fluids.append((ftype, fname)) + particles.append(field_key) + elif field_key not in fluids: + fluids.append(field_key) # The _read method will figure out which fields it needs to get from # disk, and return a dict of those fields along with the fields that diff --git a/yt/funcs.py b/yt/funcs.py index 5a9596f46fb..55475fb51f0 100644 --- a/yt/funcs.py +++ b/yt/funcs.py @@ -1080,10 +1080,7 @@ def obj_length(v): def array_like_field(data, x, field): field = data._determine_fields(field)[0] - if isinstance(field, tuple): - finfo = data.ds._get_field_info(field[0], field[1]) - else: - finfo = data.ds._get_field_info(field) + finfo = data.ds._get_field_info(field) if finfo.sampling_type == "particle": units = finfo.output_units else: diff --git a/yt/utilities/exceptions.py b/yt/utilities/exceptions.py index c5324c7a2d1..4875f1d71ea 100644 --- a/yt/utilities/exceptions.py +++ b/yt/utilities/exceptions.py @@ -127,7 +127,7 @@ def _get_suggestions(self) -> List[FieldKey]: ] def __str__(self): - msg = f"Could not find field {self.field} in {self.ds}." + msg = f"Could not find field {self.field!r} in {self.ds}." try: suggestions = self._get_suggestions() except AttributeError: @@ -408,7 +408,7 @@ def __init__(self, field): self.field = field def __str__(self): - return f"Cannot identify field {self.field}" + return f"Cannot identify field {self.field!r}" class YTDataSelectorNotImplemented(YTException): diff --git a/yt/utilities/lib/octree_raytracing.py b/yt/utilities/lib/octree_raytracing.py index 8995ba2dc0c..da1a6c9d603 100644 --- a/yt/utilities/lib/octree_raytracing.py +++ b/yt/utilities/lib/octree_raytracing.py @@ -44,7 +44,7 @@ def vertex_centered_data(self, field): data_source = self.data_source chunks = data_source.index._chunk(data_source, "spatial", ngz=1) - finfo = data_source.ds._get_field_info(*field) + finfo = data_source.ds._get_field_info(field) units = finfo.units rv = data_source.ds.arr( np.zeros((2, 2, 2, data_source.ires.size), dtype="float64"), units diff --git a/yt/visualization/fits_image.py b/yt/visualization/fits_image.py index 4be8446a2da..220d4271f71 100644 --- a/yt/visualization/fits_image.py +++ b/yt/visualization/fits_image.py @@ -1447,8 +1447,8 @@ def __init__( fields = source._determine_fields(list(iter_fields(fields))) stddev_str = "_stddev" if moment == 2 else "" for item in fields: - - key = (item[0], item[1] + stddev_str) + ftype, fname = item + key = (ftype, f"{fname}{stddev_str}") buf[key] = off_axis_projection( source, @@ -1467,9 +1467,8 @@ def __init__( def _sq_field(field, data, item: FieldKey): return data[item] ** 2 - fd = ds._get_field_info(*item) - - field_sq = (item[0], f"tmp_{item[1]}_squared") + fd = ds._get_field_info(item) + field_sq = (ftype, f"tmp_{fname}_squared") ds.add_field( field_sq, diff --git a/yt/visualization/fixed_resolution.py b/yt/visualization/fixed_resolution.py index c0c8d9f0c43..c52b38ba875 100644 --- a/yt/visualization/fixed_resolution.py +++ b/yt/visualization/fixed_resolution.py @@ -219,7 +219,7 @@ def _get_data_source_fields(self): def _get_info(self, item): info = {} ftype, fname = field = self.data_source._determine_fields(item)[0] - finfo = self.data_source.ds._get_field_info(*field) + finfo = self.data_source.ds._get_field_info(field) info["data_source"] = self.data_source.__str__() info["axis"] = self.data_source.axis info["field"] = str(item) @@ -627,9 +627,10 @@ def __getitem__(self, item): def _sq_field(field, data, item: FieldKey): return data[item] ** 2 - fd = self.ds._get_field_info(*item) + fd = self.ds._get_field_info(item) + ftype, fname = item - item_sq = (item[0], f"tmp_{item[1]}_squared") + item_sq = (ftype, f"tmp_{fname}_squared") self.ds.add_field( item_sq, partial(_sq_field, item=item), diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index fce184b97d3..84c2e42d462 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -1107,7 +1107,7 @@ def _get_plot_instance(self, field): axrect = self._get_axrect() pnh = NormHandler(self.data_source, display_units=self.data_source[field].units) - finfo = self.data_source.ds._get_field_info(*field) + finfo = self.data_source.ds._get_field_info(field) if not finfo.take_log: pnh.norm_type = Normalize plot = PlotMPL(self.figure_size, axrect, norm_handler=pnh) diff --git a/yt/visualization/plot_modifications.py b/yt/visualization/plot_modifications.py index d5112e2cb81..4d8b1377cf4 100644 --- a/yt/visualization/plot_modifications.py +++ b/yt/visualization/plot_modifications.py @@ -951,7 +951,7 @@ def __call__(self, plot): take_log = self.take_log else: field = data._determine_fields([self.field])[0] - take_log = plot.ds._get_field_info(*field).take_log + take_log = plot.ds._get_field_info(field).take_log if take_log: zi = np.log10(zi) diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 63b726b1142..9cf231b228a 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -241,7 +241,7 @@ def __init__( self._setup_plots() for field in self.data_source._determine_fields(self.fields): - finfo = self.data_source.ds._get_field_info(*field) + finfo = self.data_source.ds._get_field_info(field) pnh = self.plots[field].norm_handler if finfo.take_log is False: # take_log can be `None` so we explicitly compare against a boolean diff --git a/yt/visualization/tests/test_geo_projections.py b/yt/visualization/tests/test_geo_projections.py index 9f0d1751df6..3d113ef7c43 100644 --- a/yt/visualization/tests/test_geo_projections.py +++ b/yt/visualization/tests/test_geo_projections.py @@ -64,7 +64,7 @@ class TestGeoProjections(unittest.TestCase): def setUp(self): self.ds = fake_amr_ds(geometry="geographic") # switch off the log plot to avoid some unrelated matplotlib issues - f = self.ds._get_field_info("stream", "Density") + f = self.ds._get_field_info(("stream", "Density")) f.take_log = False @requires_module("cartopy") diff --git a/yt/visualization/volume_rendering/old_camera.py b/yt/visualization/volume_rendering/old_camera.py index 5455a41ede3..81e425ef787 100644 --- a/yt/visualization/volume_rendering/old_camera.py +++ b/yt/visualization/volume_rendering/old_camera.py @@ -203,7 +203,7 @@ def __init__( dd = self.ds.all_data() efields = dd._determine_fields(self.fields) if self.log_fields is None: - self.log_fields = [self.ds._get_field_info(*f).take_log for f in efields] + self.log_fields = [self.ds._get_field_info(f).take_log for f in efields] self.no_ghost = no_ghost self.use_light = use_light self.light_dir = None @@ -1703,7 +1703,7 @@ def __init__( fields = dd._determine_fields(fields) self.fields = fields if log_fields is None: - log_fields = [self.ds._get_field_info(*f).take_log for f in fields] + log_fields = [self.ds._get_field_info(f).take_log for f in fields] self.log_fields = log_fields self.sub_samples = sub_samples if volume is None: @@ -2144,7 +2144,7 @@ def finalize_image(self, image): ds = self.ds dd = ds.all_data() field = dd._determine_fields([self.field])[0] - finfo = ds._get_field_info(*field) + finfo = ds._get_field_info(field) dl = 1.0 if self.method == "integrate": if self.weight is None: @@ -2201,7 +2201,7 @@ def _render(self, double_check, num_threads, image, sampler): def save_image(self, image, fn=None, clip_ratio=None): dd = self.ds.all_data() field = dd._determine_fields([self.field])[0] - finfo = self.ds._get_field_info(*field) + finfo = self.ds._get_field_info(field) if finfo.take_log: im = np.log10(image) else: