diff --git a/Psience/VPT2/Corrections.py b/Psience/VPT2/Corrections.py index 07a53ac..e229f42 100644 --- a/Psience/VPT2/Corrections.py +++ b/Psience/VPT2/Corrections.py @@ -4,10 +4,13 @@ from McUtils.Numputils import SparseArray import McUtils.Numputils as nput from McUtils.Data import UnitsData -from McUtils.Scaffolding import NullLogger, Checkpointer +from McUtils.Scaffolding import Logger, NullLogger, Checkpointer from ..Spectra import DiscreteSpectrum -from ..BasisReps import BasisStateSpace, BasisMultiStateSpace, SelectionRuleStateSpace, HarmonicOscillatorProductBasis +from ..BasisReps import ( + BasisStateSpace, BasisMultiStateSpace, SelectionRuleStateSpace, HarmonicOscillatorProductBasis, + StateMaker +) from .Common import PerturbationTheoryException, _safe_dot __all__ = [ @@ -920,12 +923,14 @@ class AnalyticPerturbationTheoryCorrections: degenerate_states: 'Iterable[BasisStateSpace]' = None _degenerate_hamiltonians: 'Iterable[np.ndarray]' = None _degenerate_coefficients: 'Iterable[np.ndarray]' = None + _degenerate_state_list_transformations: 'Iterable[list[np.ndarray, np.ndarray]]' = None energy_corrections: BasicAPTCorrections = None transition_moment_corrections: 'Iterable[BasicAPTCorrections]' = None degenerate_hamiltonian_corrections: 'Iterable[BasicAPTCorrections]' = None operator_corrections: 'Iterable[BasicAPTCorrections]' = None _deperturbed_operator_values: 'Iterable[np.ndarray]' = None _operator_values: 'Iterable[np.ndarray]' = None + logger: 'Logger' = None _zpe_pos: int = None @@ -1027,25 +1032,36 @@ def get_degenerate_transformations(self, basis, energies): return energies, (hams, transf) - def apply_degenerate_transformations(self, initial_states, final_states, subcorr): - initial_space = BasisStateSpace( + def _apply_degenerate_transformations(self, + initial_states, final_states, subcorr, + initial_space=None, final_space=None, + all_degs=None, + degenerate_mapping=None + ): + if initial_space is None: + initial_space = BasisStateSpace( + HarmonicOscillatorProductBasis(len(initial_states[0])), + initial_states + ) + if final_space is None: + final_space = BasisStateSpace( + HarmonicOscillatorProductBasis(len(initial_states[0])), + final_states + ) + if all_degs is None: + all_degs = BasisStateSpace( HarmonicOscillatorProductBasis(len(initial_states[0])), - initial_states + np.concatenate(self.degenerate_states, axis=0) ) - final_space = BasisStateSpace( - HarmonicOscillatorProductBasis(len(initial_states[0])), - final_states - ) - all_degs = BasisStateSpace( - HarmonicOscillatorProductBasis(len(initial_states[0])), - np.concatenate(self.degenerate_states, axis=0) - ) - deg_map_row = np.concatenate([ - [i] * len(g) for i,g in enumerate(self.degenerate_states) - ]) - deg_map_col = np.concatenate([ - np.arange(len(g)) for i, g in enumerate(self.degenerate_states) - ]) + if degenerate_mapping is None: + deg_map_row = np.concatenate([ + [i] * len(g) for i,g in enumerate(self.degenerate_states) + ]) + deg_map_col = np.concatenate([ + np.arange(len(g)) for i, g in enumerate(self.degenerate_states) + ]) + else: + deg_map_row, deg_map_col = degenerate_mapping init_pos = all_degs.find(initial_states, missing_val=-1) final_pos = all_degs.find(final_states, missing_val=-1) @@ -1092,23 +1108,154 @@ def get_deperturbed_freqs(self): else: return self.get_freqs() + @property + def degenerate_transformation_pairs(self): + if self._degenerate_state_list_transformations is None: + self._degenerate_state_list_transformations = self._get_degenerate_tfs_mats() + return self._degenerate_state_list_transformations + def _get_degenerate_tfs_mats(self, logger=None): + #TODO: add checks to ensure that our blocks are complete and we have proper unitary tfs at the end + if logger is None: + logger = self.logger + logger = Logger.lookup(logger) + all_degs = BasisStateSpace( + HarmonicOscillatorProductBasis(len(self.state_lists[0][0][0])), + np.concatenate(self.degenerate_states, axis=0) + ) + deg_map_row = np.concatenate([ + [i] * len(g) for i, g in enumerate(self.degenerate_states) + ]) + deg_map_col = np.concatenate([ + np.arange(len(g)) for i, g in enumerate(self.degenerate_states) + ]) + for i, block in enumerate(self.degenerate_states): + with logger.block(tag="Degenerate block {i}", i=i + 1): + logger.log_print( + "{blocks}", + blocks=block, + preformatter=lambda **kw: dict( + kw, + blocks="\n".join(StateMaker.parse_state(e) for e in kw['blocks']) + ) + ) + + tfs = [] + for block_idx, (init_states, final_states) in enumerate(self.state_lists): + initial_space = BasisStateSpace( + HarmonicOscillatorProductBasis(len(init_states[0])), + init_states + ) + final_space = BasisStateSpace( + HarmonicOscillatorProductBasis(len(init_states[0])), + final_states + ) + init_pos = all_degs.find(initial_space, missing_val=-1) + final_pos = all_degs.find(final_space, missing_val=-1) + + # subcorr is a n_init x n_final object, but we need to figure out the transformation to apply to each axis + # to do so we find the appropriate transformation and insert it + row_tf = np.eye(len(init_states)) + for n,i in enumerate(init_pos): + if i != -1: + col_pos = deg_map_row[i] + row_pos = deg_map_col[i] + deg_block = self.degenerate_coefficients[col_pos] + block_idx = initial_space.find(self.degenerate_states[col_pos]) + # print(row_pos, col_pos, col_tf.shape) + row_tf[n, block_idx] = deg_block[row_pos] + + nz_init_pos = [i for i in init_pos if i > -1] + nz_init = [s for i, s in zip(init_pos, init_states) if i > -1] + if len(nz_init_pos) > 0: + with logger.block(tag="Initial states ({ix})", + ix=[(deg_map_row[i], deg_map_col[i]) for i in nz_init_pos]): + logger.log_print( + "{initials}", + initials=nz_init, + preformatter=lambda **kw: dict( + kw, + initials="\n".join(StateMaker.parse_state(e) for e in kw['initials']) + ) + ) + logger.log_print( + "{tf}", + tf=row_tf, + preformatter=lambda **kw:dict(kw, tf="\n".join(logger.prep_array(kw['tf']))) + ) + if np.sum(np.abs((row_tf @ row_tf.T) - np.eye(len(init_states))).flatten()) > 1e-3: + raise ValueError("Non-unitary row tf, something wrong with initial state degs") + + col_tf = np.eye(len(final_states)) + for n,i in enumerate(final_pos): + if i != -1: + col_pos = deg_map_row[i] + deg_block = self.degenerate_coefficients[col_pos] + block_idx = final_space.find(self.degenerate_states[col_pos]) + row_pos = deg_map_col[i] + # print(row_pos, col_pos, col_tf.shape) + col_tf[n, block_idx] = deg_block[row_pos] + nz_final_pos = [i for i in final_pos if i > -1] + nz_final = [s for i, s in zip(final_pos, final_states) if i > -1] + if len(nz_final_pos) > 0: + with logger.block(tag="Final states ({ix})", + ix=[(deg_map_row[i], deg_map_col[i]) for i in nz_final_pos]): + logger.log_print( + "{finals}", + finals=nz_final, + preformatter=lambda **kw: dict( + kw, + finals="\n".join(StateMaker.parse_state(e) for e in kw['finals']) + ) + ) + logger.log_print( + "{tf}", + tf=col_tf, + preformatter=lambda **kw:dict(kw, tf="\n".join(logger.prep_array(kw['tf']))) + ) + if np.sum(np.abs((row_tf @ row_tf.T) - np.eye(len(init_states))).flatten()) > 1e-3: + raise ValueError("Non-unitary col tf, something wrong with final state degs") + + tfs.append([row_tf, col_tf]) + return tfs + def _apply_degs_to_corrs(self, corrs, logger=None): + if logger is None: + logger = self.logger + logger = Logger.lookup(logger) + + + has_subcorrs = not (isinstance(corrs[0], np.ndarray) and corrs[0].ndim == 2) + if has_subcorrs: + all_tms = [[] for _ in corrs] # num axes + else: + all_tms = [] + tfs = self.degenerate_transformation_pairs + + for block_idx, (row_tf, col_tf) in enumerate(tfs): + + if has_subcorrs: + for a, (storage, axis_moms) in enumerate(zip(all_tms, corrs)): + storage.append( + row_tf @ axis_moms[block_idx] @ col_tf.T + ) + else: + all_tms.append(row_tf @ corrs[block_idx] @ col_tf.T) + + return all_tms + @property def transition_moments(self): if self._transition_moments is None: + logger = Logger.lookup(self.logger) + null = NullLogger() if self.degenerate_states is None: self._transition_moments = self.deperturbed_transition_moments else: tmoms = self.deperturbed_transition_moments - all_tms = [[] for _ in tmoms] # num axes - for block_idx, (init_states, final_states) in enumerate(self.state_lists): - for storage, axis_moms in zip(all_tms, tmoms): - storage.append( - self.apply_degenerate_transformations( - init_states, final_states, - axis_moms[block_idx] - ) - ) - self._transition_moments = all_tms + + self._transition_moments = self._apply_degs_to_corrs( + tmoms, + logger = self.logger + ) return self._transition_moments diff --git a/Psience/VPT2/Hamiltonian.py b/Psience/VPT2/Hamiltonian.py index 8047416..06169e6 100755 --- a/Psience/VPT2/Hamiltonian.py +++ b/Psience/VPT2/Hamiltonian.py @@ -100,6 +100,7 @@ def __init__(self, self.molecule = molecule if modes is None: modes = molecule.normal_modes.modes + full_modes = modes if isinstance(mode_selection, dict): submodes = mode_selection.get('modes', None) if submodes is not None: @@ -139,7 +140,8 @@ def __init__(self, V_terms = None else: V_terms = PotentialTerms(self.molecule, - modes=modes, mode_selection=mode_selection, + modes=modes, + mode_selection=mode_selection, full_surface_mode_selection=full_surface_mode_selection, potential_derivatives=potential_derivatives, allow_higher_potential_terms=allow_higher_potential_terms, diff --git a/Psience/VPT2/Runner.py b/Psience/VPT2/Runner.py index 94f63dd..c2b470e 100644 --- a/Psience/VPT2/Runner.py +++ b/Psience/VPT2/Runner.py @@ -2518,6 +2518,8 @@ def construct(cls, mixed_derivative_handling_mode=mixed_derivative_handling_mode, **settings ) + if isinstance(degeneracy_specs, (list, tuple)) and len(degeneracy_specs) == 0: + degeneracy_specs = None if states is not None and not isinstance(states, MultiVPTStateSpace): if ( @@ -2771,7 +2773,8 @@ def get_operator_corrections(self, for s in initials ], order=order, terms=terms, - verbose=verbose, degenerate_states=states.degenerate_pairs, + verbose=verbose, + degenerate_states=states.degenerate_pairs, **opts ) @@ -3112,7 +3115,7 @@ def run_VPT(self, # states.flat_space.degenerate_states # ) - corrs = AnalyticPerturbationTheoryCorrections(basis, states.state_list_pairs) + corrs = AnalyticPerturbationTheoryCorrections(basis, states.state_list_pairs, logger=self.logger) with self.logger.block(tag="Calculating frequency corrections"): energy_corrections = self.get_energy_corrections( states, diff --git a/Psience/VPT2/Terms.py b/Psience/VPT2/Terms.py index 4d9eeb9..2eadb72 100755 --- a/Psience/VPT2/Terms.py +++ b/Psience/VPT2/Terms.py @@ -1278,7 +1278,7 @@ def __init__(self, :type mode_selection: None | Iterable[int] """ - self.full_modes=full_surface_mode_selection + self.full_mode_sel = full_surface_mode_selection super().__init__(molecule, modes, mode_selection=mode_selection, logger=logger, parallelizer=parallelizer, checkpointer=checkpointer, @@ -1300,7 +1300,7 @@ def v_derivs(self): if self._v_derivs is None: if self._input_derivs is None: self._input_derivs = self.molecule.potential_surface.derivatives - self._v_derivs = self._canonicalize_derivs(self.freqs, self.masses, self._input_derivs, self.full_modes) + self._v_derivs = self._canonicalize_derivs(self.freqs, self.masses, self._input_derivs, self.full_mode_sel) return self._v_derivs @v_derivs.setter @@ -1316,7 +1316,7 @@ def _check_mode_terms(self, derivs=None): if d.shape != (modes_n,) * len(d.shape): return False return True - def _canonicalize_derivs(self, freqs, masses, derivs, full_modes): + def _canonicalize_derivs(self, freqs, masses, derivs, full_mode_sel): if self._check_mode_terms(derivs): return derivs @@ -1340,20 +1340,21 @@ def _canonicalize_derivs(self, freqs, masses, derivs, full_modes): fourths = derivs[3] if len(derivs) > 3 else None n = self.num_atoms + full_modes_n = self._presel_dim modes_n = len(self.modes.freqs) internals_n = 3 * n - 6 coord_n = 3 * n - if len(derivs) > 2 and full_modes is not None and thirds.shape[0] != modes_n: - new_thirds = np.zeros((modes_n,) + fcs.shape) - new_thirds[full_modes,] = thirds + if len(derivs) > 2 and full_mode_sel is not None and thirds.shape[0] != modes_n: + new_thirds = np.zeros((full_modes_n,) + fcs.shape) + new_thirds[full_mode_sel,] = thirds thirds = new_thirds if len(derivs) > 2 and self.mode_sel is not None and thirds.shape[0] == self._presel_dim: thirds = thirds[(self.mode_sel,)] - if len(derivs) > 2 and full_modes is not None and fourths.shape[0] != modes_n: - new_fourths = np.zeros((modes_n, modes_n) + fcs.shape) - new_fourths[np.ix_(full_modes, full_modes)] = fourths + if len(derivs) > 2 and full_mode_sel is not None and fourths.shape[0] != modes_n: + new_fourths = np.zeros((full_modes_n, full_modes_n) + fcs.shape) + new_fourths[np.ix_(full_mode_sel, full_mode_sel)] = fourths fourths = new_fourths if len(derivs) > 3 and self.mode_sel is not None and fourths.shape[0] == self._presel_dim: if not isinstance(self.mode_sel, slice): @@ -2528,7 +2529,7 @@ def __init__(self, :type mode_selection: None | Iterable[int] """ self.derivs = None - self.full_modes=full_surface_mode_selection + self.full_mode_sel = full_surface_mode_selection super().__init__(molecule, modes=modes, mode_selection=mode_selection, logger=logger, parallelizer=parallelizer, checkpointer=checkpointer, **opts @@ -2538,9 +2539,9 @@ def __init__(self, self.mixed_derivs = mixed_derivs if dipole_derivatives is None: dipole_derivatives = molecule.dipole_surface.derivatives - self.derivs = self._canonicalize_derivs(self.freqs, self.masses, dipole_derivatives, self.full_modes) + self.derivs = self._canonicalize_derivs(self.freqs, self.masses, dipole_derivatives, self.full_mode_sel) - def _canonicalize_derivs(self, freqs, masses, derivs, full_modes): + def _canonicalize_derivs(self, freqs, masses, derivs, full_mode_sel): """ Makes sure all of the dipole moments are clean and ready to rotate """ @@ -2581,20 +2582,22 @@ def _canonicalize_derivs(self, freqs, masses, derivs, full_modes): thirds = derivs[3] if len(derivs) > 3 else None n = len(masses) + full_mode_n = self._presel_dim modes_n = len(self.modes.freqs) internals_n = 3 * n - 6 coord_n = 3 * n - if len(derivs) > 2 and full_modes is not None and seconds.shape[0] != modes_n: - new_seconds = np.zeros((modes_n,) + grad.shape) - new_seconds[full_modes,] = seconds + + if len(derivs) > 2 and full_mode_sel is not None and seconds.shape[0] != modes_n: + new_seconds = np.zeros((full_mode_n,) + grad.shape) + new_seconds[full_mode_sel,] = seconds seconds = new_seconds if len(derivs) > 2 and self.mode_sel is not None and seconds.shape[0] == self._presel_dim: seconds = seconds[(self.mode_sel,)] - if len(derivs) > 2 and full_modes is not None and thirds.shape[0] != modes_n: - new_thirds = np.zeros((modes_n,modes_n) + grad.shape) - new_thirds[np.ix_(full_modes,full_modes)] = thirds + if len(derivs) > 2 and full_mode_sel is not None and thirds.shape[0] != modes_n: + new_thirds = np.zeros((full_mode_n,full_mode_n) + grad.shape) + new_thirds[np.ix_(full_mode_sel, full_mode_sel)] = thirds thirds = new_thirds if self.mode_sel is not None and thirds.shape[0] == self._presel_dim: if not isinstance(self.mode_sel, slice): diff --git a/ci/tests/VPT2Tests.py b/ci/tests/VPT2Tests.py index 8356155..c992a89 100755 --- a/ci/tests/VPT2Tests.py +++ b/ci/tests/VPT2Tests.py @@ -124,7 +124,64 @@ def test_AnalyticWFC(self): handle_degeneracies=True ) - @debugTest + @inactiveTest + def test_PartialRebuild(self): + state = VPTStateMaker(7) + corrs = AnalyticVPTRunner.run_simple( + mol, + # [ + # state(), + # state(21), + # state(20), + # state(19), + # state([21, 2]), + # state([20, 2]), + # state([19, 2]), + # state(21, 19), + # state(21, 20), + # state(20, 19) + # ], + [ + [ + [state()], + [ + state(1), + state(2), + state(3), + state([1, 2]), + state([2, 2]), + state([3, 2]), + state(1, 2), + state(1, 3), + state(2, 3), + ] + ] + ], + full_surface_mode_selection=[108 - 108, 108 - 107, 108 - 106, 108 - 105, 108 - 21, 108 - 20, 108 - 19, + 108 - 1], + # degeneracy_specs = [ + # {'polyads':[[state(19), state(20, 107)]]}, + # {'polyads':[[state(106, 106), state(108, 108)]]}, + # {'polyads':[[state(106), state(107)], [state(105), state(107)], [state(105), state(108)]]} + # ], + mode_selection=[108 - 108, 108 - 107, 108 - 106, 108 - 105, 108 - 21, 108 - 20, 108 - 19], + # degeneracy_specs = [ + # {'polyads':[[state(1), state(2, 6)]]}, + # {'polyads':[[state([5, 2]), state([7, 2])]]}, + # { + # 'polyads': + # [ + # [state(4), state(6)], + # [state(4), state(7)], + # [state(5), state(6)] + # ] + # } + # ], + # logger=output_file, + expressions_file=os.path.expanduser("exprs.hdf5") + ) + + @validationTest def test_AnalyticOCHHMultiple(self): file_name = "OCHH_freq.fchk" @@ -140,33 +197,33 @@ def test_AnalyticOCHHMultiple(self): # [0, 0, 0, 1, 1, 0], ], ], - [ - [0, 0, 0, 0, 1, 0], - [ - [0, 0, 0, 0, 1, 1], - [0, 1, 0, 1, 1, 0] - ] - ] + # [ + # [0, 0, 0, 0, 1, 0], + # [ + # [0, 0, 0, 0, 1, 1], + # [0, 1, 0, 1, 1, 0] + # ] + # ] ], - expressions_file=os.path.expanduser("~/Desktop/exprs.hdf5"), - degeneracy_specs=None, + # expressions_file=os.path.expanduser("~/Desktop/exprs.hdf5"), + # degeneracy_specs=None, # degeneracy_specs = { # 'polyads': [ # [[0, 0, 0, 0, 0, 1], [0, 1, 0, 1, 0, 0]] # ] # } - # degeneracy_specs=[ - # { - # 'polyads': [ - # [[0, 0, 0, 0, 0, 1], [0, 1, 0, 1, 0, 0]] - # ] - # }, - # { - # 'polyads': [ - # [[0, 0, 0, 0, 1, 0], [0, 1, 0, 1, 0, 0]] - # ] - # } - # ] + degeneracy_specs=[ + { + 'polyads': [ + [[0, 0, 0, 0, 0, 1], [0, 1, 0, 1, 0, 0]] + ] + }, + { + 'polyads': [ + [[0, 0, 0, 0, 1, 0], [0, 1, 0, 1, 0, 0]] + ] + } + ] ) @validationTest