Skip to content

Commit

Permalink
Merge branch 'master' into gh-pages
Browse files Browse the repository at this point in the history
  • Loading branch information
b3m2a1 committed Sep 19, 2024
2 parents f45a2eb + 9075b27 commit cf4219b
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 72 deletions.
205 changes: 176 additions & 29 deletions Psience/VPT2/Corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from McUtils.Numputils import SparseArray
import McUtils.Numputils as nput
from McUtils.Data import UnitsData
from McUtils.Scaffolding import NullLogger, Checkpointer
from McUtils.Scaffolding import Logger, NullLogger, Checkpointer

from ..Spectra import DiscreteSpectrum
from ..BasisReps import BasisStateSpace, BasisMultiStateSpace, SelectionRuleStateSpace, HarmonicOscillatorProductBasis
from ..BasisReps import (
BasisStateSpace, BasisMultiStateSpace, SelectionRuleStateSpace, HarmonicOscillatorProductBasis,
StateMaker
)
from .Common import PerturbationTheoryException, _safe_dot

__all__ = [
Expand Down Expand Up @@ -920,12 +923,14 @@ class AnalyticPerturbationTheoryCorrections:
degenerate_states: 'Iterable[BasisStateSpace]' = None
_degenerate_hamiltonians: 'Iterable[np.ndarray]' = None
_degenerate_coefficients: 'Iterable[np.ndarray]' = None
_degenerate_state_list_transformations: 'Iterable[list[np.ndarray, np.ndarray]]' = None
energy_corrections: BasicAPTCorrections = None
transition_moment_corrections: 'Iterable[BasicAPTCorrections]' = None
degenerate_hamiltonian_corrections: 'Iterable[BasicAPTCorrections]' = None
operator_corrections: 'Iterable[BasicAPTCorrections]' = None
_deperturbed_operator_values: 'Iterable[np.ndarray]' = None
_operator_values: 'Iterable[np.ndarray]' = None
logger: 'Logger' = None

_zpe_pos: int = None

Expand Down Expand Up @@ -1027,25 +1032,36 @@ def get_degenerate_transformations(self, basis, energies):

return energies, (hams, transf)

def apply_degenerate_transformations(self, initial_states, final_states, subcorr):
initial_space = BasisStateSpace(
def _apply_degenerate_transformations(self,
initial_states, final_states, subcorr,
initial_space=None, final_space=None,
all_degs=None,
degenerate_mapping=None
):
if initial_space is None:
initial_space = BasisStateSpace(
HarmonicOscillatorProductBasis(len(initial_states[0])),
initial_states
)
if final_space is None:
final_space = BasisStateSpace(
HarmonicOscillatorProductBasis(len(initial_states[0])),
final_states
)
if all_degs is None:
all_degs = BasisStateSpace(
HarmonicOscillatorProductBasis(len(initial_states[0])),
initial_states
np.concatenate(self.degenerate_states, axis=0)
)
final_space = BasisStateSpace(
HarmonicOscillatorProductBasis(len(initial_states[0])),
final_states
)
all_degs = BasisStateSpace(
HarmonicOscillatorProductBasis(len(initial_states[0])),
np.concatenate(self.degenerate_states, axis=0)
)
deg_map_row = np.concatenate([
[i] * len(g) for i,g in enumerate(self.degenerate_states)
])
deg_map_col = np.concatenate([
np.arange(len(g)) for i, g in enumerate(self.degenerate_states)
])
if degenerate_mapping is None:
deg_map_row = np.concatenate([
[i] * len(g) for i,g in enumerate(self.degenerate_states)
])
deg_map_col = np.concatenate([
np.arange(len(g)) for i, g in enumerate(self.degenerate_states)
])
else:
deg_map_row, deg_map_col = degenerate_mapping
init_pos = all_degs.find(initial_states, missing_val=-1)
final_pos = all_degs.find(final_states, missing_val=-1)

Expand Down Expand Up @@ -1092,23 +1108,154 @@ def get_deperturbed_freqs(self):
else:
return self.get_freqs()

@property
def degenerate_transformation_pairs(self):
if self._degenerate_state_list_transformations is None:
self._degenerate_state_list_transformations = self._get_degenerate_tfs_mats()
return self._degenerate_state_list_transformations
def _get_degenerate_tfs_mats(self, logger=None):
#TODO: add checks to ensure that our blocks are complete and we have proper unitary tfs at the end
if logger is None:
logger = self.logger
logger = Logger.lookup(logger)
all_degs = BasisStateSpace(
HarmonicOscillatorProductBasis(len(self.state_lists[0][0][0])),
np.concatenate(self.degenerate_states, axis=0)
)
deg_map_row = np.concatenate([
[i] * len(g) for i, g in enumerate(self.degenerate_states)
])
deg_map_col = np.concatenate([
np.arange(len(g)) for i, g in enumerate(self.degenerate_states)
])
for i, block in enumerate(self.degenerate_states):
with logger.block(tag="Degenerate block {i}", i=i + 1):
logger.log_print(
"{blocks}",
blocks=block,
preformatter=lambda **kw: dict(
kw,
blocks="\n".join(StateMaker.parse_state(e) for e in kw['blocks'])
)
)

tfs = []
for block_idx, (init_states, final_states) in enumerate(self.state_lists):
initial_space = BasisStateSpace(
HarmonicOscillatorProductBasis(len(init_states[0])),
init_states
)
final_space = BasisStateSpace(
HarmonicOscillatorProductBasis(len(init_states[0])),
final_states
)
init_pos = all_degs.find(initial_space, missing_val=-1)
final_pos = all_degs.find(final_space, missing_val=-1)

# subcorr is a n_init x n_final object, but we need to figure out the transformation to apply to each axis
# to do so we find the appropriate transformation and insert it
row_tf = np.eye(len(init_states))
for n,i in enumerate(init_pos):
if i != -1:
col_pos = deg_map_row[i]
row_pos = deg_map_col[i]
deg_block = self.degenerate_coefficients[col_pos]
block_idx = initial_space.find(self.degenerate_states[col_pos])
# print(row_pos, col_pos, col_tf.shape)
row_tf[n, block_idx] = deg_block[row_pos]

nz_init_pos = [i for i in init_pos if i > -1]
nz_init = [s for i, s in zip(init_pos, init_states) if i > -1]
if len(nz_init_pos) > 0:
with logger.block(tag="Initial states ({ix})",
ix=[(deg_map_row[i], deg_map_col[i]) for i in nz_init_pos]):
logger.log_print(
"{initials}",
initials=nz_init,
preformatter=lambda **kw: dict(
kw,
initials="\n".join(StateMaker.parse_state(e) for e in kw['initials'])
)
)
logger.log_print(
"{tf}",
tf=row_tf,
preformatter=lambda **kw:dict(kw, tf="\n".join(logger.prep_array(kw['tf'])))
)
if np.sum(np.abs((row_tf @ row_tf.T) - np.eye(len(init_states))).flatten()) > 1e-3:
raise ValueError("Non-unitary row tf, something wrong with initial state degs")

col_tf = np.eye(len(final_states))
for n,i in enumerate(final_pos):
if i != -1:
col_pos = deg_map_row[i]
deg_block = self.degenerate_coefficients[col_pos]
block_idx = final_space.find(self.degenerate_states[col_pos])
row_pos = deg_map_col[i]
# print(row_pos, col_pos, col_tf.shape)
col_tf[n, block_idx] = deg_block[row_pos]
nz_final_pos = [i for i in final_pos if i > -1]
nz_final = [s for i, s in zip(final_pos, final_states) if i > -1]
if len(nz_final_pos) > 0:
with logger.block(tag="Final states ({ix})",
ix=[(deg_map_row[i], deg_map_col[i]) for i in nz_final_pos]):
logger.log_print(
"{finals}",
finals=nz_final,
preformatter=lambda **kw: dict(
kw,
finals="\n".join(StateMaker.parse_state(e) for e in kw['finals'])
)
)
logger.log_print(
"{tf}",
tf=col_tf,
preformatter=lambda **kw:dict(kw, tf="\n".join(logger.prep_array(kw['tf'])))
)
if np.sum(np.abs((row_tf @ row_tf.T) - np.eye(len(init_states))).flatten()) > 1e-3:
raise ValueError("Non-unitary col tf, something wrong with final state degs")

tfs.append([row_tf, col_tf])
return tfs
def _apply_degs_to_corrs(self, corrs, logger=None):
if logger is None:
logger = self.logger
logger = Logger.lookup(logger)


has_subcorrs = not (isinstance(corrs[0], np.ndarray) and corrs[0].ndim == 2)
if has_subcorrs:
all_tms = [[] for _ in corrs] # num axes
else:
all_tms = []
tfs = self.degenerate_transformation_pairs

for block_idx, (row_tf, col_tf) in enumerate(tfs):

if has_subcorrs:
for a, (storage, axis_moms) in enumerate(zip(all_tms, corrs)):
storage.append(
row_tf @ axis_moms[block_idx] @ col_tf.T
)
else:
all_tms.append(row_tf @ corrs[block_idx] @ col_tf.T)

return all_tms

@property
def transition_moments(self):
if self._transition_moments is None:
logger = Logger.lookup(self.logger)
null = NullLogger()
if self.degenerate_states is None:
self._transition_moments = self.deperturbed_transition_moments
else:
tmoms = self.deperturbed_transition_moments
all_tms = [[] for _ in tmoms] # num axes
for block_idx, (init_states, final_states) in enumerate(self.state_lists):
for storage, axis_moms in zip(all_tms, tmoms):
storage.append(
self.apply_degenerate_transformations(
init_states, final_states,
axis_moms[block_idx]
)
)
self._transition_moments = all_tms

self._transition_moments = self._apply_degs_to_corrs(
tmoms,
logger = self.logger
)

return self._transition_moments

Expand Down
4 changes: 3 additions & 1 deletion Psience/VPT2/Hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self,
self.molecule = molecule
if modes is None:
modes = molecule.normal_modes.modes
full_modes = modes
if isinstance(mode_selection, dict):
submodes = mode_selection.get('modes', None)
if submodes is not None:
Expand Down Expand Up @@ -139,7 +140,8 @@ def __init__(self,
V_terms = None
else:
V_terms = PotentialTerms(self.molecule,
modes=modes, mode_selection=mode_selection,
modes=modes,
mode_selection=mode_selection,
full_surface_mode_selection=full_surface_mode_selection,
potential_derivatives=potential_derivatives,
allow_higher_potential_terms=allow_higher_potential_terms,
Expand Down
7 changes: 5 additions & 2 deletions Psience/VPT2/Runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2518,6 +2518,8 @@ def construct(cls,
mixed_derivative_handling_mode=mixed_derivative_handling_mode,
**settings
)
if isinstance(degeneracy_specs, (list, tuple)) and len(degeneracy_specs) == 0:
degeneracy_specs = None

if states is not None and not isinstance(states, MultiVPTStateSpace):
if (
Expand Down Expand Up @@ -2771,7 +2773,8 @@ def get_operator_corrections(self,
for s in initials
],
order=order, terms=terms,
verbose=verbose, degenerate_states=states.degenerate_pairs,
verbose=verbose,
degenerate_states=states.degenerate_pairs,
**opts
)

Expand Down Expand Up @@ -3112,7 +3115,7 @@ def run_VPT(self,
# states.flat_space.degenerate_states
# )

corrs = AnalyticPerturbationTheoryCorrections(basis, states.state_list_pairs)
corrs = AnalyticPerturbationTheoryCorrections(basis, states.state_list_pairs, logger=self.logger)
with self.logger.block(tag="Calculating frequency corrections"):
energy_corrections = self.get_energy_corrections(
states,
Expand Down
Loading

0 comments on commit cf4219b

Please sign in to comment.