diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 9d18831a4..181266a85 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -352,7 +352,7 @@ def __init__(self, hidden_channels, activation="silu", dtype=torch.float32): - super(ChargeHead, self).__init__(dtype=dtype) + super(PointChargeHead, self).__init__(dtype=dtype) act_class = act_class_mapping[activation] self.output_network = nn.Sequential( nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),