Skip to content

Commit

Permalink
Merge pull request #141 from maxnus/refactor_solver_init
Browse files Browse the repository at this point in the history
Refactors solver/__init__.py
  • Loading branch information
basilib authored Aug 21, 2024
2 parents 6500b83 + 7888e35 commit 0df8a9e
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 105 deletions.
2 changes: 1 addition & 1 deletion vayesta/core/qemb/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ def check_solver(self, solver):
is_eb = "crpa_full" in self.opts.screening
else:
is_eb = False
check_solver_config(is_uhf, is_eb, solver, self.log)
check_solver_config(solver, is_uhf, is_eb, self.log)

def get_solver(self, solver=None):
if solver is None:
Expand Down
2 changes: 1 addition & 1 deletion vayesta/core/qemb/qemb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,4 +1755,4 @@ def check_solver(self, solver):
is_eb = "crpa_full" in self.opts.screening
else:
is_eb = False
check_solver_config(is_uhf, is_eb, solver, self.log)
check_solver_config(solver, is_uhf, is_eb, self.log)
2 changes: 1 addition & 1 deletion vayesta/edmet/edmet.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def e_nonlocal(self, value):
def check_solver(self, solver):
is_uhf = np.ndim(self.mo_coeff[1]) == 2
is_eb = True
check_solver_config(is_uhf, is_eb, solver, self.log)
check_solver_config(solver, is_uhf, is_eb, self.log)

def kernel(self):
t_start = timer()
Expand Down
2 changes: 1 addition & 1 deletion vayesta/edmet/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def energy_couplings(self, value):
def check_solver(self, solver):
is_uhf = np.ndim(self.base.mo_coeff[1]) == 2
is_eb = True
check_solver_config(is_uhf, is_eb, solver, self.log)
check_solver_config(solver, is_uhf, is_eb, self.log)

def get_fock(self):
f = self.base.get_fock()
Expand Down
187 changes: 86 additions & 101 deletions vayesta/solver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations
from typing import *

from vayesta.solver.ccsd import RCCSD_Solver, UCCSD_Solver
from vayesta.solver.cisd import RCISD_Solver, UCISD_Solver
from vayesta.solver.coupled_ccsd import coupledRCCSD_Solver
Expand All @@ -12,124 +15,106 @@

try:
from vayesta.solver.ebcc import REBCC_Solver, UEBCC_Solver, EB_REBCC_Solver, EB_UEBCC_Solver
_has_ebcc = True
except ImportError:
REBCC_Solver = UEBCC_Solver = EB_REBCC_Solver = EB_UEBCC_Solver = None
_has_ebcc = False
else:
_has_ebcc = True

if TYPE_CHECKING:
from logging import Logger


def get_solver_class(ham, solver):
assert is_ham(ham)
uhf = is_uhf_ham(ham)
eb = is_eb_ham(ham)
return _get_solver_class(uhf, eb, solver, ham.log)
return _get_solver_class(solver, uhf, eb, ham.log)


def check_solver_config(is_uhf, is_eb, solver, log):
_get_solver_class(is_uhf, is_eb, solver, log)
def check_solver_config(solver, is_uhf, is_eb, log):
_get_solver_class(solver, is_uhf, is_eb, log)


def _get_solver_class(is_uhf, is_eb, solver, log):
def _get_solver_class(solver: str, is_uhf: bool, is_eb: bool, log: Logger) -> Type:
try:
solver_cls = _get_solver_class_internal(is_uhf, is_eb, solver, log)
solver_cls = _get_solver_class_internal(solver, is_uhf, is_eb, log)
return solver_cls
except ValueError as e:
spinmessage = "unrestricted" if is_uhf else "restricted"
bosmessage = "coupled electron-boson" if is_eb else "purely electronic"

fullmessage = f"Error; solver {solver} not available for {spinmessage} {bosmessage} systems"
ebmessage = " with electron-boson coupling" if is_eb else ""
fullmessage = f"solver '{solver}' not available for {spinmessage} systems{ebmessage}"
log.critical(fullmessage)
raise ValueError(fullmessage)


def _get_solver_class_internal(is_uhf, is_eb, solver, log):
# First check if we have a CC approach as implemented in pyscf.
if solver == "CCSD" and not is_eb:
# Use pyscf solvers.
if is_uhf:
return UCCSD_Solver
else:
return RCCSD_Solver
if solver == "TCCSD":
if is_uhf or is_eb:
raise ValueError("TCCSD is not implemented for unrestricted or electron-boson calculations!")
return TRCCSD_Solver
if solver == "extCCSD":
if is_eb:
raise ValueError("extCCSD is not implemented for electron-boson calculations!")
if is_uhf:
return extUCCSD_Solver
return extRCCSD_Solver
if solver == "coupledCCSD":
if is_eb:
raise ValueError("coupledCCSD is not implemented for electron-boson calculations!")
if is_uhf:
raise ValueError("coupledCCSD is not implemented for unrestricted calculations!")
return coupledRCCSD_Solver

# Now consider general CC ansatzes; these are solved via EBCC.
# Note that we support all capitalisations of `ebcc`, but need `CC` to be capitalised when also using this to
# specify an ansatz.
if "CC" in solver.upper():
if not _has_ebcc:
raise ImportError(f"{solver} solver is only accessible via ebcc. Please install ebcc.")
if is_uhf:
if is_eb:
solverclass = EB_UEBCC_Solver
else:
solverclass = UEBCC_Solver
else:
if is_eb:
solverclass = EB_REBCC_Solver
else:
solverclass = REBCC_Solver
if solver.upper() == "EBCC":
# Default to `opts.ansatz`.
return solverclass
if solver[:2].upper() == "EB":
solver = solver[2:]
if solver == "CCSD" and is_eb:
log.warning("CCSD solver requested for coupled electron-boson system; defaulting to CCSD-SD-1-1.")
solver = "CCSD-SD-1-1"

# This is just a wrapper to allow us to use the solver option as the ansatz kwarg in this case.
def get_right_CC(*args, **kwargs):
setansatz = kwargs.get("ansatz", None)
if setansatz is not None:
if setansatz != solver:
raise ValueError(
"Desired CC ansatz specified differently in solver and solver_options.ansatz."
"Please use only specify via one approach, or ensure they agree."
)
kwargs["ansatz"] = solver
return solverclass(*args, **kwargs)

return get_right_CC
if solver == "FCI":
if is_uhf:
if is_eb:
return EB_UEBFCI_Solver
else:
return UFCI_Solver
else:
if is_eb:
return EB_EBFCI_Solver
else:
return FCI_Solver
if is_eb:
raise ValueError("%s solver is not implemented for coupled electron-boson systems!", solver)
if solver == "MP2":
if is_uhf:
return UMP2_Solver
else:
return RMP2_Solver
if solver == "CISD":
if is_uhf:
return UCISD_Solver
else:
return RCISD_Solver
if solver == "DUMP":
return DumpSolver
if solver == 'CALLBACK':
return CallbackSolver
raise ValueError("Unknown solver: %s" % solver)
# (solver_string, is_uhf, is_eb) -> SolverClass
_solver_dict: Dict[Tuple[str, bool, bool], Type] = {
('MP2', False, False): RMP2_Solver,
('MP2', True, False): UMP2_Solver,
('CISD', False, False): RCISD_Solver,
('CISD', True, False): UCISD_Solver,
('CCSD', False, False): RCCSD_Solver,
('CCSD', True, False): UCCSD_Solver,
('TCCSD', False, False): TRCCSD_Solver,
('TCCSD', True, False): NotImplemented,
('extCCSD', False, False): extRCCSD_Solver,
('extCCSD', True, False): extUCCSD_Solver,
('coupledCCSD', False, False): coupledRCCSD_Solver,
('coupledCCSD', True, False): NotImplemented,
('FCI', False, False): FCI_Solver,
('FCI', True, False): UFCI_Solver,
('FCI', False, True): EB_EBFCI_Solver,
('FCI', True, True): EB_UEBFCI_Solver,
('DUMP', False, False): DumpSolver,
('DUMP', True, False): DumpSolver,
('CALLBACK', False, False): CallbackSolver,
}


# (is_uhf, is_eb) -> SolverClass
_ebcc_solver_dict: Dict[Tuple[bool, bool], Type] = {
(False, False): REBCC_Solver,
(True, False): UEBCC_Solver,
(False, True): EB_REBCC_Solver,
(True, True): EB_UEBCC_Solver,
}


def _get_solver_class_internal(solver: str, is_uhf: bool, is_eb: bool, log: Logger) -> Type | Callable:
solver_cls = _solver_dict.get((solver, is_uhf, is_eb), None)
if solver_cls is NotImplemented:
spinsym = 'unrestricted' if is_uhf else 'restricted'
raise NotImplementedError(f"solver '{solver}' for {spinsym} spin-symmetry is not implemented")
if solver_cls is not None:
return solver_cls
if 'CC' not in solver:
raise ValueError(f"unknown solver '{solver}'")
# Try EBCC next
return _get_solver_class_ebcc(solver, is_uhf, is_eb, log)


def _get_solver_class_ebcc(solver: str, is_uhf: bool, is_eb: bool, log: Logger) -> Type | Callable:
if not _has_ebcc:
raise ImportError(f"{solver} solver is only accessible via ebcc. Please install ebcc.")
solver_cls = _ebcc_solver_dict[is_uhf, is_eb]
if solver == "EBCC":
# Default to `opts.ansatz`.
return solver_cls
if solver[:2] == "EB":
solver = solver[2:]
if solver == "CCSD" and is_eb:
solver = "CCSD-SD-1-1"
log.warning(f"CCSD solver requested for coupled electron-boson system; defaulting to {solver}.")

# This is just a wrapper to allow us to use the solver option as the ansatz kwarg in this case.
def get_right_cc(*args, **kwargs):
setansatz = kwargs.get("ansatz", None)
if setansatz is not None and setansatz != solver:
raise ValueError(
f"solver '{solver}' does not match solver_options.ansatz "
f"'{setansatz}'; only specify via one argument or ensure they agree"
)
kwargs["ansatz"] = solver
return solver_cls(*args, **kwargs)

return get_right_cc

0 comments on commit 0df8a9e

Please sign in to comment.