Skip to content

Commit

Permalink
add function df_to_in_mem_dataloader() to extract train/test_loader c…
Browse files Browse the repository at this point in the history
…reation from run_wrenformer()
  • Loading branch information
janosh committed Jun 23, 2022
1 parent 3c100bb commit 983a0b8
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 62 deletions.
8 changes: 5 additions & 3 deletions aviary/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ class InMemoryDataLoader:
Args:
*tensors: List of arrays or tensors. Must all have the same length in dimension 0.
batch_size (int, optional): Defaults to 32.
batch_size (int, optional): Usually 64, 128 or 256. Can be larger for test set loaders
to speedup inference. Defaults to 64.
shuffle (bool, optional): If True, shuffle the data *in-place* whenever an
iterator is created from this object. Defaults to False.
collate_fn (Callable, optional): Should accept variadic list of tensors and
output a minibatch of data ready for model consumption. Defaults to tuple().
"""

tensors: list[Tensor]
batch_size: int = 32
# each item must be indexable (usually torch.tensor, np.array or pd.Series)
tensors: list[Tensor | np.ndarray]
batch_size: int = 64
shuffle: bool = False
collate_fn: Callable = tuple

Expand Down
56 changes: 56 additions & 0 deletions aviary/wrenformer/data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

import json
from typing import Literal

import numpy as np
import pandas as pd
import torch
from pymatgen.core import Composition
from torch import LongTensor, Tensor, nn

from aviary import PKG_DIR
from aviary.data import InMemoryDataLoader
from aviary.wren.data import parse_aflow_wyckoff_str


Expand Down Expand Up @@ -118,3 +121,56 @@ def get_composition_embedding(formula: str) -> Tensor:
combined_features = torch.cat([element_ratios, element_features], dim=1).float()

return combined_features


def df_to_in_mem_dataloader(
df: pd.DataFrame,
target_col: str,
input_col: str = "wyckoff",
id_col: str = "material_id",
embedding_type: Literal["wyckoff", "composition"] = "wyckoff",
device: str = None,
**kwargs,
) -> InMemoryDataLoader:
"""Construct an InMemoryDataLoader with Wrenformer batch collation from a dataframe.
Can also be used for Roostformer.
Args:
df (pd.DataFrame): Expected to have columns input_col, target_col, id_col.
target_col (str): Column name holding the target values.
input_col (str): Column name holding the input values (Aflow Wyckoff labels or composition
strings) from which initial embeddings will be constructed. Defaults to "wyckoff".
id_col (str): Column name holding material identifiers. Defaults to "material_id".
embedding_type ('wyckoff' | 'composition'): Defaults to "wyckoff".
device (str): torch.device to load tensors onto. Defaults to
"cuda" if torch.cuda.is_available() else "cpu".
kwargs (dict): Keyword arguments like batch_size: int and shuffle: bool
to pass to InMemoryDataLoader. Defaults to None.
Returns:
InMemoryDataLoader: Ready for use in model.evaluate(data_loader) or
[model(x) for x in data_loader]
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

if embedding_type not in ["wyckoff", "composition"]:
raise ValueError(f"{embedding_type = } must be 'wyckoff' or 'composition'")

initial_embeddings = df[input_col].map(
wyckoff_embedding_from_aflow_str
if embedding_type == "wyckoff"
else get_composition_embedding
)
targets = torch.tensor(df[target_col], device=device)
if targets.dtype == torch.bool:
targets = targets.long() # convert binary classification targets to 0 and 1
inputs = np.empty(len(initial_embeddings), dtype=object)
for idx, tensor in enumerate(initial_embeddings):
inputs[idx] = tensor.to(device)

ids = df[id_col].to_numpy()
data_loader = InMemoryDataLoader(
[inputs, targets, ids], collate_fn=collate_batch, **kwargs
)
return data_loader
2 changes: 1 addition & 1 deletion examples/mat_bench/run_wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def run_wrenformer_on_matbench(
target_col=target,
task_type=task_type,
# set to None to disable logging
wandb_project=kwargs.get("wandb_project", "mp-wbm"),
wandb_project=kwargs.pop("wandb_project", "mp-wbm"),
id_col=id_col,
run_params={
"dataset": dataset_name,
Expand Down
91 changes: 33 additions & 58 deletions examples/wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,9 @@

from aviary import ROOT
from aviary.core import Normalizer, TaskType
from aviary.data import InMemoryDataLoader
from aviary.losses import RobustL1Loss
from aviary.utils import get_metrics
from aviary.wrenformer.data import (
collate_batch,
get_composition_embedding,
wyckoff_embedding_from_aflow_str,
)
from aviary.wrenformer.data import df_to_in_mem_dataloader
from aviary.wrenformer.model import Wrenformer
from aviary.wrenformer.utils import print_walltime

Expand All @@ -42,6 +37,7 @@ def run_wrenformer(
target_col: str,
epochs: int,
timestamp: str = None,
input_col: str = None,
id_col: str = "material_id",
n_attn_layers: int = 4,
wandb_project: str = None,
Expand All @@ -67,6 +63,8 @@ def run_wrenformer(
train_df (pd.DataFrame): Dataframe containing the training data.
test_df (pd.DataFrame): Dataframe containing the test data.
target_col (str): Name of df column containing the target values.
input_col (str): Name of df column containing the input values. Defaults to 'wyckoff' if
'wren' in run_name else 'composition'.
id_col (str): Name of df column containing material IDs.
epochs (int): How many epochs to train for. Defaults to 100.
timestamp (str): Will be included in run_params and used as file name prefix for model
Expand Down Expand Up @@ -118,23 +116,16 @@ def run_wrenformer(
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Pytorch running on {device=}")

for label, df in [("training set", train_df), ("test set", test_df)]:
if "wren" in run_name.lower():
err_msg = "Missing 'wyckoff' column in dataframe. "
err_msg += (
"Please generate Aflow Wyckoff labels ahead of time."
if "structure" in df
else "Trying to deploy Wrenformer on composition-only task?"
)
assert "wyckoff" in df, err_msg
with print_walltime(
start_desc=f"Generating Wyckoff embeddings for {label}", newline=False
):
df["features"] = df.wyckoff.map(wyckoff_embedding_from_aflow_str)
elif "roost" in run_name.lower():
df["features"] = df.composition.map(get_composition_embedding)
else:
raise ValueError(f"{run_name = } must contain 'roost' or 'wren'")
if "wren" in run_name.lower():
input_col = input_col or "wyckoff"
embedding_type = "wyckoff"
elif "roost" in run_name.lower():
input_col = input_col or "composition"
embedding_type = "composition"
else:
raise ValueError(
f"{run_name = } must contain 'roost' or 'wren' (case insensitive)"
)

robust = "robust" in run_name.lower()
loss_func = (
Expand All @@ -145,42 +136,24 @@ def run_wrenformer(
loss_dict = {target_col: (task_type, loss_func)}
normalizer_dict = {target_col: Normalizer() if task_type == reg_key else None}

features, targets, ids = (train_df[x] for x in ["features", target_col, id_col])
targets = torch.tensor(targets, device=device)
if targets.dtype == torch.bool:
targets = targets.long()
inputs = np.empty(len(features), dtype=object)
for idx, tensor in enumerate(features):
inputs[idx] = tensor.to(device)

train_loader = InMemoryDataLoader(
[inputs, targets, ids],
batch_size=batch_size,
shuffle=True,
collate_fn=collate_batch,
data_loader_kwargs = dict(
target_col=target_col,
input_col=input_col,
id_col=id_col,
embedding_type=embedding_type,
)

features, targets, ids = (test_df[x] for x in ["features", target_col, id_col])
targets = torch.tensor(targets, device=device)
if targets.dtype == torch.bool:
targets = targets.long()
inputs = np.empty(len(features), dtype=object)
for idx, tensor in enumerate(features):
inputs[idx] = tensor.to(device)

test_loader = InMemoryDataLoader(
[inputs, targets, ids], batch_size=512, collate_fn=collate_batch
train_loader = df_to_in_mem_dataloader(
train_df, batch_size=batch_size, shuffle=True, **data_loader_kwargs
)

# n_features is the length of the embedding vector for a Wyckoff position encoding
# the element type (usually 200-dim matscholar embeddings) and Wyckoff position (see
# 'bra-alg-off.json') + 1 for the weight of that element/Wyckoff position in the
# material's composition
embedding_len = features[0].shape[-1]
assert embedding_len in (
200 + 1,
200 + 1 + 444,
) # Roost and Wren embedding size resp.
test_loader = df_to_in_mem_dataloader(test_df, batch_size=512, **data_loader_kwargs)

# embedding_len is the length of the embedding vector for a Wyckoff position encoding the
# element type (usually 200-dim matscholar embeddings) and Wyckoff position (see
# 'bra-alg-off.json') + 1 for the weight of that Wyckoff position (or element) in the material
embedding_len = train_loader.tensors[0][0].shape[-1]
# Roost and Wren embedding size resp.
assert embedding_len in (200 + 1, 200 + 1 + 444), f"{embedding_len=}"

model_params = dict(
# 1 for regression, n_classes for classification
Expand Down Expand Up @@ -242,6 +215,7 @@ def run_wrenformer(
"training_samples": len(train_df),
"test_samples": len(test_df),
"trainable_params": model.num_params,
"task_type": task_type,
"swa": {
"start": swa_start,
"epochs": int(swa_start * epochs),
Expand Down Expand Up @@ -272,6 +246,7 @@ def run_wrenformer(
for epoch in range(epochs):
if verbose:
print(f"Epoch {epoch + 1}/{epochs}")

train_metrics = model.evaluate(
train_loader,
loss_dict,
Expand Down Expand Up @@ -333,9 +308,9 @@ def run_wrenformer(
predictions = predictions.softmax(dim=1)

predictions = predictions.cpu().numpy().squeeze()
targets = targets.cpu().numpy()
targets = test_df[target_col]
pred_col = f"{target_col}_pred"
test_df[pred_col] = predictions.tolist()
test_df[pred_col] = predictions.tolist() # requires shuffle=False for test_loader

test_metrics = get_metrics(targets, predictions, task_type)
test_metrics["test_size"] = len(test_df)
Expand Down

0 comments on commit 983a0b8

Please sign in to comment.