Skip to content

Commit

Permalink
Merge pull request #4597 from chrishavlin/particle_io_code_duplication
Browse files Browse the repository at this point in the history
reduce code duplication in IOHandler _read_particle_coords and _read_particle_fields
  • Loading branch information
matthewturk authored Jul 27, 2023
2 parents 2ac4606 + c3784b0 commit 28defca
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 146 deletions.
19 changes: 7 additions & 12 deletions yt/frontends/adaptahop/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,11 @@ def _yield_coordinates(self, data_file):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
ptype = "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=attrgetter("filename")):
for data_file in self._sorted_chunk_iterator(chunks):
pcount = (
data_file.ds.parameters["nhalos"] + data_file.ds.parameters["nsubs"]
)
Expand All @@ -56,14 +51,10 @@ def _read_particle_coords(self, chunks, ptf):

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()

# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)

def iterate_over_attributes(attr_list):
for attr, *_ in attr_list:
Expand All @@ -76,7 +67,7 @@ def iterate_over_attributes(attr_list):

attr_pos = partial(_find_attr_position, halo_attributes=halo_attributes)

for data_file in sorted(data_files, key=attrgetter("filename")):
for data_file in self._sorted_chunk_iterator(chunks):
pcount = (
data_file.ds.parameters["nhalos"] + data_file.ds.parameters["nsubs"]
)
Expand Down Expand Up @@ -194,6 +185,10 @@ def members(self, ihalo):
members = fpu.read_attrs(todo.pop(0))["particle_identities"]
return members

def _sorted_chunk_iterator(self, chunks):
data_files = self._get_data_files(chunks)
yield from sorted(data_files, key=attrgetter("filename"))


def _todo_from_attributes(attributes: ATTR_T, halo_attributes: ATTR_T):
# Helper function to generate a list of read-skip instructions given a list of
Expand Down
29 changes: 13 additions & 16 deletions yt/frontends/ahf/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ def _read_particle_coords(self, chunks, ptf):
# This needs to *yield* a series of tuples of (ptype, (x, y, z), hsml).
# chunks is a list of chunks, and ptf is a dict where the keys are
# ptypes and the values are lists of fields.
for data_file in self._get_data_files(chunks, ptf):

# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
for data_file in self._sorted_chunk_iterator(chunks):
pos = data_file._get_particle_positions("halos")
x, y, z = (pos[:, i] for i in range(3))
yield "halos", (x, y, z), 0.0
Expand All @@ -34,7 +38,10 @@ def _read_particle_fields(self, chunks, ptf, selector):
# reading ptype, field and applying the selector to the data read in.
# Selector objects have a .select_points(x,y,z) that returns a mask, so
# you need to do your masking here.
for data_file in self._get_data_files(chunks, ptf):
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
for data_file in self._sorted_chunk_iterator(chunks):
si, ei = data_file.start, data_file.end
cols = []
for field_list in ptf.values():
Expand Down Expand Up @@ -65,17 +72,7 @@ def _identify_fields(self, data_file):
fields = [("halos", f) for f in data_file.col_names]
return fields, {}

# Helper methods

def _get_data_files(self, chunks, ptf):
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
# Get data_files
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
data_files = sorted(data_files, key=attrgetter("filename"))
yield from data_files
def _sorted_chunk_iterator(self, chunks):
# yield from sorted list of data_files
data_files = self._get_data_files(chunks)
yield from sorted(data_files, key=attrgetter("filename"))
14 changes: 2 additions & 12 deletions yt/frontends/gadget_fof/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ def _read_fluid_selection(self, chunks, selector, fields, size):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
with h5py.File(data_file.filename, mode="r") as f:
for ptype in sorted(ptf):
coords = data_file._get_particle_positions(ptype, f=f)
Expand Down Expand Up @@ -71,12 +66,7 @@ def _read_offset_particle_field(self, field, data_file, fh):

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
si, ei = data_file.start, data_file.end
with h5py.File(data_file.filename, mode="r") as f:
for ptype, field_list in sorted(ptf.items()):
Expand Down
15 changes: 2 additions & 13 deletions yt/frontends/halo_catalog/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,12 @@ def _read_fluid_selection(self, chunks, selector, fields, size):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
ptype = "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
pn = "particle_position_%s"
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
with h5py.File(data_file.filename, mode="r") as f:
units = parse_h5_attr(f[pn % "x"], "units")
pos = data_file._get_particle_positions(ptype, f=f)
Expand All @@ -46,17 +41,11 @@ def _yield_coordinates(self, data_file):
yield "halos", pos

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
pn = "particle_position_%s"
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
si, ei = data_file.start, data_file.end
with h5py.File(data_file.filename, mode="r") as f:
for ptype, field_list in sorted(ptf.items()):
Expand Down
13 changes: 2 additions & 11 deletions yt/frontends/http_stream/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@ def _identify_fields(self, data_file):
return f, {}

def _read_particle_coords(self, chunks, ptf):
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
for ptype in ptf:
s = self._open_stream(data_file, (ptype, "Coordinates"))
c = np.frombuffer(s, dtype="float64")
Expand All @@ -47,11 +42,7 @@ def _read_particle_coords(self, chunks, ptf):

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
for ptype, field_list in sorted(ptf.items()):
s = self._open_stream(data_file, (ptype, "Coordinates"))
c = np.frombuffer(s, dtype="float64")
Expand Down
14 changes: 2 additions & 12 deletions yt/frontends/owls_subfind/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@ def _read_fluid_selection(self, chunks, selector, fields, size):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
with h5py.File(data_file.filename, mode="r") as f:
for ptype in sorted(ptf):
pcount = data_file.total_particles[ptype]
Expand Down Expand Up @@ -67,12 +62,7 @@ def _read_offset_particle_field(self, field, data_file, fh):

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
with h5py.File(data_file.filename, mode="r") as f:
for ptype, field_list in sorted(ptf.items()):
pcount = data_file.total_particles[ptype]
Expand Down
16 changes: 3 additions & 13 deletions yt/frontends/rockstar/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,12 @@ def _read_fluid_selection(self, chunks, selector, fields, size):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()

# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
ptype = "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
pcount = data_file.header["num_halos"]
if pcount == 0:
continue
Expand All @@ -37,16 +33,10 @@ def _read_particle_coords(self, chunks, ptf):
yield "halos", (pos[:, i] for i in range(3)), 0.0

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
si, ei = data_file.start, data_file.end
pcount = data_file.header["num_halos"]
if pcount == 0:
Expand Down
12 changes: 2 additions & 10 deletions yt/frontends/sdf/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@ def _read_fluid_selection(self, chunks, selector, fields, size):
raise NotImplementedError

def _read_particle_coords(self, chunks, ptf):
chunks = list(chunks)
data_files = set()
assert len(ptf) == 1
assert ptf.keys()[0] == "dark_matter"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
data_files = self._get_data_files(chunks)
assert len(data_files) == 1
for _data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
yield "dark_matter", (
Expand All @@ -31,13 +27,9 @@ def _read_particle_coords(self, chunks, ptf):
), 0.0

def _read_particle_fields(self, chunks, ptf, selector):
chunks = list(chunks)
data_files = set()
assert len(ptf) == 1
assert ptf.keys()[0] == "dark_matter"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
data_files = self._get_data_files(chunks)
assert len(data_files) == 1
for _data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for ptype, field_list in sorted(ptf.items()):
Expand Down
15 changes: 2 additions & 13 deletions yt/frontends/stream/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ def __init__(self, ds):
super().__init__(ds)

def _read_particle_coords(self, chunks, ptf):
for data_file in sorted(
self._get_data_files(chunks), key=lambda x: (x.filename, x.start)
):
for data_file in self._sorted_chunk_iterator(chunks):
f = self.fields[data_file.filename]
# This double-reads
for ptype in sorted(ptf):
Expand All @@ -117,19 +115,10 @@ def _read_particle_coords(self, chunks, ptf):
), 0.0

def _read_smoothing_length(self, chunks, ptf, ptype):
for data_file in sorted(
self._get_data_files(chunks), key=lambda x: (x.filename, x.start)
):
for data_file in self._sorted_chunk_iterator(chunks):
f = self.fields[data_file.filename]
return f[ptype, "smoothing_length"]

def _get_data_files(self, chunks):
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
return data_files

def _read_particle_data_file(self, data_file, ptf, selector=None):
return_data = {}
f = self.fields[data_file.filename]
Expand Down
6 changes: 1 addition & 5 deletions yt/frontends/tipsy/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,8 @@ def _fill_fields(self, fields, vals, hsml, mask, data_file):
return rv

def _read_particle_coords(self, chunks, ptf):
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
chunksize = self.ds.index.chunksize
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
poff = data_file.field_offsets
tp = data_file.total_particles
f = open(data_file.filename, "rb")
Expand Down
21 changes: 3 additions & 18 deletions yt/frontends/ytdata/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,7 @@ def _yield_coordinates(self, data_file):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
index_mask = slice(data_file.start, data_file.end)
with h5py.File(data_file.filename, mode="r") as f:
for ptype in sorted(ptf):
Expand Down Expand Up @@ -271,12 +266,7 @@ class IOHandlerYTSpatialPlotHDF5(IOHandlerYTDataContainerHDF5):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
with h5py.File(data_file.filename, mode="r") as f:
for ptype in sorted(ptf):
pcount = data_file.total_particles[ptype]
Expand All @@ -292,12 +282,7 @@ def _read_particle_coords(self, chunks, ptf):

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
all_count = self._count_particles(data_file)
with h5py.File(data_file.filename, mode="r") as f:
for ptype, field_list in sorted(ptf.items()):
Expand Down
Loading

0 comments on commit 28defca

Please sign in to comment.