Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce code duplication in IOHandler _read_particle_coords and _read_particle_fields #4597

Merged
merged 2 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions yt/frontends/adaptahop/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,11 @@ def _yield_coordinates(self, data_file):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
ptype = "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=attrgetter("filename")):
for data_file in self._sorted_chunk_iterator(chunks):
pcount = (
data_file.ds.parameters["nhalos"] + data_file.ds.parameters["nsubs"]
)
Expand All @@ -56,14 +51,10 @@ def _read_particle_coords(self, chunks, ptf):

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()

# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)

def iterate_over_attributes(attr_list):
for attr, *_ in attr_list:
Expand All @@ -76,7 +67,7 @@ def iterate_over_attributes(attr_list):

attr_pos = partial(_find_attr_position, halo_attributes=halo_attributes)

for data_file in sorted(data_files, key=attrgetter("filename")):
for data_file in self._sorted_chunk_iterator(chunks):
pcount = (
data_file.ds.parameters["nhalos"] + data_file.ds.parameters["nsubs"]
)
Expand Down Expand Up @@ -194,6 +185,10 @@ def members(self, ihalo):
members = fpu.read_attrs(todo.pop(0))["particle_identities"]
return members

def _sorted_chunk_iterator(self, chunks):
data_files = self._get_data_files(chunks)
yield from sorted(data_files, key=attrgetter("filename"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewturk just pointing out an example where I added an override of _sorted_chunk_iterator so that the frontend continues only sorting by filename. The change in ahf is similar.



def _todo_from_attributes(attributes: ATTR_T, halo_attributes: ATTR_T):
# Helper function to generate a list of read-skip instructions given a list of
Expand Down
29 changes: 13 additions & 16 deletions yt/frontends/ahf/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ def _read_particle_coords(self, chunks, ptf):
# This needs to *yield* a series of tuples of (ptype, (x, y, z), hsml).
# chunks is a list of chunks, and ptf is a dict where the keys are
# ptypes and the values are lists of fields.
for data_file in self._get_data_files(chunks, ptf):

# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
for data_file in self._sorted_chunk_iterator(chunks):
pos = data_file._get_particle_positions("halos")
x, y, z = (pos[:, i] for i in range(3))
yield "halos", (x, y, z), 0.0
Expand All @@ -34,7 +38,10 @@ def _read_particle_fields(self, chunks, ptf, selector):
# reading ptype, field and applying the selector to the data read in.
# Selector objects have a .select_points(x,y,z) that returns a mask, so
# you need to do your masking here.
for data_file in self._get_data_files(chunks, ptf):
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
for data_file in self._sorted_chunk_iterator(chunks):
si, ei = data_file.start, data_file.end
cols = []
for field_list in ptf.values():
Expand Down Expand Up @@ -65,17 +72,7 @@ def _identify_fields(self, data_file):
fields = [("halos", f) for f in data_file.col_names]
return fields, {}

# Helper methods

def _get_data_files(self, chunks, ptf):
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
# Get data_files
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
data_files = sorted(data_files, key=attrgetter("filename"))
yield from data_files
def _sorted_chunk_iterator(self, chunks):
# yield from sorted list of data_files
data_files = self._get_data_files(chunks)
yield from sorted(data_files, key=attrgetter("filename"))
14 changes: 2 additions & 12 deletions yt/frontends/gadget_fof/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ def _read_fluid_selection(self, chunks, selector, fields, size):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
with h5py.File(data_file.filename, mode="r") as f:
for ptype in sorted(ptf):
coords = data_file._get_particle_positions(ptype, f=f)
Expand Down Expand Up @@ -71,12 +66,7 @@ def _read_offset_particle_field(self, field, data_file, fh):

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
si, ei = data_file.start, data_file.end
with h5py.File(data_file.filename, mode="r") as f:
for ptype, field_list in sorted(ptf.items()):
Expand Down
15 changes: 2 additions & 13 deletions yt/frontends/halo_catalog/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,12 @@ def _read_fluid_selection(self, chunks, selector, fields, size):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
ptype = "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
pn = "particle_position_%s"
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
with h5py.File(data_file.filename, mode="r") as f:
units = parse_h5_attr(f[pn % "x"], "units")
pos = data_file._get_particle_positions(ptype, f=f)
Expand All @@ -46,17 +41,11 @@ def _yield_coordinates(self, data_file):
yield "halos", pos

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
pn = "particle_position_%s"
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
si, ei = data_file.start, data_file.end
with h5py.File(data_file.filename, mode="r") as f:
for ptype, field_list in sorted(ptf.items()):
Expand Down
13 changes: 2 additions & 11 deletions yt/frontends/http_stream/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@ def _identify_fields(self, data_file):
return f, {}

def _read_particle_coords(self, chunks, ptf):
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
for ptype in ptf:
s = self._open_stream(data_file, (ptype, "Coordinates"))
c = np.frombuffer(s, dtype="float64")
Expand All @@ -47,11 +42,7 @@ def _read_particle_coords(self, chunks, ptf):

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
for ptype, field_list in sorted(ptf.items()):
s = self._open_stream(data_file, (ptype, "Coordinates"))
c = np.frombuffer(s, dtype="float64")
Expand Down
14 changes: 2 additions & 12 deletions yt/frontends/owls_subfind/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@ def _read_fluid_selection(self, chunks, selector, fields, size):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
with h5py.File(data_file.filename, mode="r") as f:
for ptype in sorted(ptf):
pcount = data_file.total_particles[ptype]
Expand Down Expand Up @@ -67,12 +62,7 @@ def _read_offset_particle_field(self, field, data_file, fh):

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
with h5py.File(data_file.filename, mode="r") as f:
for ptype, field_list in sorted(ptf.items()):
pcount = data_file.total_particles[ptype]
Expand Down
16 changes: 3 additions & 13 deletions yt/frontends/rockstar/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,12 @@ def _read_fluid_selection(self, chunks, selector, fields, size):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()

# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
ptype = "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
pcount = data_file.header["num_halos"]
if pcount == 0:
continue
Expand All @@ -37,16 +33,10 @@ def _read_particle_coords(self, chunks, ptf):
yield "halos", (pos[:, i] for i in range(3)), 0.0

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()
# Only support halo reading for now.
assert len(ptf) == 1
assert list(ptf.keys())[0] == "halos"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
si, ei = data_file.start, data_file.end
pcount = data_file.header["num_halos"]
if pcount == 0:
Expand Down
12 changes: 2 additions & 10 deletions yt/frontends/sdf/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@ def _read_fluid_selection(self, chunks, selector, fields, size):
raise NotImplementedError

def _read_particle_coords(self, chunks, ptf):
chunks = list(chunks)
data_files = set()
assert len(ptf) == 1
assert ptf.keys()[0] == "dark_matter"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
data_files = self._get_data_files(chunks)
assert len(data_files) == 1
for _data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
yield "dark_matter", (
Expand All @@ -31,13 +27,9 @@ def _read_particle_coords(self, chunks, ptf):
), 0.0

def _read_particle_fields(self, chunks, ptf, selector):
chunks = list(chunks)
data_files = set()
assert len(ptf) == 1
assert ptf.keys()[0] == "dark_matter"
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
data_files = self._get_data_files(chunks)
assert len(data_files) == 1
for _data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a note: I did not use the _sorted_chunk_iterator here because I wanted to keep the assert len(data_files) == 1 line above.

for ptype, field_list in sorted(ptf.items()):
Expand Down
15 changes: 2 additions & 13 deletions yt/frontends/stream/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ def __init__(self, ds):
super().__init__(ds)

def _read_particle_coords(self, chunks, ptf):
for data_file in sorted(
self._get_data_files(chunks), key=lambda x: (x.filename, x.start)
):
for data_file in self._sorted_chunk_iterator(chunks):
f = self.fields[data_file.filename]
# This double-reads
for ptype in sorted(ptf):
Expand All @@ -117,19 +115,10 @@ def _read_particle_coords(self, chunks, ptf):
), 0.0

def _read_smoothing_length(self, chunks, ptf, ptype):
for data_file in sorted(
self._get_data_files(chunks), key=lambda x: (x.filename, x.start)
):
for data_file in self._sorted_chunk_iterator(chunks):
f = self.fields[data_file.filename]
return f[ptype, "smoothing_length"]

def _get_data_files(self, chunks):
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
return data_files

def _read_particle_data_file(self, data_file, ptf, selector=None):
return_data = {}
f = self.fields[data_file.filename]
Expand Down
6 changes: 1 addition & 5 deletions yt/frontends/tipsy/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,8 @@ def _fill_fields(self, fields, vals, hsml, mask, data_file):
return rv

def _read_particle_coords(self, chunks, ptf):
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
chunksize = self.ds.index.chunksize
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
poff = data_file.field_offsets
tp = data_file.total_particles
f = open(data_file.filename, "rb")
Expand Down
21 changes: 3 additions & 18 deletions yt/frontends/ytdata/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,7 @@ def _yield_coordinates(self, data_file):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
index_mask = slice(data_file.start, data_file.end)
with h5py.File(data_file.filename, mode="r") as f:
for ptype in sorted(ptf):
Expand Down Expand Up @@ -271,12 +266,7 @@ class IOHandlerYTSpatialPlotHDF5(IOHandlerYTDataContainerHDF5):

def _read_particle_coords(self, chunks, ptf):
# This will read chunks and yield the results.
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
with h5py.File(data_file.filename, mode="r") as f:
for ptype in sorted(ptf):
pcount = data_file.total_particles[ptype]
Expand All @@ -292,12 +282,7 @@ def _read_particle_coords(self, chunks, ptf):

def _read_particle_fields(self, chunks, ptf, selector):
# Now we have all the sizes, and we can allocate
chunks = list(chunks)
data_files = set()
for chunk in chunks:
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
for data_file in self._sorted_chunk_iterator(chunks):
all_count = self._count_particles(data_file)
with h5py.File(data_file.filename, mode="r") as f:
for ptype, field_list in sorted(ptf.items()):
Expand Down
Loading