Skip to content

Commit

Permalink
errors_correction5
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiomarco25 committed Dec 19, 2024
1 parent 9a1ba0f commit fa43a29
Showing 1 changed file with 248 additions and 31 deletions.
279 changes: 248 additions & 31 deletions src/troutpy/pl/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,16 @@ def plot_crosstab(data, xvar: str = '', yvar: str = '', normalize=True, axis=1,
-----------
data : pd.DataFrame
Input dataset containing the variables for the cross-tabulation.
xvar : str, optional (default: '')
The variable to use on the x-axis for the cross-tabulation.
yvar : str, optional (default: '')
The variable to use on the y-axis for the cross-tabulation.
normalize : bool, optional (default: True)
Whether to normalize the cross-tabulated data (percentages). If True, the data will be normalized.
axis : int, optional (default: 1)
The axis to normalize across. Use `1` for row normalization and `0` for column normalization.
Expand Down Expand Up @@ -314,7 +314,7 @@ def pie_of_positive(data, groupby: str = '', figures_path: str = '', save: bool
"""

plt.figure()
y = np.array([np.sum(data[groupby] == False), np.sum(data[groupby] == True)])
y = np.array([np.sum(data[groupby] == False), np.sum(data[groupby] )])
mylabels = [f"{groupby}=False", f"{groupby}=True"]

plt.pie(y, labels=mylabels, colors=['#a0b7e0', '#c5e493'])
Expand All @@ -325,7 +325,7 @@ def pie_of_positive(data, groupby: str = '', figures_path: str = '', save: bool
plt.savefig(os.path.join(figures_path, plot_filename))

def genes_over_noise(sdata, scores_by_genes,layer='extracellular_transcripts', output_path:str='',save=True,format:str='pdf'):
"""This function plots log fold change per gene over noise using a boxplot.
"""Function that plots log fold change per gene over noise using a boxplot.
Parameters:
- data_quantified: DataFrame containing the extracellular transcript data, including feature names and codeword categories.
Expand Down Expand Up @@ -460,25 +460,105 @@ def proportion_above_threshold(
plt.savefig(os.path.join(figures_path, f'{filename}.{format}'))
plt.show()

def nmf_factors_exrna_cells_W(sdata,nmf_adata_key='nmf_data', save=True,saving_path='',spot_size:int=30,cmap='viridis'):
def nmf_factors_exrna_cells_W(sdata, nmf_adata_key: str = 'nmf_data', save: bool = True, saving_path: str = '',
spot_size: int = 30, cmap: str = 'viridis'):
"""
Plot NMF factors for each cell in a spatial transcriptomics dataset.
This function extracts the NMF (Non-negative Matrix Factorization) factors from the specified AnnData object
within the spatial data (`sdata`) and creates spatial plots for each factor. The plots can be displayed or saved to disk.
Parameters:
-----------
sdata : AnnData or SpatialData object
A spatial transcriptomics dataset that contains the NMF factors in the specified key.
nmf_adata_key : str, optional
The key in `sdata` that contains the AnnData object with NMF results. Defaults to 'nmf_data'.
save : bool, optional
Whether to save the spatial factor plots to disk. Defaults to True.
saving_path : str, optional
Path where the plots should be saved if `save` is True. The plots are saved in a `figures` subdirectory.
Defaults to an empty string.
spot_size : int, optional
Size of the spots in the spatial plot. Defaults to 30.
cmap : str, optional
Colormap to use for the spatial plots. Defaults to 'viridis'.
Returns:
--------
None
Displays the spatial plots for each NMF factor. If `save` is True, the plots are saved as PNG files.
Notes:
------
- The NMF factors are expected to be stored in `adata.obsm['W_nmf']`, where `adata` is extracted from `sdata`.
- A maximum of 20 factors is plotted by iterating through the columns of `W_nmf`.
- When saving, each plot is named `spatialnmf{factor}.png` and stored in a `figures` directory inside `saving_path`.
Example:
--------
>>> nmf_factors_exrna_cells_W(sdata, nmf_adata_key='nmf_data', save=True, saving_path='./results', spot_size=50, cmap='plasma')
"""
# Plot the factors for each cell in a spatial plot
adata=sdata[nmf_adata_key]
adata = sdata[nmf_adata_key]
W = adata.obsm['W_nmf']
for factor in range(20):
# Add the factor values to adata.obs for plotting
adata.obs[f'NMF_factor_{factor + 1}'] = W[:, factor]
# Plot spatial map of cells colored by this factor
if save:
sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap=cmap, title=f'NMF Factor {factor + 1}', spot_size=30,show=False)
plt.savefig(saving_path+'/figures/'+ f'spatialnmf{factor}.png')
sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap=cmap, title=f'NMF Factor {factor + 1}',
spot_size=30, show=False)
plt.savefig(saving_path + '/figures/' + f'spatialnmf{factor}.png')
plt.show()
else:
sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap=cmap, title=f'NMF Factor {factor + 1}', spot_size=spot_size)
sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap=cmap, title=f'NMF Factor {factor + 1}',
spot_size=spot_size)

def nmf_gene_contributions(sdata, nmf_adata_key: str = 'nmf_data', save: bool = True, vmin: float = 0.0, vmax: float = 0.02,
saving_path: str = '', cmap: str = 'viridis', figsize: tuple = (5, 5)):
"""Plot a heatmap of gene contributions to NMF factors.
def nmf_gene_contributions(sdata,nmf_adata_key='nmf_data', save=True, vmin=0.0, vmax=0.02,saving_path='',cmap='viridis',figsize=(5,5)):
adata=sdata[nmf_adata_key]
loadings=pd.DataFrame(adata.uns['H_nmf'],columns=adata.var.index)
loadings_filtered=loadings.loc[:,np.max(loadings,axis=0)>0.05].transpose()
This function extracts the NMF (Non-negative Matrix Factorization) gene loadings matrix from the specified AnnData object within the spatial data (`sdata`), filters genes based on their maximum loading value, and plots a heatmap of the filtered loadings.
Parameters:
-----------
sdata : AnnData or SpatialData object
A spatial transcriptomics dataset that contains the NMF factors in the specified key.
nmf_adata_key : str, optional
The key in `sdata` that contains the AnnData object with NMF results. Defaults to 'nmf_data'.
save : bool, optional
Whether to save the heatmap plot to disk. Defaults to True.
vmin : float, optional
Minimum value for the colormap scale. Defaults to 0.0.
vmax : float, optional
Maximum value for the colormap scale. Defaults to 0.02.
saving_path : str, optional
Path where the plot should be saved if `save` is True. The plot is saved in a `figures` subdirectory.
Defaults to an empty string.
cmap : str, optional
Colormap to use for the heatmap. Defaults to 'viridis'.
figsize : tuple, optional
Size of the heatmap figure. Defaults to (5, 5).
Returns:
--------
None
Displays a heatmap of gene contributions to NMF factors. If `save` is True, the heatmap is saved as a PDF file.
Notes:
------
- The gene loadings matrix is expected to be stored in `adata.uns['H_nmf']`, where `adata` is extracted from `sdata`.
- Genes with a maximum loading value greater than 0.05 are included in the heatmap.
- The rows of the heatmap are sorted based on the factor with the highest contribution for each gene.
Example:
--------
>>> nmf_gene_contributions(sdata, nmf_adata_key='nmf_data', save=True, saving_path='./results', cmap='plasma', figsize=(10, 8))
"""
adata = sdata[nmf_adata_key]
loadings = pd.DataFrame(adata.uns['H_nmf'], columns=adata.var.index)
loadings_filtered = loadings.loc[:, np.max(loadings, axis=0) > 0.05].transpose()
figures_path = os.path.join(saving_path, 'figures')
os.makedirs(figures_path, exist_ok=True)

Expand All @@ -494,28 +574,60 @@ def nmf_gene_contributions(sdata,nmf_adata_key='nmf_data', save=True, vmin=0.0,
plt.savefig(os.path.join(figures_path, "loadings_NMF.pdf"))
plt.show()
plt.close() # Close the figure to avoid memory issues

def apply_exrnaH_to_cellular_to_create_cellularW(adata_extracellular_with_nmf, adata_annotated_cellular):
"""Apply extracellular RNA NMF loadings (H) to cellular data to generate cellular NMF factors (W).
This function transfers the gene loadings (H matrix) derived from extracellular RNA analysis to a cellular dataset. It calculates the new W matrix for cellular data by multiplying the gene expression values of the cellular dataset with the filtered H matrix.
Parameters:
-----------
adata_extracellular_with_nmf : AnnData
An AnnData object containing the extracellular RNA data with the NMF results.
The H matrix is expected to be stored in `adata.uns['H_nmf']`.
adata_annotated_cellular : AnnData
An AnnData object containing the cellular RNA data with annotated gene expression values.
Returns:
--------
AnnData
The input `adata_annotated_cellular` object with the following updates:
- Adds the calculated NMF factors (W matrix) as a DataFrame to `adata.obsm['factors']`.
- Adds each NMF factor as individual columns in `adata.obs` with names `NMF_factor_1`, `NMF_factor_2`, etc.
Notes:
------
- Only the genes common between the extracellular RNA data and the cellular data are used for the computation.
- The gene intersection ensures compatibility between the NMF H matrix and the cellular gene expression matrix.
Example:
--------
>>> adata_cellular = apply_exrnaH_to_cellular_to_create_cellularW(adata_extracellular_with_nmf, adata_annotated_cellular)
"""
# Extract the H matrix (NMF gene loadings) from the extracellular data
H = adata_extracellular_with_nmf.uns['H_nmf']
# Check the number of genes in adata_annotated and spots2region_output to match gene loadings (H)

# Check the genes in both datasets
genes_spots2region = adata_extracellular_with_nmf.var_names
genes_annotated = adata_annotated_cellular.var_names

# Get intersection of genes between the two datasets
# Get the intersection of genes between the two datasets
common_genes = genes_annotated.intersection(genes_spots2region)

# Filter both datasets to keep only common genes
# Filter both datasets to retain only common genes
adata_annotated_cellular = adata_annotated_cellular[:, common_genes]
H_filtered = H[:, np.isin(genes_spots2region, common_genes)] # Filtered NMF gene loadings for common genes
H_filtered = H[:, np.isin(genes_spots2region, common_genes)] # Filter H matrix to include only common genes

# Apply the NMF factors to the annotated dataset
# Calculate the new W matrix by multiplying the annotated data with the filtered H
# Compute the new W matrix for the cellular dataset
W_annotated = adata_annotated_cellular.X @ H_filtered.T

adata_annotated_cellular.obsm['factors']=pd.DataFrame(W_annotated,index=adata_annotated_cellular.obs.index)
#print(W_annotated[:, 0].shape)
# Add the factors as new columns in adata_annotated.obs
# Store the W matrix in the obsm attribute as a DataFrame
adata_annotated_cellular.obsm['factors'] = pd.DataFrame(
W_annotated, index=adata_annotated_cellular.obs.index
)

# Add individual NMF factors to adata.obs
for factor in range(W_annotated.shape[1]):
adata_annotated_cellular.obs[f'NMF_factor_{factor + 1}'] = W_annotated[:, factor]

Expand Down Expand Up @@ -614,10 +726,45 @@ def paired_nmf_factors(
plt.tight_layout()
plt.show()

def W(adata, n_factors, save=True): # not very intuitive
# Plot the spatial map of cells colored by each factor
def plot_nmf_factors_spatial(adata, n_factors, save=True):
"""
Plot spatial maps of cells colored by NMF factors.
This function visualizes the spatial distribution of cells, colored by their corresponding NMF factor values, stored in `adata.obs`. It iterates over all specified NMF factors and generates spatial plots for each factor.
Parameters:
-----------
adata : AnnData
An AnnData object containing the dataset with NMF factors already added as columns in `adata.obs`.Each factor should be named `NMF_factor_1`, `NMF_factor_2`, ..., `NMF_factor_n`.
n_factors : int
The number of NMF factors to plot.
save : bool, optional (default=True)
If `True`, saves the plots to files with filenames `exo_to_cell_spatial_<factor>.png`.
Returns:
--------
None
This function does not return anything but generates and optionally saves spatial plots.
Notes:
------
- The plots are colored using the 'plasma' colormap.
- The spot size for the spatial plots is set to 15 by default.
- Files are saved in the current working directory unless specified otherwise using `sc.settings.figdir`.
Example:
--------
>>> plot_nmf_factors_spatial(adata, n_factors=10, save=True)
"""
for factor in range(n_factors):
sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap='plasma', title=f'NMF Factor {factor + 1}', spot_size=15, save=f'exo_to_cell_spatial_{factor}.png')
sc.pl.spatial(
adata,
color=f'NMF_factor_{factor + 1}',
cmap='plasma',
title=f'NMF Factor {factor + 1}',
spot_size=15,
save=f'exo_to_cell_spatial_{factor}.png' if save else None
)

def spatial_interactions(
sdata: AnnData,
Expand All @@ -641,6 +788,78 @@ def spatial_interactions(
save: Optional[Union[str, Path]] = None,
**kwargs
):
"""
Visualizes the spatial interactions of extracellular RNA and associated cells.
This function generates a scatter plot showing the positions of target cells, source cells, and extracellular RNA transcripts within a spatial omics dataset. The target and source cells are highlighted in different colors, while the RNA transcripts are shown as points at their respective positions. Optionally, a background image (e.g., tissue section) can be displayed.
Parameters:
----------
sdata : AnnData
An AnnData object containing the spatial omics data, including transcript expression and cell positions.
layer : str, optional, default: 'extracellular_transcripts_enriched'
The layer in the AnnData object that contains the extracellular RNA transcript data.
gene : str, optional, default: 'Arc'
The gene of interest to be visualized in terms of its spatial interaction with source and target cells.
gene_key : str, optional, default: 'feature_name'
The column name in the AnnData object used to identify the gene.
cell_id_key : str, optional, default: 'cell_id'
The column name in the AnnData object used to identify individual cells.
color_target : str, optional, default: 'blue'
The color to be used for target cells in the plot.
color_source : str, optional, default: 'red'
The color to be used for source cells in the plot.
color_transcript : str, optional, default: 'green'
The color to be used for the RNA transcripts in the plot.
spatial_key : str, optional, default: 'spatial'
The key in the AnnData object that stores the spatial coordinates of the cells.
img : Optional[Union[bool, Sequence]], optional, default: None
A background image to overlay on the plot, such as a tissue section. Can be set to `None` to omit.
img_alpha : Optional[float], optional, default: None
The transparency level of the background image. Ignored if `img` is `None`.
image_cmap : Optional[Colormap], optional, default: None
The colormap to be used for the background image, if applicable.
size : Optional[Union[float, Sequence[float]]], optional, default: 8
The size of the scatter plot points for the cells and transcripts.
alpha : float, optional, default: 0.6
The transparency level for the scatter plot points.
title : Optional[Union[str, Sequence[str]]], optional, default: None
The title of the plot. If `None`, the gene name is used.
legend_loc : Optional[str], optional, default: 'best'
The location of the legend in the plot.
figsize : Tuple[float, float], optional, default: (10, 10)
The dimensions of the plot in inches.
dpi : Optional[int], optional, default: 100
The resolution (dots per inch) for the plot.
save : Optional[Union[str, Path]], optional, default: None
The path to save the plot image. If `None`, the plot is displayed but not saved.
**kwargs : Additional keyword arguments
Any additional arguments passed to the `scatter` or `imshow` functions for customizing plot appearance.
Returns:
-------
None
The function generates and displays (or saves) a scatter plot.
"""
# Extract relevant data
transcripts = sdata.points[layer]
trans_filt = transcripts[transcripts[gene_key] == gene]
Expand All @@ -657,8 +876,6 @@ def spatial_interactions(
plt.scatter(cell_positions.loc[source_cells, 'x'], cell_positions.loc[source_cells, 'y'], c=color_source, s=size, label='Source Cells', **kwargs)
plt.scatter(trans_filt['x'], trans_filt['y'], c=color_transcript, s=size*0.4, label='Transcripts', **kwargs)



# Titles and Legends
plt.title(title or gene)
plt.legend(loc=legend_loc)
Expand Down

0 comments on commit fa43a29

Please sign in to comment.