From d411f6ca35b06a6e70ea1fd80a73c6665e3a625a Mon Sep 17 00:00:00 2001 From: Ingmar Schoegl Date: Mon, 21 Feb 2022 08:29:20 -0600 Subject: [PATCH 1/2] [Python] Propagate 'extra' in SolutionArray.__getitem__ --- interfaces/cython/cantera/composite.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/interfaces/cython/cantera/composite.py b/interfaces/cython/cantera/composite.py index ed4a7b5e62..94bcc3f4c9 100644 --- a/interfaces/cython/cantera/composite.py +++ b/interfaces/cython/cantera/composite.py @@ -617,14 +617,15 @@ def __init__(self, phase, shape=(0,), states=None, extra=None, meta=None): def __getitem__(self, index): states = self._states[index] + extra = OrderedDict({key: val[index] for key, val in self._extra.items()}) if(isinstance(states, list)): num_rows = len(states) if num_rows == 0: states = None - return SolutionArray(self._phase, num_rows, states) + return SolutionArray(self._phase, num_rows, states, extra=extra) else: shape = states.shape[:-1] - return SolutionArray(self._phase, shape, states) + return SolutionArray(self._phase, shape, states, extra=extra) def __getattr__(self, name): if name in self._extra: From c7f08654c4c066f93178a36bb1962e2d0a7aaa9d Mon Sep 17 00:00:00 2001 From: Ingmar Schoegl Date: Mon, 21 Feb 2022 08:39:12 -0600 Subject: [PATCH 2/2] [UnitTests] Add unit test for propagated SolutionArray 'extra' --- interfaces/cython/cantera/test/test_composite.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/interfaces/cython/cantera/test/test_composite.py b/interfaces/cython/cantera/test/test_composite.py index f3aabf6a5e..24219be945 100644 --- a/interfaces/cython/cantera/test/test_composite.py +++ b/interfaces/cython/cantera/test/test_composite.py @@ -224,6 +224,14 @@ def test_collect_data(self): self.assertIn('X', collected) self.assertEqual(collected['X'].shape, (0, self.gas.n_species)) + def test_getitem(self): + states = ct.SolutionArray(self.gas, 10, extra={"index": range(10)}) + for ix, state in enumerate(states): + assert state.index == ix + + assert list(states[:2].index) == [0, 1] + assert list(states[100:102].index) == [] # outside of range + def test_append_state(self): gas = ct.Solution("h2o2.yaml") gas.TPX = 300, ct.one_atm, 'H2:0.5, O2:0.4'