Skip to content

Commit

Permalink
TYP: more type hints fixes for gadget
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Oct 8, 2021
1 parent 5b9dbe1 commit 7c974c7
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 18 deletions.
2 changes: 1 addition & 1 deletion yt/frontends/gadget/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _initialize_frontend_specific(self):

class GadgetDataset(SPHDataset):
_index_class: Type[Index] = GadgetBinaryIndex
_file_class = GadgetBinaryFile
_file_class: Type = GadgetBinaryFile
_field_info_class: Type[FieldInfoContainer] = GadgetFieldInfo
_particle_mass_name = "Mass"
_particle_coordinates_name = "Coordinates"
Expand Down
28 changes: 15 additions & 13 deletions yt/frontends/gadget/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def _identify_fields(self, data_file):
if not hasattr(g[kk], "shape"):
continue
if len(g[kk].shape) > 1:
self._vector_fields[kk] = g[kk].shape[1]
self._vector_fields_dict[kk] = g[kk].shape[1]
fields.append((ptype, str(kk)))

f.close()
Expand All @@ -306,16 +306,22 @@ def _identify_fields(self, data_file):
ZeroMass = object()


class IOHandlerGadgetBinary(IOHandlerSPH):
_dataset_type = "gadget_binary"
_vector_fields = (
_VECTOR_FIELDS = dict(
(
("Coordinates", 3),
("Velocity", 3),
("Velocities", 3),
("MagneticField", 3),
("FourMetalFractions", 4),
("ElevenMetalMasses", 11),
)
)


class IOHandlerGadgetBinary(IOHandlerSPH):
_dataset_type = "gadget_binary"
_vector_fields = tuple(_VECTOR_FIELDS.keys())
_vector_fields_dict = _VECTOR_FIELDS

# Particle types (Table 3 in GADGET-2 user guide)
#
Expand All @@ -337,7 +343,6 @@ class IOHandlerGadgetBinary(IOHandlerSPH):
_format = None

def __init__(self, ds, *args, **kwargs):
self._vector_fields = dict(self._vector_fields)
self._fields = ds._field_spec
self._ptypes = ds._ptype_spec
self.data_files = set()
Expand Down Expand Up @@ -437,15 +442,14 @@ def _read_field_from_file(self, f, count, name):
else:
dt = self._endian + self._float_type
dt = np.dtype(dt)
if name in self._vector_fields:
count *= self._vector_fields[name]
count *= self._vector_fields_dict.get(name, 1)
arr = np.fromfile(f, dtype=dt, count=count)
# ensure data are in native endianness to avoid errors
# when field data are passed to cython
dt = dt.newbyteorder("N")
arr = arr.astype(dt)
if name in self._vector_fields:
factor = self._vector_fields[name]
if name in self._vector_fields_dict:
factor = self._vector_fields_dict[name]
arr = arr.reshape((count // factor, factor), order="C")
return arr

Expand Down Expand Up @@ -530,14 +534,12 @@ def _calculate_field_offsets(
if (ptype, field) not in field_list:
continue
start_offset = df_start * fs
if field in self._vector_fields:
start_offset *= self._vector_fields[field]
start_offset *= self._vector_fields_dict.get(field, 1)
pos += start_offset
offsets[(ptype, field)] = pos
any_ptypes = True
remain_offset = (pcount[ptype] - df_start) * fs
if field in self._vector_fields:
remain_offset *= self._vector_fields[field]
remain_offset *= self._vector_fields_dict.get(field, 1)
pos += remain_offset
pos += 4
if not any_ptypes:
Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/gadget/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .definitions import gadget_field_specs, gadget_ptype_specs
from .io import IOHandlerGadgetBinary

vector_fields = dict(IOHandlerGadgetBinary._vector_fields)
vector_fields = dict(IOHandlerGadgetBinary._vector_fields_dict)

block_ids = {
"Coordinates": "POS",
Expand Down
6 changes: 3 additions & 3 deletions yt/frontends/gadget_fof/fields.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from yt.fields.field_info_container import FieldInfoContainer
from yt.fields.field_info_container import FieldInfoContainer, KnownFieldsT

m_units = "code_mass"
p_units = "code_length"
v_units = "code_velocity"

_pnums = 6
_type_fields = tuple(
_type_fields: KnownFieldsT = tuple(
("%s%sType_%d" % (ptype, field, pnum), (units, [], None))
for pnum in range(_pnums)
for field, units in (("Mass", m_units), ("Len", p_units))
for ptype in ("Group", "Subhalo")
)
_sub_type_fields = tuple(
_sub_type_fields: KnownFieldsT = tuple(
("Subhalo%sType_%d" % (field, pnum), (units, [], None))
for pnum in range(_pnums)
for field, units in (
Expand Down

0 comments on commit 7c974c7

Please sign in to comment.