Skip to content

Commit

Permalink
Quick patch of change patching
Browse files Browse the repository at this point in the history
  • Loading branch information
b3m2a1 committed Jul 20, 2023
1 parent 7f39479 commit 7566fa8
Show file tree
Hide file tree
Showing 4 changed files with 1,696 additions and 163 deletions.
19 changes: 18 additions & 1 deletion Psience/BasisReps/HarmonicOscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,17 @@ def get_poly_coeffs(cls, a, b, shift=0):
coeffs = DensePolynomial._compute_shifted_coeffs(coeffs, shift)
return coeffs

@classmethod
def get_direction_change_poly(cls, delta, shift):
if not isinstance(delta, (int, np.integer)):
return [cls.get_direction_change_poly(d, s) for d,s in zip(delta, shift)]
sqrt_contrib = None
if delta != 0:
dir_change = np.sign(delta) == -np.sign(shift) # avoid the zero case
if dir_change:
sqrt_contrib = HarmonicOscillatorRaisingLoweringPolyTerms.get_sqrt_remainder_coeffs(delta, shift)
return sqrt_contrib

@classmethod
def get_sqrt_remainder_coeffs(cls, delta, k):
# provides rising or falling coeffs with a starting shift
Expand Down Expand Up @@ -526,6 +537,8 @@ def _get_poly_coeffs(cls, terms, delta):
return 0
if abs(delta) > len(terms):
return 0
if len(terms) == 0:
return [1] # just the constant overlap term
# we know we need to change by delta
# over
if terms not in cls._size_blocks_cache:
Expand Down Expand Up @@ -554,7 +567,11 @@ def poly_coeffs(self, delta, shift=0):
if delta not in self._poly_cache:
self._poly_cache[delta] = self._get_poly_coeffs(self.terms, delta)
if shift != 0:
self._poly_cache[(delta, shift)] = DensePolynomial._compute_shifted_coeffs(self._poly_cache[delta], shift=shift)
base_coeffs = self._poly_cache[delta]
if isinstance(base_coeffs, (int, float, np.integer, np.floating)):
self._poly_cache[(delta, shift)] = base_coeffs
else:
self._poly_cache[(delta, shift)] = DensePolynomial._compute_shifted_coeffs(base_coeffs, shift=shift)
else:
self._poly_cache[(delta, shift)] = self._poly_cache[delta]
return self._poly_cache[(delta, shift)]
Expand Down
3 changes: 2 additions & 1 deletion Psience/BasisReps/StateSpaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3254,7 +3254,8 @@ def union(self,
)
if len(where_inds) > 0:
where_inds = np.sort(where_inds)
changes[i_new] = np.concatenate([changes[i_new], c2[where_inds]])
changes[i_new] = np.concatenate([changes[i_new], *[c2[w] for w in where_inds]], axis=0)
# changes[i_new] = np.concatenate([changes[i_new], c2[where_inds]])
# print(" >", changes[i_new])

new_spaces[i_new] = new_spaces[i_new].union(
Expand Down
Loading

0 comments on commit 7566fa8

Please sign in to comment.