Skip to content

Commit

Permalink
fix errors
Browse files Browse the repository at this point in the history
  • Loading branch information
RylieWeaver committed Oct 8, 2024
1 parent 6c21612 commit 0feb249
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 14 deletions.
10 changes: 1 addition & 9 deletions hydragnn/utils/model/mace_utils/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,16 @@
ScaleShiftBlock,
)

from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis
from .radial import BesselBasis, GaussianBasis, PolynomialCutoff
from .symmetric_contraction import SymmetricContraction

interaction_classes: Dict[str, Type[InteractionBlock]] = {
"RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock,
}

# gate_dict: Dict[str, Optional[Callable]] = {
# "abs": torch.abs,
# "tanh": torch.tanh,
# "silu": torch.nn.functional.silu,
# "None": None,
# }

__all__ = [
"AtomicEnergiesBlock",
"RadialEmbeddingBlock",
"ZBLBasis",
"LinearNodeEmbeddingBlock",
"LinearReadoutBlock",
"EquivariantProductBasisBlock",
Expand Down
5 changes: 0 additions & 5 deletions tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,7 @@ def pytest_train_model_vectoroutput(model_type, overwrite_data=False):
"DimeNet",
"EGNN",
"PNAEq",
"MACE",
],
)
def pytest_train_model_conv_head(model_type, overwrite_data=False):
unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data)


def train_model_conv_head(model_type, overwrite_data=False):
unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data)

0 comments on commit 0feb249

Please sign in to comment.