Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support arbitrary outputs in TorchMD_Net #239

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

RaulPPelaez
Copy link
Collaborator

@RaulPPelaez RaulPPelaez commented Nov 3, 2023

Following the discussion in #198 this PR attempts to give TorchMD_Net the ability to return more than one output ("y") and its derivative ("neg_dy").

This PR is still a draft as I am trying to figure out the final design.

This PR introduces user-facing breaking changes:

  • It changes some names in the configuration file (for instance Scalar is no longer a thing). Although a conversion could be made when processing the configuration.
  • The Datasets must provide "energy", "force" instead of "y", "neg_dy".
  • TorchMD_Net is expected to compute always at least energy, instead of a generic label called "y". Maybe I am missing some usecases here, so we will see...

New design proposed for the outputs of the model:

  • TorchMD_Net is composed of a representation model + an arbitrary number of heads stacked sequentially.
  • There is no distinction between a Prior and what used to be an OutputModel, they are all Heads now.
  • The EnergyHead is always the first one and the ForceHead the last (if derivative=True)
  • There is some level of customization akin to the Heads for computing the loss of each output and reducing the total loss.
  • The user provides a list of weights (like y_weight, neg_dy_weight now) for each model output that should be considered for the loss computation.

This is the BaseHead interface I propose:

class BaseHead(nn.Module):
    def __init__(self, dtype=torch.float32):
        super(BaseHead, self).__init__()
        self.dtype = dtype

    def reset_parameters(self):
        pass

    def per_point(self, point_features, results, z, pos, batch, extra_args):
        return point_features, results

    def per_sample(self, point_features, results, z, pos, batch, extra_args):
        return point_features, results

Where the forward call of TorchMD_Net would go like this:

        results = {}
        point_features = self.representation_model(z, pos, batch, q=q, s=s)
        for head in self.head_list:
            point_features, results = head.per_point(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args)
        for head in self.head_list:
            point_features, results = head.per_sample(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args)

Each head is free to add a new key to result, modify the point_features or the contents of result (i.e add to the energy). For instance, the EnergyHead:

class EnergyHead(BaseHead):
    def __init__(self,
                 hidden_channels,
                 activation="silu",
                 dtype=torch.float32):
        super(EnergyHead, self).__init__(dtype=dtype)
        act_class = act_class_mapping[activation]
        self.output_network = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
            act_class(),
            nn.Linear(hidden_channels // 2, 1, dtype=dtype),
        )
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.output_network[0].weight)
        self.output_network[0].bias.data.fill_(0)
        nn.init.xavier_uniform_(self.output_network[2].weight)
        self.output_network[2].bias.data.fill_(0)

    def per_point(self, point_features, results, z, pos, batch, extra_args):
        results["energy"] = self.output_network(point_features)
        return point_features, results

    def per_sample(self, point_features, results, z, pos, batch, extra_args):
        results["energy"] = scatter(results["energy"], batch, dim=0)
        return point_features, results

There are some challenges I have still to deal with:

  • Not sure how happy TorchScript is going to be with this.
  • Not sure ho the user should specify a list of predefined heads. Perhaps something like an option
    head_list: energy_head, coulomb_prior, some_other_prior, charge_head, some_charge_prior, force_head

Tasks:

  • Adapt TorchMD_Net
  • Reintroduce support for mean and std
    • Make Equivariant versions of the heads for ET.
  • Adapt LNNP
  • Adapt Datasets
  • Make priors into heads
  • Generalize the loss computation
  • Handle user input
  • Update tests

@giadefa
Copy link
Contributor

giadefa commented Nov 3, 2023 via email

@peastman
Copy link
Collaborator

peastman commented Nov 3, 2023

I was thinking of something a bit more generic than this. You can define an arbitrary set of output heads and loss terms. I imagine the description in the config file looking something like this.

output_heads:
  - scalar:
    name: energy
  - coulomb  # the Coulomb head is hardcoded to output a scalar "energy" and a vector "charges"
losses:
  - l2
    output: energy  # since multiple heads have "energy" outputs, they get summed before computing the loss
    dataset_field: y
    weight: 1.0
  - gradient_l2
    output: energy
    dataset_field: neg_dy
    weight: 0.1
  - l2
    output: charges
    dataset_field: mbis_charges
    weight: 0.1

The configuration for a totally different sort of model might look like this.

output_heads:
  - scalar
    name: solubility
losses:
  - l2
    output: solubility
    dataset_field: solubility
    # if weight is omitted, it defaults to 1

@peastman
Copy link
Collaborator

Is it ok if I try implementing the design described above?

@RaulPPelaez
Copy link
Collaborator Author

Hi Peter, I am working on it but I have not had much time, sorry about that.
It is fine if you want to give it a try, feel free to open a new PR if/when you have something and we can iterate. Would love to see your take.
I like your design very much, btw. Perhaps with the exception that I would rather the gradient be a property of the heads instead of the losses. Thinking about how an inference configuration should work, when reading it I would not immediately look at the loss section.

@giadefa
Copy link
Contributor

giadefa commented Nov 15, 2023 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants