Skip to content

Commit

Permalink
Adopt protostructure naming (#84)
Browse files Browse the repository at this point in the history
* doc: adopt protostructure naming

* lint: sort imports in notebooks

* clean: rename last functions

* fix outdated import pymatviz.(utils->powerups).add_identity_line

* ruff auto-fixes

* fix ruff aviary/utils.py:732:5: PLC0206 Extracting value from dictionary without calling `.items()`

and aviary/roost/data.py:116:13: PLR1704 Redefining argument with the local name `idx`

* fix save_results_dict doc string

* fea: bump version for breaking change

* fea: rename parse function used in wren data

---------

Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
CompRhys and janosh authored Jul 22, 2024
1 parent 94fba8c commit b7655b5
Show file tree
Hide file tree
Showing 21 changed files with 467 additions and 394 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.5.3
hooks:
- id: ruff
args: [--fix]
Expand All @@ -30,7 +30,7 @@ repos:
args: [--check-filenames]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.1
rev: v1.11.0
hooks:
- id: mypy
exclude: (tests|examples)/
Expand Down
4 changes: 2 additions & 2 deletions aviary/roost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def __getitem__(self, idx: int):
n_elems = len(elements)
self_idx = []
nbr_idx = []
for idx in range(n_elems):
self_idx += [idx] * n_elems
for elem_idx in range(n_elems):
self_idx += [elem_idx] * n_elems
nbr_idx += list(range(n_elems))

# convert all data to tensors
Expand Down
8 changes: 4 additions & 4 deletions aviary/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def forward(self, x: Tensor, index: Tensor) -> Tensor:
"""
gate = self.gate_nn(x)

gate = gate - scatter_max(gate, index, dim=0)[0][index]
gate -= scatter_max(gate, index, dim=0)[0][index]
gate = gate.exp()
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10)
gate /= scatter_add(gate, index, dim=0)[index] + 1e-10

x = self.message_nn(x)
return scatter_add(gate * x, index, dim=0)
Expand Down Expand Up @@ -78,9 +78,9 @@ def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor:
"""
gate = self.gate_nn(x)

gate = gate - scatter_max(gate, index, dim=0)[0][index]
gate -= scatter_max(gate, index, dim=0)[0][index]
gate = (weights**self.pow) * gate.exp()
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-10)
gate /= scatter_add(gate, index, dim=0)[index] + 1e-10

x = self.message_nn(x)
return scatter_add(gate * x, index, dim=0)
Expand Down
13 changes: 6 additions & 7 deletions aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,13 @@ def train_model(
print("Starting stochastic weight averaging...")
swa_model.update_parameters(model)
swa_scheduler.step()
elif scheduler_name == "ReduceLROnPlateau":
val_metric = val_metrics[target_col][
"MAE" if task_type == reg_key else "Accuracy"
]
lr_scheduler.step(val_metric)
else:
if scheduler_name == "ReduceLROnPlateau":
val_metric = val_metrics[target_col][
"MAE" if task_type == reg_key else "Accuracy"
]
lr_scheduler.step(val_metric)
else:
lr_scheduler.step()
lr_scheduler.step()

model.epoch += 1

Expand Down
54 changes: 20 additions & 34 deletions aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,12 @@ def initialize_losses(
raise NameError(
"Only L1 or L2 losses are allowed for robust regression tasks"
)
elif loss_name_dict[name] == "L1":
loss_func_dict[name] = (task, L1Loss())
elif loss_name_dict[name] == "L2":
loss_func_dict[name] = (task, MSELoss())
else:
if loss_name_dict[name] == "L1":
loss_func_dict[name] = (task, L1Loss())
elif loss_name_dict[name] == "L2":
loss_func_dict[name] = (task, MSELoss())
else:
raise NameError(
"Only L1 or L2 losses are allowed for regression tasks"
)
raise NameError("Only L1 or L2 losses are allowed for regression tasks")

return loss_func_dict

Expand Down Expand Up @@ -723,46 +720,35 @@ def save_results_dict(
"""Save the results to a file after model evaluation.
Args:
ids (dict[str, list[str | int]]): ): Each key is the name of an identifier
ids (dict[str, list[str | int]]): Each key is the name of an identifier
(e.g. material ID, composition, ...) and its value a list of IDs.
results_dict (dict[str, Any]): ): nested dictionary of results
{name: {col: data}}
model_name (str): ): The name given the model via the --model-name flag.
run_id (str): ): The run ID given to the model via the --run-id flag.
results_dict (dict[str, Any]): nested dictionary of results {name: {col: data}}
model_name (str): The name given the model via the --model-name flag.
run_id (str): The run ID given to the model via the --run-id flag.
"""
results = {}
results: dict[str, np.ndarray] = {}

for target_name in results_dict:
for col, data in results_dict[target_name].items():
for target_name, target_data in results_dict.items():
for col, data in target_data.items():
# NOTE we save pre_logits rather than logits due to fact
# that with the heteroskedastic setup we want to be able to
# sample from the Gaussian distributed pre_logits we parameterize.
if "pre-logits" in col:
for n_ens, y_pre_logit in enumerate(data):
results.update(
{
f"{target_name}_{col}_c{lab}_n{n_ens}": val.ravel()
for lab, val in enumerate(y_pre_logit.T)
}
)
results |= {
f"{target_name}_{col}_c{lab}_n{n_ens}": val.ravel()
for lab, val in enumerate(y_pre_logit.T)
}

elif "pred" in col:
preds = {
elif "pred" in col or "ale" in col:
# elif so that pre-logit-ale doesn't trigger
results |= {
f"{target_name}_{col}_n{n_ens}": val.ravel()
for (n_ens, val) in enumerate(data)
}
results.update(preds)

elif "ale" in col: # elif so that pre-logit-ale doesn't trigger
results.update(
{
f"{target_name}_{col}_n{n_ens}": val.ravel()
for (n_ens, val) in enumerate(data)
}
)

elif col == "target":
results.update({f"{target_name}_target": data})
results |= {f"{target_name}_target": data}

df = pd.DataFrame({**ids, **results})

Expand Down
22 changes: 15 additions & 7 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def __getitem__(self, idx: int):
- list[str | int]: identifiers like material_id, composition
"""
row = self.df.iloc[idx]
wyckoff_str = row[self.inputs]
protostructure_label = row[self.inputs]
material_ids = row[self.identifiers].to_list()

parsed_output = parse_aflow_wyckoff_str(wyckoff_str)
parsed_output = parse_protostructure_label(protostructure_label)
spg_num, wyk_site_multiplcities, elements, augmented_wyks = parsed_output

wyk_site_multiplcities = np.atleast_2d(wyk_site_multiplcities).T / np.sum(
Expand Down Expand Up @@ -256,21 +256,29 @@ def collate_batch(
)


def parse_aflow_wyckoff_str(
aflow_label: str,
def parse_protostructure_label(
protostructure_label: str,
) -> tuple[str, list[float], list[str], list[tuple[str, ...]]]:
"""Parse the Wren AFLOW-like Wyckoff encoding.
Args:
aflow_label (str): AFLOW-style prototype string with appended chemical system
protostructure_label (str): label constructed as `aflow_label:chemsys` where
aflow_label is an AFLOW-style prototype label chemsys is the alphabetically
sorted chemical system.
Returns:
tuple[str, list[float], list[str], list[str]]: spacegroup number, Wyckoff site
multiplicities, elements symbols and equivalent wyckoff sets
"""
proto, chemsys = aflow_label.split(":")
aflow_label, chemsys = protostructure_label.split(":")
elems = chemsys.split("-")
_, _, spg_num, *wyckoff_letters = proto.split("_")
_, _, spg_num, *wyckoff_letters = aflow_label.split("_")

if len(elems) != len(wyckoff_letters):
raise ValueError(
f"Chemical system {chemsys} does not match Wyckoff letters "
f"{wyckoff_letters}"
)

wyckoff_site_multiplicities = []
elements = []
Expand Down
Loading

0 comments on commit b7655b5

Please sign in to comment.