Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: continuous, embedded covariates re-injected during training #3032

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

ori-kron-wis
Copy link
Collaborator

@ori-kron-wis ori-kron-wis commented Oct 27, 2024

Added other covariates types (continuous, embedded) to be able to reinjected, and not just to the input layer

@ori-kron-wis ori-kron-wis self-assigned this Oct 27, 2024
@ori-kron-wis ori-kron-wis added this to the scvi-tools 1.2 milestone Oct 27, 2024
@ori-kron-wis
Copy link
Collaborator Author

close #3021

@ori-kron-wis ori-kron-wis changed the title Added other covariates types (continuous, embedded) to be able to rei… feat: continuous, embedded covariates re-injected during training Oct 27, 2024
"""A helper class to build fully-connected layers for a neural network.
"""FCLayers class of scvi-tools adapted to also inject continous covariates.

The only adaptation is addition of `n_cont` parameter in init and `cont` in forward,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this docstring part - I only added it as info for implementing the change.

@@ -135,13 +145,18 @@ def _hook_fn_zero_out(grad):
b = layer.bias.register_hook(_hook_fn_zero_out)
self.hooks.append(b)

def forward(self, x: torch.Tensor, *cat_list: int):
def forward(
self, x: torch.Tensor, cat_list: list | None = None, cont: torch.Tensor | None = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the forward parametrization from (x, *cat_list) to (x,cat_list,cont) - so this will break with other Modules using FCLayers by passing cat_list in expanded * format. So either the parametrization here needs to be fixed or the use of FCLayers elsewhere.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment in #3021

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I thought I would be able to finish this in one shot for all models, but its a bit more complex than I thought. So, I'll be focusing on scvi and scanvi for now only.

Copy link

codecov bot commented Oct 29, 2024

Codecov Report

Attention: Patch coverage is 97.82609% with 2 lines in your changes missing coverage. Please review.

Project coverage is 84.35%. Comparing base (795297e) to head (77df3e2).

Files with missing lines Patch % Lines
src/scvi/nn/_base_components.py 95.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3032      +/-   ##
==========================================
- Coverage   84.80%   84.35%   -0.46%     
==========================================
  Files         173      173              
  Lines       14797    14806       +9     
==========================================
- Hits        12549    12489      -60     
- Misses       2248     2317      +69     
Files with missing lines Coverage Δ
src/scvi/external/contrastivevi/_module.py 98.61% <100.00%> (ø)
src/scvi/external/methylvi/_base_components.py 100.00% <100.00%> (ø)
src/scvi/external/methylvi/_module.py 81.08% <100.00%> (ø)
src/scvi/module/_mrdeconv.py 95.13% <100.00%> (ø)
src/scvi/module/_multivae.py 82.22% <100.00%> (ø)
src/scvi/module/_peakvae.py 96.22% <100.00%> (+4.71%) ⬆️
src/scvi/module/_scanvae.py 86.23% <100.00%> (+0.84%) ⬆️
src/scvi/module/_totalvae.py 87.54% <100.00%> (ø)
src/scvi/module/_vae.py 94.50% <100.00%> (+0.02%) ⬆️
src/scvi/module/_vaec.py 85.05% <100.00%> (-0.34%) ⬇️
... and 1 more

... and 5 files with indirect coverage changes

…i, other modules work but w/o the option for continous covariates injection to deep layers, which might be a bit more complex to implement
@ori-kron-wis ori-kron-wis marked this pull request as ready for review October 30, 2024 13:35
@ori-kron-wis
Copy link
Collaborator Author

This kind of change requires touching the module for each model. I don't think there is anywhere around it. Its risky and this is why I implemented it only for scvi & scanvi at the moment (maybe others will work also but I didn't test thoroughly). the rest of the changed modules are mainly things in function headers and placeholders to able to cope with the major changes in base_components (i.e cont_cov is usually None and cant be injected to their hidden layers).

To summarise we have 2 parameters to think of:

  • deeply_inject (bool): covariates (continuous or categorial whether being encoded or not) are always injected to the first layer. If this parameter is True they will also be injected to the hidden layer of the encoder/decoder.
  • encode_covariates (bool): this is only relevant for the categorial covariates. If this parameter is false, they will be injected as one-hot, if true, they will be encoded as an embedded matrix (this is dependent on other meta-parameters such as the size of embedded vector and the batch_representation which should be "embedded" as well if we want it, otherwise it will remain one-hot)

@Hrovatin please try to use this branch

@@ -372,17 +382,17 @@ def _regular_inference(
if self.batch_representation == "embedding" and self.encode_covariates:
batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index)
encoder_input = torch.cat([encoder_input, batch_rep], dim=-1)
qz, z = self.z_encoder(encoder_input, *categorical_input)
qz, z = self.z_encoder(encoder_input, cont_covs, *categorical_input)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here batch_rep should be also concat to cont_covs to be treated as continuous covariate no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to keep all seperated. So you could do deep injection of the batch embedding or covariate.

@@ -360,7 +370,7 @@ def _regular_inference(
if self.log_variational:
x_ = torch.log1p(x_)

if cont_covs is not None and self.encode_covariates:
if cont_covs is not None:
encoder_input = torch.cat((x_, cont_covs), dim=-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you always concatenate continous cov to expression instead of handling them separately as a covariate (sam as for one-hot)?

@@ -45,13 +47,26 @@ class FCLayers(nn.Module):
Whether to inject covariates in each layer, or just the first (default).
activation_fn
Which activation function to use
encode_covariates
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used anywhere?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encode_covariates and batch_representation are used in Module (e.g. vae) for handling of covariates not here

layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out]

if n_cat_list is not None:
# n_cat = 1 will be ignored
if n_cat_list is not None and self.batch_representation == "one-hot":
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think you should check batch_representation here - if n_cat_list is not empty the covariates need to be accounted for here irrespective of what batch_representation is - as the batch_representation is used in Module to determine if batch will be added to the cat list or not. So the input cat cov list will already contain batch if it is not embedded ,besides other cat covariates

encode_covariates
If ``True``, covariates are concatenated to gene expression prior to passing through
the encoder(s). Else, only gene expression is used.
batch_representation
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be also removed and only kept in Module

if i > 0 and self.inject_covariates and cont_covs is not None:
# Need to inject the continous covariates to hidden layers
x = torch.cat((x, cont_covs), dim=-1)
if self.batch_representation == "one-hot":
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be done irregardless of what is batch representation for all covariates that were passed as categorical

@Hrovatin
Copy link

Hrovatin commented Nov 1, 2024

@ori-kron-wis I think there are things that need to be fixed as per my review comments. To summarise:
I think the behaviour should be as follows:

  • Module passes to layers expression, continuous covariates (can be None), and categorical covariates that need to be one-hot encoded (can be empty list)
  • Module is the one that decides if batch will be embedded and concatenated to continuous covariates or added to the categorical list
  • Layers should not check how the batch is encoded as their init params already tell the number of continous/categorical covariates as determined by the Module.
  • The layers then either inject covariates (continuous and one-hote encoded categorical) into all layers or only add to the first one. - So Module should not be concatenating continuous covariates to expression in advance.

@canergen can you confirm?

@@ -163,7 +163,7 @@ def inference(self, mc, cov, batch_index, cat_covs=None, n_samples=1):
else:
categorical_input = ()

qz, z = self.z_encoder(methylation_input, batch_index, *categorical_input)
qz, z = self.z_encoder(methylation_input, None, batch_index, *categorical_input)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the None as an argument in line 148. It makes it cleaner where it comes from.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is throughout this PR.

@@ -51,9 +53,9 @@ def __init__(
)
self.output = torch.nn.Sequential(torch.nn.Linear(n_hidden, 1), torch.nn.LeakyReLU())

def forward(self, x: torch.Tensor, *cat_list: int):
def forward(self, x: torch.Tensor, cont_covs: torch.Tensor | None = None, *cat_list: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it's correct.

@@ -581,7 +604,7 @@ def inference(
mask_acc = x_chr.sum(dim=1) > 0
mask_pro = y.sum(dim=1) > 0

if cont_covs is not None and self.encode_covariates:
if cont_covs is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't remove the self.encode covariates check.

@@ -597,21 +620,21 @@ def inference(

# Z Encoders
qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility(
encoder_input_accessibility, batch_index, *categorical_input
encoder_input_accessibility, None, batch_index, *categorical_input
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this hard-coded?

@@ -185,6 +189,7 @@ def __init__(
n_output=self.n_latent,
n_hidden=self.n_hidden,
n_cat_list=encoder_cat_list,
n_cont=n_continuous_cov,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only if encode covariates

cont_covs: torch.Tensor = tensors[REGISTRY_KEYS.CONT_COVS_KEY]
cont_covs = broadcast_labels(cont_covs, n_broadcast=self.n_labels)[1]
else:
cont_covs = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please explain.

qz2, z2 = self.encoder_z2_z1(z1s, ys)
pz1_m, pz1_v = self.decoder_z1_z2(z2, ys)
qz2, z2 = self.encoder_z2_z1(
torch.cat((z1s, cont_covs), dim=-1) if cont_covs is not None else z1s, cont_covs, ys
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no this is independent of cont_covs

torch.cat((z1s, cont_covs), dim=-1) if cont_covs is not None else z1s, cont_covs, ys
)
pz1_m, pz1_v = self.decoder_z1_z2(
torch.cat((z2, cont_covs), dim=-1) if cont_covs is not None else z2, cont_covs, ys
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this to independent of cont_covs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants