Skip to content

Commit

Permalink
update for latest jax and pandas version
Browse files Browse the repository at this point in the history
  • Loading branch information
zekun-shi committed Apr 29, 2024
1 parent a2cac4d commit e0b846d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion d4ft/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def log_step(self, metrics: NamedTuple, t: int, thresh: float) -> None:
def get_segment_summary(self) -> pd.DataFrame:
segment_df = self.data_df[self.last_t:]
self.last_t = self.data_df.index[-1]
logging.info(f"Iter: {self.last_t}\n{segment_df.mean()}")
logging.info(f"Iter: {self.last_t}\n{segment_df.iloc[-1]}")
return segment_df

def log_summary(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions d4ft/solver/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def scipy_opt(
solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params,
key: jax.random.KeyArray
key: jax.Array
) -> float:
energy_fn_jit = jax.jit(lambda mo_coeff: H.energy_fn(mo_coeff, key)[0])
import jaxopt
Expand All @@ -40,7 +40,7 @@ def scipy_opt(

def sgd(
solver_cfg: GDConfig, H: Hamiltonian, params: hk.Params,
key: jax.random.KeyArray
key: jax.Array
) -> Tuple[RunLogger, Trajectory]:

@jax.jit
Expand Down
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import string
from pathlib import Path
from typing import Any
import jax

import matplotlib.pyplot as plt
import pandas as pd
import shortuuid
from absl import app, flags, logging
from jax.config import config
from ml_collections.config_flags import config_flags

from d4ft.config import D4FTConfig
Expand Down Expand Up @@ -60,8 +60,8 @@ def get_rxn_energy(rxn: str, benchmark: str, df: pd.DataFrame) -> float:


def main(_: Any) -> None:
config.update("jax_enable_x64", FLAGS.use_f64)
config.update("jax_debug_nans", FLAGS.debug_nans)
jax.config.update("jax_enable_x64", FLAGS.use_f64)
jax.config.update("jax_debug_nans", FLAGS.debug_nans)

cfg: D4FTConfig = FLAGS.config
print(cfg)
Expand Down

0 comments on commit e0b846d

Please sign in to comment.