Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt protostructure naming #84

Merged
merged 10 commits into from
Jul 22, 2024
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.5.3
hooks:
- id: ruff
args: [--fix]
Expand All @@ -30,7 +30,7 @@ repos:
args: [--check-filenames]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.1
rev: v1.11.0
hooks:
- id: mypy
exclude: (tests|examples)/
Expand Down
4 changes: 2 additions & 2 deletions aviary/roost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def __getitem__(self, idx: int):
n_elems = len(elements)
self_idx = []
nbr_idx = []
for idx in range(n_elems):
self_idx += [idx] * n_elems
for elem_idx in range(n_elems):
self_idx += [elem_idx] * n_elems
nbr_idx += list(range(n_elems))

# convert all data to tensors
Expand Down
8 changes: 4 additions & 4 deletions aviary/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def forward(self, x: Tensor, index: Tensor) -> Tensor:
"""
gate = self.gate_nn(x)

gate = gate - scatter_max(gate, index, dim=0)[0][index]
gate -= scatter_max(gate, index, dim=0)[0][index]
gate = gate.exp()
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10)
gate /= scatter_add(gate, index, dim=0)[index] + 1e-10

x = self.message_nn(x)
return scatter_add(gate * x, index, dim=0)
Expand Down Expand Up @@ -78,9 +78,9 @@ def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor:
"""
gate = self.gate_nn(x)

gate = gate - scatter_max(gate, index, dim=0)[0][index]
gate -= scatter_max(gate, index, dim=0)[0][index]
gate = (weights**self.pow) * gate.exp()
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10)
gate /= scatter_add(gate, index, dim=0)[index] + 1e-10

x = self.message_nn(x)
return scatter_add(gate * x, index, dim=0)
Expand Down
13 changes: 6 additions & 7 deletions aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,13 @@ def train_model(
print("Starting stochastic weight averaging...")
swa_model.update_parameters(model)
swa_scheduler.step()
elif scheduler_name == "ReduceLROnPlateau":
val_metric = val_metrics[target_col][
"MAE" if task_type == reg_key else "Accuracy"
]
lr_scheduler.step(val_metric)
else:
if scheduler_name == "ReduceLROnPlateau":
val_metric = val_metrics[target_col][
"MAE" if task_type == reg_key else "Accuracy"
]
lr_scheduler.step(val_metric)
else:
lr_scheduler.step()
lr_scheduler.step()

model.epoch += 1

Expand Down
54 changes: 20 additions & 34 deletions aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,12 @@ def initialize_losses(
raise NameError(
"Only L1 or L2 losses are allowed for robust regression tasks"
)
elif loss_name_dict[name] == "L1":
loss_func_dict[name] = (task, L1Loss())
elif loss_name_dict[name] == "L2":
loss_func_dict[name] = (task, MSELoss())
else:
if loss_name_dict[name] == "L1":
loss_func_dict[name] = (task, L1Loss())
elif loss_name_dict[name] == "L2":
loss_func_dict[name] = (task, MSELoss())
else:
raise NameError(
"Only L1 or L2 losses are allowed for regression tasks"
)
raise NameError("Only L1 or L2 losses are allowed for regression tasks")

return loss_func_dict

Expand Down Expand Up @@ -723,46 +720,35 @@ def save_results_dict(
"""Save the results to a file after model evaluation.

Args:
ids (dict[str, list[str | int]]): ): Each key is the name of an identifier
ids (dict[str, list[str | int]]): Each key is the name of an identifier
(e.g. material ID, composition, ...) and its value a list of IDs.
results_dict (dict[str, Any]): ): nested dictionary of results
{name: {col: data}}
model_name (str): ): The name given the model via the --model-name flag.
run_id (str): ): The run ID given to the model via the --run-id flag.
results_dict (dict[str, Any]): nested dictionary of results {name: {col: data}}
model_name (str): The name given the model via the --model-name flag.
run_id (str): The run ID given to the model via the --run-id flag.
"""
results = {}
results: dict[str, np.ndarray] = {}

for target_name in results_dict:
for col, data in results_dict[target_name].items():
for target_name, target_data in results_dict.items():
for col, data in target_data.items():
# NOTE we save pre_logits rather than logits due to fact
# that with the heteroskedastic setup we want to be able to
# sample from the Gaussian distributed pre_logits we parameterize.
if "pre-logits" in col:
for n_ens, y_pre_logit in enumerate(data):
results.update(
{
f"{target_name}_{col}_c{lab}_n{n_ens}": val.ravel()
for lab, val in enumerate(y_pre_logit.T)
}
)
results |= {
f"{target_name}_{col}_c{lab}_n{n_ens}": val.ravel()
for lab, val in enumerate(y_pre_logit.T)
}

elif "pred" in col:
preds = {
elif "pred" in col or "ale" in col:
# elif so that pre-logit-ale doesn't trigger
results |= {
f"{target_name}_{col}_n{n_ens}": val.ravel()
for (n_ens, val) in enumerate(data)
}
results.update(preds)

elif "ale" in col: # elif so that pre-logit-ale doesn't trigger
results.update(
{
f"{target_name}_{col}_n{n_ens}": val.ravel()
for (n_ens, val) in enumerate(data)
}
)

elif col == "target":
results.update({f"{target_name}_target": data})
results |= {f"{target_name}_target": data}

df = pd.DataFrame({**ids, **results})

Expand Down
22 changes: 15 additions & 7 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def __getitem__(self, idx: int):
- list[str | int]: identifiers like material_id, composition
"""
row = self.df.iloc[idx]
wyckoff_str = row[self.inputs]
protostructure_label = row[self.inputs]
material_ids = row[self.identifiers].to_list()

parsed_output = parse_aflow_wyckoff_str(wyckoff_str)
parsed_output = parse_protostructure_label(protostructure_label)
spg_num, wyk_site_multiplcities, elements, augmented_wyks = parsed_output

wyk_site_multiplcities = np.atleast_2d(wyk_site_multiplcities).T / np.sum(
Expand Down Expand Up @@ -256,21 +256,29 @@ def collate_batch(
)


def parse_aflow_wyckoff_str(
aflow_label: str,
def parse_protostructure_label(
protostructure_label: str,
) -> tuple[str, list[float], list[str], list[tuple[str, ...]]]:
"""Parse the Wren AFLOW-like Wyckoff encoding.

Args:
aflow_label (str): AFLOW-style prototype string with appended chemical system
protostructure_label (str): label constructed as `aflow_label:chemsys` where
aflow_label is an AFLOW-style prototype label chemsys is the alphabetically
sorted chemical system.

Returns:
tuple[str, list[float], list[str], list[str]]: spacegroup number, Wyckoff site
multiplicities, elements symbols and equivalent wyckoff sets
"""
proto, chemsys = aflow_label.split(":")
aflow_label, chemsys = protostructure_label.split(":")
elems = chemsys.split("-")
_, _, spg_num, *wyckoff_letters = proto.split("_")
_, _, spg_num, *wyckoff_letters = aflow_label.split("_")

if len(elems) != len(wyckoff_letters):
raise ValueError(
f"Chemical system {chemsys} does not match Wyckoff letters "
f"{wyckoff_letters}"
)

wyckoff_site_multiplicities = []
elements = []
Expand Down
Loading