-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: continuous, embedded covariates re-injected during training #3032
base: main
Are you sure you want to change the base?
Conversation
…njected, and not just to the input layer
close #3021 |
src/scvi/nn/_base_components.py
Outdated
"""A helper class to build fully-connected layers for a neural network. | ||
"""FCLayers class of scvi-tools adapted to also inject continous covariates. | ||
|
||
The only adaptation is addition of `n_cont` parameter in init and `cont` in forward, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this docstring part - I only added it as info for implementing the change.
src/scvi/nn/_base_components.py
Outdated
@@ -135,13 +145,18 @@ def _hook_fn_zero_out(grad): | |||
b = layer.bias.register_hook(_hook_fn_zero_out) | |||
self.hooks.append(b) | |||
|
|||
def forward(self, x: torch.Tensor, *cat_list: int): | |||
def forward( | |||
self, x: torch.Tensor, cat_list: list | None = None, cont: torch.Tensor | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the forward parametrization from (x, *cat_list)
to (x,cat_list,cont)
- so this will break with other Modules
using FCLayers
by passing cat_list
in expanded *
format. So either the parametrization here needs to be fixed or the use of FCLayers
elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment in #3021
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I thought I would be able to finish this in one shot for all models, but its a bit more complex than I thought. So, I'll be focusing on scvi and scanvi for now only.
…ly (not done yet)
…ly (not done yet)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3032 +/- ##
==========================================
- Coverage 84.80% 84.35% -0.46%
==========================================
Files 173 173
Lines 14797 14806 +9
==========================================
- Hits 12549 12489 -60
- Misses 2248 2317 +69
|
…i, other modules work but w/o the option for continous covariates injection to deep layers, which might be a bit more complex to implement
This kind of change requires touching the module for each model. I don't think there is anywhere around it. Its risky and this is why I implemented it only for scvi & scanvi at the moment (maybe others will work also but I didn't test thoroughly). the rest of the changed modules are mainly things in function headers and placeholders to able to cope with the major changes in base_components (i.e cont_cov is usually None and cant be injected to their hidden layers). To summarise we have 2 parameters to think of:
@Hrovatin please try to use this branch |
@@ -372,17 +382,17 @@ def _regular_inference( | |||
if self.batch_representation == "embedding" and self.encode_covariates: | |||
batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) | |||
encoder_input = torch.cat([encoder_input, batch_rep], dim=-1) | |||
qz, z = self.z_encoder(encoder_input, *categorical_input) | |||
qz, z = self.z_encoder(encoder_input, cont_covs, *categorical_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here batch_rep should be also concat to cont_covs to be treated as continuous covariate no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to keep all seperated. So you could do deep injection of the batch embedding or covariate.
@@ -360,7 +370,7 @@ def _regular_inference( | |||
if self.log_variational: | |||
x_ = torch.log1p(x_) | |||
|
|||
if cont_covs is not None and self.encode_covariates: | |||
if cont_covs is not None: | |||
encoder_input = torch.cat((x_, cont_covs), dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you always concatenate continous cov to expression instead of handling them separately as a covariate (sam as for one-hot)?
@@ -45,13 +47,26 @@ class FCLayers(nn.Module): | |||
Whether to inject covariates in each layer, or just the first (default). | |||
activation_fn | |||
Which activation function to use | |||
encode_covariates |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
encode_covariates
and batch_representation
are used in Module (e.g. vae) for handling of covariates not here
layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out] | ||
|
||
if n_cat_list is not None: | ||
# n_cat = 1 will be ignored | ||
if n_cat_list is not None and self.batch_representation == "one-hot": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think you should check batch_representation
here - if n_cat_list
is not empty the covariates need to be accounted for here irrespective of what batch_representation
is - as the batch_representation
is used in Module to determine if batch will be added to the cat list or not. So the input cat cov list will already contain batch if it is not embedded ,besides other cat covariates
encode_covariates | ||
If ``True``, covariates are concatenated to gene expression prior to passing through | ||
the encoder(s). Else, only gene expression is used. | ||
batch_representation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be also removed and only kept in Module
if i > 0 and self.inject_covariates and cont_covs is not None: | ||
# Need to inject the continous covariates to hidden layers | ||
x = torch.cat((x, cont_covs), dim=-1) | ||
if self.batch_representation == "one-hot": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be done irregardless of what is batch representation for all covariates that were passed as categorical
@ori-kron-wis I think there are things that need to be fixed as per my review comments. To summarise:
@canergen can you confirm? |
@@ -163,7 +163,7 @@ def inference(self, mc, cov, batch_index, cat_covs=None, n_samples=1): | |||
else: | |||
categorical_input = () | |||
|
|||
qz, z = self.z_encoder(methylation_input, batch_index, *categorical_input) | |||
qz, z = self.z_encoder(methylation_input, None, batch_index, *categorical_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the None as an argument in line 148. It makes it cleaner where it comes from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is throughout this PR.
@@ -51,9 +53,9 @@ def __init__( | |||
) | |||
self.output = torch.nn.Sequential(torch.nn.Linear(n_hidden, 1), torch.nn.LeakyReLU()) | |||
|
|||
def forward(self, x: torch.Tensor, *cat_list: int): | |||
def forward(self, x: torch.Tensor, cont_covs: torch.Tensor | None = None, *cat_list: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here it's correct.
@@ -581,7 +604,7 @@ def inference( | |||
mask_acc = x_chr.sum(dim=1) > 0 | |||
mask_pro = y.sum(dim=1) > 0 | |||
|
|||
if cont_covs is not None and self.encode_covariates: | |||
if cont_covs is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't remove the self.encode covariates check.
@@ -597,21 +620,21 @@ def inference( | |||
|
|||
# Z Encoders | |||
qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility( | |||
encoder_input_accessibility, batch_index, *categorical_input | |||
encoder_input_accessibility, None, batch_index, *categorical_input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this hard-coded?
@@ -185,6 +189,7 @@ def __init__( | |||
n_output=self.n_latent, | |||
n_hidden=self.n_hidden, | |||
n_cat_list=encoder_cat_list, | |||
n_cont=n_continuous_cov, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only if encode covariates
cont_covs: torch.Tensor = tensors[REGISTRY_KEYS.CONT_COVS_KEY] | ||
cont_covs = broadcast_labels(cont_covs, n_broadcast=self.n_labels)[1] | ||
else: | ||
cont_covs = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please explain.
qz2, z2 = self.encoder_z2_z1(z1s, ys) | ||
pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) | ||
qz2, z2 = self.encoder_z2_z1( | ||
torch.cat((z1s, cont_covs), dim=-1) if cont_covs is not None else z1s, cont_covs, ys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no this is independent of cont_covs
torch.cat((z1s, cont_covs), dim=-1) if cont_covs is not None else z1s, cont_covs, ys | ||
) | ||
pz1_m, pz1_v = self.decoder_z1_z2( | ||
torch.cat((z2, cont_covs), dim=-1) if cont_covs is not None else z2, cont_covs, ys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this to independent of cont_covs
Added other covariates types (continuous, embedded) to be able to reinjected, and not just to the input layer