Skip to content

Commit

Permalink
define __array__ for StateVector (#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsung authored Jul 10, 2024
1 parent 1085271 commit 1fc5204
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/ffsim/states/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ class StateVector:
norb: int
nelec: int | tuple[int, int]

def __array__(self, dtype=None, copy=None):
# TODO in Numpy 2.0 this can be simplified to
# return np.array(self.vec, dtype=dtype, copy=copy)
if copy:
if dtype is None:
return self.vec.copy()
else:
return self.vec.astype(dtype, copy=True)
if dtype is None:
return self.vec
return self.vec.astype(dtype, copy=False)


def dims(norb: int, nelec: tuple[int, int]) -> tuple[int, int]:
"""Get the dimensions of the FCI space.
Expand Down
9 changes: 9 additions & 0 deletions tests/python/states/states_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,12 @@ def test_slater_determinant_one_rdm_diff_rotation(
expected = ffsim.rdm(vec, norb, nelec, spin_summed=spin_summed)

np.testing.assert_allclose(rdm, expected, atol=1e-12)


def test_state_vector_array():
"""Test StateVector's __array__ method."""
norb = 5
nelec = (3, 2)
vec = ffsim.random.random_state_vector(ffsim.dim(norb, nelec), seed=3556)
state_vec = ffsim.StateVector(vec, norb, nelec)
assert np.array_equal(np.abs(state_vec), np.abs(vec))

0 comments on commit 1fc5204

Please sign in to comment.