Skip to content

Commit

Permalink
fix sydney model for batches > 128
Browse files Browse the repository at this point in the history
  • Loading branch information
wkirgsn committed Aug 4, 2024
1 parent a41f42c commit 50020f0
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 21 deletions.
24 changes: 9 additions & 15 deletions src_py/magnethub/sydney.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,14 @@ def __call__(self, data_B, data_F, data_T):
if data_B.ndim == 1:
data_B = np.array(data_B).reshape(1, -1)

loader = get_dataloader(data_B, data_F, data_T, self.mdl.norm)
_, ts_feats, scalar_feats = get_dataloader(data_B, data_F, data_T, self.mdl.norm)

# 2.Validate the models
data_P = torch.Tensor([]).to(self.device) # Allocate memory to store loss density

with torch.no_grad():
self.mdl.eval()
with torch.inference_mode():
# Start model evaluation explicitly
self.mdl.eval()
for inputs, vars in loader:
Pv, h_series = self.mdl(inputs.to(self.device), vars.to(self.device))
data_P, h_series = self.mdl(ts_feats.to(self.device), scalar_feats.to(self.device))

data_P = torch.cat((data_P, Pv.to(self.device)), dim=0)
data_P, h_series = data_P.cpu().numpy(), h_series.cpu().numpy()

# 3.Return results
Expand Down Expand Up @@ -337,9 +333,7 @@ def forward(self, x, hidden=None):


def get_dataloader(data_B, data_F, data_T, norm, n_init=32):
"""
Preprocess data into a data loader.
"""Preprocess data into a data loader.
Get a test dataloader.
Parameters
Expand Down Expand Up @@ -394,14 +388,14 @@ def get_dataloader(data_B, data_F, data_T, norm, n_init=32):

s0 = get_operator_init(in_B[:, 0, 0] - in_dB[:, 0, 0], in_dB, max_B, min_B) # Operator inital state

ts_feats = torch.cat((in_B, in_dB, in_dB_dt), dim=2)
scalar_feats = torch.cat((in_F, in_T, s0), dim=1)
# 6. Create dataloader to speed up data processing
test_dataset = torch.utils.data.TensorDataset(
torch.cat((in_B, in_dB, in_dB_dt), dim=2), torch.cat((in_F, in_T, s0), dim=1)
)
test_dataset = torch.utils.data.TensorDataset(ts_feats, scalar_feats)
kwargs = {"num_workers": 0, "batch_size": 128, "drop_last": False}
test_loader = torch.utils.data.DataLoader(test_dataset, **kwargs)

return test_loader
return test_loader, ts_feats, scalar_feats


# %% Predict the operator state at t0
Expand Down
27 changes: 27 additions & 0 deletions tests/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
import pandas as pd
from pathlib import Path
from magnethub.loss import LossModel, MATERIALS

test_ds = pd.read_csv(
Path.cwd() / "tests" / "test_files" / "all_data.csv.gzip", dtype={"material": str}
)
errs_d = {}
for m_lbl in MATERIALS:
mdl = LossModel(material=m_lbl, team="paderborn")
test_mat_df = test_ds.query("material == @m_lbl")
p, h = mdl(
test_mat_df.loc[:, [c for c in test_mat_df if c.startswith("B_t_")]].to_numpy(),
test_mat_df.loc[:, "freq"].to_numpy(),
test_mat_df.loc[:, "temp"].to_numpy(),
)
rel_err = np.abs(test_mat_df.ploss - p) / test_mat_df.ploss
errs_d[m_lbl] = {
"avg": np.mean(rel_err),
"95th": np.quantile(rel_err, 0.95),
"99th": np.quantile(rel_err, 0.99),
'samples': len(rel_err),
}
rel_df = pd.DataFrame(errs_d).T
print(f"Rel. errors")
print(rel_df)
12 changes: 6 additions & 6 deletions tests/test_sydney.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ def test_longer_sequence():

def test_batch_execution():
mdl = LossModel(material="3C92", team=TEAM_NAME)

b_waves = np.random.randn(100, 1024) * 200e-3 # mT
freqs = np.random.randint(100e3, 750e3, size=100)
temps = np.random.randint(20, 80, size=100)
seq_len = 1412
b_waves = np.random.randn(seq_len, 1024) * 200e-3 # mT
freqs = np.random.randint(100e3, 750e3, size=seq_len)
temps = np.random.randint(20, 80, size=seq_len)
p, h = mdl(b_waves, freqs, temps)

assert p.size == 100, f"{p.size=}"
assert h.shape == (100, 1024), f"{h.shape=}"
assert p.size == seq_len, f"{p.size=}"
assert h.shape == (seq_len, 1024), f"{h.shape=}"


def test_material_availability():
Expand Down

0 comments on commit 50020f0

Please sign in to comment.