Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: add static typing to Dataset._get_field_info_helper #4229

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
189c7d1
TYP: narrow down type of Dataset.particle_types
neutrinoceros Dec 4, 2022
93508b3
TYP: add static typing to Dataset._get_field_info_helper
neutrinoceros Nov 27, 2022
e0dc8bc
RFC: internal var renaming
neutrinoceros Dec 4, 2022
71f5754
RFC: cleanup calls to _get_field_info using the 2 args signature
neutrinoceros Dec 4, 2022
ccd65c5
RFC: drop second positional argument in Dataset._get_field_info
neutrinoceros Dec 4, 2022
4764124
TYP: reduce runtime type flexibility in private routine Dataset._get_…
neutrinoceros Dec 4, 2022
e84267a
RFC: drop unused caching from Dataset._get_field_info (Dataset._last_…
neutrinoceros Dec 4, 2022
0e4b2a5
TYP: improve exception specificity in Dataset._get_field_info
neutrinoceros Dec 4, 2022
8f8875b
RFC: simplify one-liner
neutrinoceros Dec 4, 2022
5c950af
MNT: adjust YTFieldNotFound error message to evolutions in Dataset._g…
neutrinoceros Dec 4, 2022
d7b5391
RFC: simplify unused branch
neutrinoceros Dec 4, 2022
ade1c28
MNT: cleanup unused private method
neutrinoceros Dec 4, 2022
85759e5
RFC: simplify internal logic in Dataset._get_field_info_helper
neutrinoceros Dec 5, 2022
b7c622c
RFC: fix duck typing in Dataset._get_field_info
neutrinoceros Dec 5, 2022
35e836f
TYP: use yt._typing explicit type annotations more extensively
neutrinoceros Dec 5, 2022
f9dc662
RFC: avoid unpacking field tuples only to reassemble in the same scope
neutrinoceros Dec 21, 2022
43c1004
TYP: add type hints to YTDataContainer._first_matching_field
neutrinoceros Dec 22, 2022
183425c
TST: drop a couple badly defined tests
neutrinoceros Dec 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/helper_scripts/show_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _strip_ftype(field):


np.random.seed(int(0x4D3D3D3))
units = [base_ds._get_field_info(*f).units for f in fields]
units = [base_ds._get_field_info(f).units for f in fields]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, this is because it's always a tuple now right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More exactly, it's because we don't have a second positional argument anymore. Passing a full tuple as the first argument already worked on main, I just made it the only way to do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this line use base_ds._get_field_info(f) at all instead of of base_ds.field_info[f].units?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea, probably no good reason, but this feels just slightly out of scope here. Do you mind if I leave the line as is ?

fields = [_strip_ftype(f) for f in fields]
ds = fake_random_ds(16, fields=fields, units=units, particles=1)
ds.parameters["HydroMethod"] = "streaming"
Expand Down
29 changes: 15 additions & 14 deletions yt/data_objects/construction_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,16 @@ def get_data(self, fields=None):
def _sq_field(field, data, fname: FieldKey):
return data[fname] ** 2

for fname in fields:
fd = self.ds._get_field_info(*fname)
for field in fields:
fd = self.ds._get_field_info(field)
ftype, fname = field
self.ds.add_field(
(fname[0], f"tmp_{fname[1]}_squared"),
partial(_sq_field, fname=fname),
(ftype, f"tmp_{fname}_squared"),
partial(_sq_field, fname=field),
sampling_type=fd.sampling_type,
units=f"({fd.units})*({fd.units})",
)
sfields.append((fname[0], f"tmp_{fname[1]}_squared"))
sfields.append((ftype, f"tmp_{fname}_squared"))
nfields = len(fields)
nsfields = len(sfields)
# We need a new tree for every single set of fields we add
Expand Down Expand Up @@ -387,7 +388,7 @@ def _initialize_projected_units(self, fields, chunk):
for field in self.data_source._determine_fields(fields):
if field in self._projected_units:
continue
finfo = self.ds._get_field_info(*field)
finfo = self.ds._get_field_info(field)
if finfo.units is None:
# First time calling a units="auto" field, infer units and cache
# for future field accesses.
Expand Down Expand Up @@ -930,7 +931,7 @@ def _split_fields(self, fields_to_get):
particles = []
alias = {}
for field in gen:
finfo = self.ds._get_field_info(*field)
finfo = self.ds._get_field_info(field)
if finfo.is_alias:
alias[field] = finfo
continue
Expand All @@ -939,7 +940,7 @@ def _split_fields(self, fields_to_get):
except NeedsOriginalGrid:
fill.append(field)
for field in fill:
finfo = self.ds._get_field_info(*field)
finfo = self.ds._get_field_info(field)
if finfo.sampling_type == "particle":
particles.append(field)
gen = [f for f in gen if f not in fill and f not in alias]
Expand Down Expand Up @@ -1086,9 +1087,9 @@ def _fill_fields(self, fields):
if self.comm.size > 1:
for i in range(len(fields)):
output_fields[i] = self.comm.mpi_allreduce(output_fields[i], op="sum")
for name, v in zip(fields, output_fields):
fi = self.ds._get_field_info(*name)
self[name] = self.ds.arr(v, fi.units)
for field, v in zip(fields, output_fields):
fi = self.ds._get_field_info(field)
self[field] = self.ds.arr(v, fi.units)

def _generate_container_field(self, field):
rv = self.ds.arr(np.ones(self.ActiveDimensions, dtype="float64"), "")
Expand Down Expand Up @@ -1492,11 +1493,11 @@ def _fill_fields(self, fields):
category=RuntimeWarning,
)
mylog.debug("Caught %d runtime errors.", runtime_errors_count)
for name, v in zip(fields, ls.fields):
for field, v in zip(fields, ls.fields):
if self.level > 0:
v = v[1:-1, 1:-1, 1:-1]
fi = self.ds._get_field_info(*name)
self[name] = self.ds.arr(v, fi.units)
fi = self.ds._get_field_info(field)
self[field] = self.ds.arr(v, fi.units)

def _initialize_level_state(self, fields):
ls = LevelState()
Expand Down
78 changes: 17 additions & 61 deletions yt/data_objects/data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,25 @@
import weakref
from collections import defaultdict
from contextlib import contextmanager
from typing import List, Tuple
from typing import TYPE_CHECKING, List, Tuple

import numpy as np

from yt._maintenance.deprecation import issue_deprecation_warning
from yt._typing import AnyFieldKey
from yt._typing import AnyFieldKey, FieldKey, FieldName
from yt.config import ytcfg
from yt.data_objects.field_data import YTFieldData
from yt.data_objects.profiles import create_profile
from yt.fields.field_exceptions import NeedsGridType
from yt.frontends.ytdata.utilities import save_as_dataset
from yt.funcs import get_output_filename, is_sequence, iter_fields, mylog
from yt.funcs import get_output_filename, iter_fields, mylog
from yt.units._numpy_wrapper_functions import uconcatenate
from yt.units.yt_array import YTArray, YTQuantity
from yt.utilities.amr_kdtree.api import AMRKDTree
from yt.utilities.exceptions import (
YTCouldNotGenerateField,
YTException,
YTFieldNotFound,
YTFieldNotParseable,
YTFieldTypeNotFound,
YTNonIndexedDataContainer,
YTSpatialFieldUnitError,
Expand All @@ -30,6 +29,9 @@
from yt.utilities.on_demand_imports import _firefly as firefly
from yt.utilities.parameter_file_storage import ParameterFileStore

if TYPE_CHECKING:
from yt.data_objects.static_output import Dataset


def sanitize_weight_field(ds, field, weight):
field_object = ds._get_field_info(field)
Expand Down Expand Up @@ -85,6 +87,8 @@ def __init__(self, ds, field_parameters):
# Dataset._add_object_class but it can also be passed as a parameter to the
# constructor, in which case it will override the default.
# This code ensures it is never not set.

self.ds: "Dataset"
if ds is not None:
self.ds = ds
else:
Expand Down Expand Up @@ -162,7 +166,7 @@ def apply_units(self, arr, units):
except AttributeError:
return self.ds.arr(arr, units=units)

def _first_matching_field(self, field):
def _first_matching_field(self, field: FieldName) -> FieldKey:
for ftype, fname in self.ds.derived_field_list:
if fname == field:
return (ftype, fname)
Expand Down Expand Up @@ -270,10 +274,7 @@ def __getitem__(self, key):
try:
rv = self.field_data[f]
except KeyError:
if isinstance(f, tuple):
fi = self.ds._get_field_info(*f)
elif isinstance(f, bytes):
fi = self.ds._get_field_info("unknown", f)
Comment on lines -275 to -276
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is interesting. any idea what the history of using a byte string here is? @matthewturk?

fi = self.ds._get_field_info(f)
rv = self.ds.arr(self.field_data[key], fi.units)
return rv

Expand All @@ -296,7 +297,7 @@ def __delitem__(self, key):

def _generate_field(self, field):
ftype, fname = field
finfo = self.ds._get_field_info(*field)
finfo = self.ds._get_field_info(field)
with self._field_type_state(ftype, finfo):
if fname in self._container_fields:
tr = self._generate_container_field(field)
Expand All @@ -310,8 +311,7 @@ def _generate_field(self, field):

def _generate_fluid_field(self, field):
# First we check the validator
ftype, fname = field
finfo = self.ds._get_field_info(ftype, fname)
finfo = self.ds._get_field_info(field)
if self._current_chunk is None or self._current_chunk.chunk_type != "spatial":
gen_obj = self
else:
Expand All @@ -326,7 +326,7 @@ def _generate_fluid_field(self, field):
return rv

def _generate_spatial_fluid(self, field, ngz):
finfo = self.ds._get_field_info(*field)
finfo = self.ds._get_field_info(field)
if finfo.units is None:
raise YTSpatialFieldUnitError(field)
units = finfo.units
Expand Down Expand Up @@ -383,7 +383,7 @@ def _generate_particle_field(self, field):
else:
gen_obj = self._current_chunk.objs[0]
try:
finfo = self.ds._get_field_info(*field)
finfo = self.ds._get_field_info(field)
finfo.check_available(gen_obj)
except NeedsGridType as ngt_exception:
if ngt_exception.ghost_zones != 0:
Expand All @@ -407,7 +407,7 @@ def _generate_particle_field(self, field):
ind += data.size
else:
with self._field_type_state(ftype, finfo, gen_obj):
rv = self.ds._get_field_info(*field)(gen_obj)
rv = self.ds._get_field_info(field)(gen_obj)
return rv

def _count_particles(self, ftype):
Expand Down Expand Up @@ -1487,49 +1487,6 @@ def _field_type_state(self, ftype, finfo, obj=None):
obj._current_particle_type = old_particle_type
obj._current_fluid_type = old_fluid_type

def _tupleize_field(self, field):

try:
ftype, fname = field.name
return ftype, fname
except AttributeError:
pass

if is_sequence(field) and not isinstance(field, str):
try:
ftype, fname = field
if not all(isinstance(_, str) for _ in field):
raise TypeError
return ftype, fname
except TypeError as e:
raise YTFieldNotParseable(field) from e
except ValueError:
pass

try:
fname = field
finfo = self.ds._get_field_info(field)
if finfo.sampling_type == "particle":
ftype = self._current_particle_type
if hasattr(self.ds, "_sph_ptypes"):
ptypes = self.ds._sph_ptypes
if finfo.name[0] in ptypes:
ftype = finfo.name[0]
elif finfo.is_alias and finfo.alias_name[0] in ptypes:
ftype = self._current_fluid_type
else:
ftype = self._current_fluid_type
if (ftype, fname) not in self.ds.field_info:
ftype = self.ds._last_freq[0]
return ftype, fname
except YTFieldNotFound:
pass

if isinstance(field, str):
return "unknown", field

raise YTFieldNotParseable(field)

def _determine_fields(self, fields):
if str(fields) in self.ds._determined_fields:
return self.ds._determined_fields[str(fields)]
Expand All @@ -1539,9 +1496,8 @@ def _determine_fields(self, fields):
explicit_fields.append(field)
continue

ftype, fname = self._tupleize_field(field)
finfo = self.ds._get_field_info(ftype, fname)

finfo = self.ds._get_field_info(field)
ftype, fname = finfo.name
# really ugly check to ensure that this field really does exist somewhere,
# in some naming convention, before returning it as a possible field type
if (
Expand Down
2 changes: 1 addition & 1 deletion yt/data_objects/index_subobjects/grid_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __getitem__(self, key):
fields = self._determine_fields(key)
except YTFieldTypeNotFound:
return tr
finfo = self.ds._get_field_info(*fields[0])
finfo = self.ds._get_field_info(fields[0])
if not finfo.sampling_type == "particle":
num_nodes = 2 ** sum(finfo.nodal_flag)
new_shape = list(self.ActiveDimensions)
Expand Down
2 changes: 1 addition & 1 deletion yt/data_objects/index_subobjects/octree_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __getitem__(self, key):
fields = self._determine_fields(key)
except YTFieldTypeNotFound:
return tr
finfo = self.ds._get_field_info(*fields[0])
finfo = self.ds._get_field_info(fields[0])
if not finfo.sampling_type == "particle":
# We may need to reshape the field, if it is being queried from
# field_data. If it's already cached, it just passes through.
Expand Down
12 changes: 10 additions & 2 deletions yt/data_objects/region_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

from yt.funcs import obj_length
from yt.units.yt_array import YTQuantity
from yt.utilities.exceptions import YTDimensionalityError, YTFieldNotParseable
from yt.utilities.exceptions import (
YTDimensionalityError,
YTFieldNotFound,
YTFieldNotParseable,
)
from yt.visualization.line_plot import LineBuffer

from .data_containers import _get_ipython_key_completion
Expand All @@ -23,7 +27,11 @@ def __getitem__(self, item):
# that result in a rectangular prism or a slice.
try:
return self.all_data[item]
except (TypeError, YTFieldNotParseable):
except (YTFieldNotParseable, YTFieldNotFound):
# any error raised by self.ds._get_field_info
# signals a type error (not a field), however we don't want to
# catch plain TypeErrors as this may create subtle bugs very hard
# to decipher, like broken internal function calls.
pass

if isinstance(item, slice):
Expand Down
18 changes: 9 additions & 9 deletions yt/data_objects/selection_objects/data_selection_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _identify_dependencies(self, fields_to_get, spatial=False):
if inspected >= len(fields_to_get):
break
inspected += 1
fi = self.ds._get_field_info(*field)
fi = self.ds._get_field_info(field)
fd = self.ds.field_dependencies.get(
field, None
) or self.ds.field_dependencies.get(field[1], None)
Expand Down Expand Up @@ -165,7 +165,7 @@ def get_data(self, fields=None):
for field in self._determine_fields(fields):
if field in self.field_data:
continue
finfo = self.ds._get_field_info(*field)
finfo = self.ds._get_field_info(field)
try:
finfo.check_available(self)
except NeedsGridType:
Expand All @@ -183,13 +183,13 @@ def get_data(self, fields=None):
# We now split up into readers for the types of fields
fluids, particles = [], []
finfos = {}
for ftype, fname in fields_to_get:
finfo = self.ds._get_field_info(ftype, fname)
finfos[ftype, fname] = finfo
for field_key in fields_to_get:
finfo = self.ds._get_field_info(field_key)
finfos[field_key] = finfo
if finfo.sampling_type == "particle":
particles.append((ftype, fname))
elif (ftype, fname) not in fluids:
fluids.append((ftype, fname))
particles.append(field_key)
elif field_key not in fluids:
fluids.append(field_key)
# The _read method will figure out which fields it needs to get from
# disk, and return a dict of those fields along with the fields that
# need to be generated.
Expand Down Expand Up @@ -228,7 +228,7 @@ def _generate_fields(self, fields_to_generate):
index += 1
if field in self.field_data:
continue
fi = self.ds._get_field_info(*field)
fi = self.ds._get_field_info(field)
try:
fd = self._generate_field(field)
if hasattr(fd, "units"):
Expand Down
Loading