Skip to content

Commit

Permalink
Merge pull request #4595 from chrishavlin/ytdata_chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros authored Jul 24, 2023
2 parents 7ea8cc0 + 6405200 commit 6a91640
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions yt/frontends/ytdata/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,18 @@ def _read_particle_coords(self, chunks, ptf):
for obj in chunk.objs:
data_files.update(obj.data_files)
for data_file in sorted(data_files, key=lambda x: (x.filename, x.start)):
index_mask = slice(data_file.start, data_file.end)
with h5py.File(data_file.filename, mode="r") as f:
for ptype in sorted(ptf):
pcount = data_file.total_particles[ptype]
if pcount == 0:
continue
units = _get_position_array_units(ptype, f, "x")
x, y, z = (
self.ds.arr(_get_position_array(ptype, f, ax), units)
self.ds.arr(
_get_position_array(ptype, f, ax, index_mask=index_mask),
units,
)
for ax in "xyz"
)
yield ptype, (x, y, z), 0.0
Expand All @@ -213,13 +217,17 @@ def _read_particle_data_file(self, data_file, ptf, selector):
data_return = {}

with h5py.File(data_file.filename, mode="r") as f:
index_mask = slice(data_file.start, data_file.end)
for ptype, field_list in sorted(ptf.items()):
if selector is None or getattr(selector, "is_all_data", False):
mask = slice(None, None, None)
mask = index_mask
else:
units = _get_position_array_units(ptype, f, "x")
x, y, z = (
self.ds.arr(_get_position_array(ptype, f, ax), units)
self.ds.arr(
_get_position_array(ptype, f, ax, index_mask=index_mask),
units,
)
for ax in "xyz"
)
mask = selector.select_points(x, y, z, 0.0)
Expand Down Expand Up @@ -308,12 +316,14 @@ def _read_particle_fields(self, chunks, ptf, selector):
yield (ptype, field), data


def _get_position_array(ptype, f, ax):
def _get_position_array(ptype, f, ax, index_mask=None):
if index_mask is None:
index_mask = slice(None, None, None)
if ptype == "grid":
pos_name = ""
else:
pos_name = "particle_position_"
return f[ptype][pos_name + ax][()].astype("float64")
return f[ptype][pos_name + ax][index_mask].astype("float64")


def _get_position_array_units(ptype, f, ax):
Expand Down

0 comments on commit 6a91640

Please sign in to comment.