From fa43a29b07598adfe22cfd617a68c8dc09dcaf57 Mon Sep 17 00:00:00 2001 From: sergiomarco25 Date: Thu, 19 Dec 2024 13:25:26 +0100 Subject: [PATCH] errors_correction5 --- src/troutpy/pl/plotting.py | 279 ++++++++++++++++++++++++++++++++----- 1 file changed, 248 insertions(+), 31 deletions(-) diff --git a/src/troutpy/pl/plotting.py b/src/troutpy/pl/plotting.py index 6b5ab39..2633c46 100644 --- a/src/troutpy/pl/plotting.py +++ b/src/troutpy/pl/plotting.py @@ -195,16 +195,16 @@ def plot_crosstab(data, xvar: str = '', yvar: str = '', normalize=True, axis=1, ----------- data : pd.DataFrame Input dataset containing the variables for the cross-tabulation. - + xvar : str, optional (default: '') The variable to use on the x-axis for the cross-tabulation. - + yvar : str, optional (default: '') The variable to use on the y-axis for the cross-tabulation. - + normalize : bool, optional (default: True) Whether to normalize the cross-tabulated data (percentages). If True, the data will be normalized. - + axis : int, optional (default: 1) The axis to normalize across. Use `1` for row normalization and `0` for column normalization. @@ -314,7 +314,7 @@ def pie_of_positive(data, groupby: str = '', figures_path: str = '', save: bool """ plt.figure() - y = np.array([np.sum(data[groupby] == False), np.sum(data[groupby] == True)]) + y = np.array([np.sum(data[groupby] == False), np.sum(data[groupby] )]) mylabels = [f"{groupby}=False", f"{groupby}=True"] plt.pie(y, labels=mylabels, colors=['#a0b7e0', '#c5e493']) @@ -325,7 +325,7 @@ def pie_of_positive(data, groupby: str = '', figures_path: str = '', save: bool plt.savefig(os.path.join(figures_path, plot_filename)) def genes_over_noise(sdata, scores_by_genes,layer='extracellular_transcripts', output_path:str='',save=True,format:str='pdf'): - """This function plots log fold change per gene over noise using a boxplot. + """Function that plots log fold change per gene over noise using a boxplot. Parameters: - data_quantified: DataFrame containing the extracellular transcript data, including feature names and codeword categories. @@ -460,25 +460,105 @@ def proportion_above_threshold( plt.savefig(os.path.join(figures_path, f'{filename}.{format}')) plt.show() -def nmf_factors_exrna_cells_W(sdata,nmf_adata_key='nmf_data', save=True,saving_path='',spot_size:int=30,cmap='viridis'): +def nmf_factors_exrna_cells_W(sdata, nmf_adata_key: str = 'nmf_data', save: bool = True, saving_path: str = '', + spot_size: int = 30, cmap: str = 'viridis'): + """ + Plot NMF factors for each cell in a spatial transcriptomics dataset. + + This function extracts the NMF (Non-negative Matrix Factorization) factors from the specified AnnData object + within the spatial data (`sdata`) and creates spatial plots for each factor. The plots can be displayed or saved to disk. + + Parameters: + ----------- + sdata : AnnData or SpatialData object + A spatial transcriptomics dataset that contains the NMF factors in the specified key. + nmf_adata_key : str, optional + The key in `sdata` that contains the AnnData object with NMF results. Defaults to 'nmf_data'. + save : bool, optional + Whether to save the spatial factor plots to disk. Defaults to True. + saving_path : str, optional + Path where the plots should be saved if `save` is True. The plots are saved in a `figures` subdirectory. + Defaults to an empty string. + spot_size : int, optional + Size of the spots in the spatial plot. Defaults to 30. + cmap : str, optional + Colormap to use for the spatial plots. Defaults to 'viridis'. + + Returns: + -------- + None + Displays the spatial plots for each NMF factor. If `save` is True, the plots are saved as PNG files. + + Notes: + ------ + - The NMF factors are expected to be stored in `adata.obsm['W_nmf']`, where `adata` is extracted from `sdata`. + - A maximum of 20 factors is plotted by iterating through the columns of `W_nmf`. + - When saving, each plot is named `spatialnmf{factor}.png` and stored in a `figures` directory inside `saving_path`. + + Example: + -------- + >>> nmf_factors_exrna_cells_W(sdata, nmf_adata_key='nmf_data', save=True, saving_path='./results', spot_size=50, cmap='plasma') + """ # Plot the factors for each cell in a spatial plot - adata=sdata[nmf_adata_key] + adata = sdata[nmf_adata_key] W = adata.obsm['W_nmf'] for factor in range(20): # Add the factor values to adata.obs for plotting adata.obs[f'NMF_factor_{factor + 1}'] = W[:, factor] # Plot spatial map of cells colored by this factor if save: - sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap=cmap, title=f'NMF Factor {factor + 1}', spot_size=30,show=False) - plt.savefig(saving_path+'/figures/'+ f'spatialnmf{factor}.png') + sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap=cmap, title=f'NMF Factor {factor + 1}', + spot_size=30, show=False) + plt.savefig(saving_path + '/figures/' + f'spatialnmf{factor}.png') plt.show() else: - sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap=cmap, title=f'NMF Factor {factor + 1}', spot_size=spot_size) + sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap=cmap, title=f'NMF Factor {factor + 1}', + spot_size=spot_size) + +def nmf_gene_contributions(sdata, nmf_adata_key: str = 'nmf_data', save: bool = True, vmin: float = 0.0, vmax: float = 0.02, + saving_path: str = '', cmap: str = 'viridis', figsize: tuple = (5, 5)): + """Plot a heatmap of gene contributions to NMF factors. -def nmf_gene_contributions(sdata,nmf_adata_key='nmf_data', save=True, vmin=0.0, vmax=0.02,saving_path='',cmap='viridis',figsize=(5,5)): - adata=sdata[nmf_adata_key] - loadings=pd.DataFrame(adata.uns['H_nmf'],columns=adata.var.index) - loadings_filtered=loadings.loc[:,np.max(loadings,axis=0)>0.05].transpose() + This function extracts the NMF (Non-negative Matrix Factorization) gene loadings matrix from the specified AnnData object within the spatial data (`sdata`), filters genes based on their maximum loading value, and plots a heatmap of the filtered loadings. + + Parameters: + ----------- + sdata : AnnData or SpatialData object + A spatial transcriptomics dataset that contains the NMF factors in the specified key. + nmf_adata_key : str, optional + The key in `sdata` that contains the AnnData object with NMF results. Defaults to 'nmf_data'. + save : bool, optional + Whether to save the heatmap plot to disk. Defaults to True. + vmin : float, optional + Minimum value for the colormap scale. Defaults to 0.0. + vmax : float, optional + Maximum value for the colormap scale. Defaults to 0.02. + saving_path : str, optional + Path where the plot should be saved if `save` is True. The plot is saved in a `figures` subdirectory. + Defaults to an empty string. + cmap : str, optional + Colormap to use for the heatmap. Defaults to 'viridis'. + figsize : tuple, optional + Size of the heatmap figure. Defaults to (5, 5). + + Returns: + -------- + None + Displays a heatmap of gene contributions to NMF factors. If `save` is True, the heatmap is saved as a PDF file. + + Notes: + ------ + - The gene loadings matrix is expected to be stored in `adata.uns['H_nmf']`, where `adata` is extracted from `sdata`. + - Genes with a maximum loading value greater than 0.05 are included in the heatmap. + - The rows of the heatmap are sorted based on the factor with the highest contribution for each gene. + + Example: + -------- + >>> nmf_gene_contributions(sdata, nmf_adata_key='nmf_data', save=True, saving_path='./results', cmap='plasma', figsize=(10, 8)) + """ + adata = sdata[nmf_adata_key] + loadings = pd.DataFrame(adata.uns['H_nmf'], columns=adata.var.index) + loadings_filtered = loadings.loc[:, np.max(loadings, axis=0) > 0.05].transpose() figures_path = os.path.join(saving_path, 'figures') os.makedirs(figures_path, exist_ok=True) @@ -494,28 +574,60 @@ def nmf_gene_contributions(sdata,nmf_adata_key='nmf_data', save=True, vmin=0.0, plt.savefig(os.path.join(figures_path, "loadings_NMF.pdf")) plt.show() plt.close() # Close the figure to avoid memory issues - + def apply_exrnaH_to_cellular_to_create_cellularW(adata_extracellular_with_nmf, adata_annotated_cellular): + """Apply extracellular RNA NMF loadings (H) to cellular data to generate cellular NMF factors (W). + + This function transfers the gene loadings (H matrix) derived from extracellular RNA analysis to a cellular dataset. It calculates the new W matrix for cellular data by multiplying the gene expression values of the cellular dataset with the filtered H matrix. + + Parameters: + ----------- + adata_extracellular_with_nmf : AnnData + An AnnData object containing the extracellular RNA data with the NMF results. + The H matrix is expected to be stored in `adata.uns['H_nmf']`. + adata_annotated_cellular : AnnData + An AnnData object containing the cellular RNA data with annotated gene expression values. + + Returns: + -------- + AnnData + The input `adata_annotated_cellular` object with the following updates: + - Adds the calculated NMF factors (W matrix) as a DataFrame to `adata.obsm['factors']`. + - Adds each NMF factor as individual columns in `adata.obs` with names `NMF_factor_1`, `NMF_factor_2`, etc. + + Notes: + ------ + - Only the genes common between the extracellular RNA data and the cellular data are used for the computation. + - The gene intersection ensures compatibility between the NMF H matrix and the cellular gene expression matrix. + + Example: + -------- + >>> adata_cellular = apply_exrnaH_to_cellular_to_create_cellularW(adata_extracellular_with_nmf, adata_annotated_cellular) + + """ + # Extract the H matrix (NMF gene loadings) from the extracellular data H = adata_extracellular_with_nmf.uns['H_nmf'] - - # Check the number of genes in adata_annotated and spots2region_output to match gene loadings (H) + + # Check the genes in both datasets genes_spots2region = adata_extracellular_with_nmf.var_names genes_annotated = adata_annotated_cellular.var_names - # Get intersection of genes between the two datasets + # Get the intersection of genes between the two datasets common_genes = genes_annotated.intersection(genes_spots2region) - # Filter both datasets to keep only common genes + # Filter both datasets to retain only common genes adata_annotated_cellular = adata_annotated_cellular[:, common_genes] - H_filtered = H[:, np.isin(genes_spots2region, common_genes)] # Filtered NMF gene loadings for common genes + H_filtered = H[:, np.isin(genes_spots2region, common_genes)] # Filter H matrix to include only common genes - # Apply the NMF factors to the annotated dataset - # Calculate the new W matrix by multiplying the annotated data with the filtered H + # Compute the new W matrix for the cellular dataset W_annotated = adata_annotated_cellular.X @ H_filtered.T - adata_annotated_cellular.obsm['factors']=pd.DataFrame(W_annotated,index=adata_annotated_cellular.obs.index) - #print(W_annotated[:, 0].shape) - # Add the factors as new columns in adata_annotated.obs + # Store the W matrix in the obsm attribute as a DataFrame + adata_annotated_cellular.obsm['factors'] = pd.DataFrame( + W_annotated, index=adata_annotated_cellular.obs.index + ) + + # Add individual NMF factors to adata.obs for factor in range(W_annotated.shape[1]): adata_annotated_cellular.obs[f'NMF_factor_{factor + 1}'] = W_annotated[:, factor] @@ -614,10 +726,45 @@ def paired_nmf_factors( plt.tight_layout() plt.show() -def W(adata, n_factors, save=True): # not very intuitive - # Plot the spatial map of cells colored by each factor +def plot_nmf_factors_spatial(adata, n_factors, save=True): + """ + Plot spatial maps of cells colored by NMF factors. + + This function visualizes the spatial distribution of cells, colored by their corresponding NMF factor values, stored in `adata.obs`. It iterates over all specified NMF factors and generates spatial plots for each factor. + + Parameters: + ----------- + adata : AnnData + An AnnData object containing the dataset with NMF factors already added as columns in `adata.obs`.Each factor should be named `NMF_factor_1`, `NMF_factor_2`, ..., `NMF_factor_n`. + n_factors : int + The number of NMF factors to plot. + save : bool, optional (default=True) + If `True`, saves the plots to files with filenames `exo_to_cell_spatial_.png`. + + Returns: + -------- + None + This function does not return anything but generates and optionally saves spatial plots. + + Notes: + ------ + - The plots are colored using the 'plasma' colormap. + - The spot size for the spatial plots is set to 15 by default. + - Files are saved in the current working directory unless specified otherwise using `sc.settings.figdir`. + + Example: + -------- + >>> plot_nmf_factors_spatial(adata, n_factors=10, save=True) + """ for factor in range(n_factors): - sc.pl.spatial(adata, color=f'NMF_factor_{factor + 1}', cmap='plasma', title=f'NMF Factor {factor + 1}', spot_size=15, save=f'exo_to_cell_spatial_{factor}.png') + sc.pl.spatial( + adata, + color=f'NMF_factor_{factor + 1}', + cmap='plasma', + title=f'NMF Factor {factor + 1}', + spot_size=15, + save=f'exo_to_cell_spatial_{factor}.png' if save else None + ) def spatial_interactions( sdata: AnnData, @@ -641,6 +788,78 @@ def spatial_interactions( save: Optional[Union[str, Path]] = None, **kwargs ): + """ + Visualizes the spatial interactions of extracellular RNA and associated cells. + + This function generates a scatter plot showing the positions of target cells, source cells, and extracellular RNA transcripts within a spatial omics dataset. The target and source cells are highlighted in different colors, while the RNA transcripts are shown as points at their respective positions. Optionally, a background image (e.g., tissue section) can be displayed. + + Parameters: + ---------- + sdata : AnnData + An AnnData object containing the spatial omics data, including transcript expression and cell positions. + + layer : str, optional, default: 'extracellular_transcripts_enriched' + The layer in the AnnData object that contains the extracellular RNA transcript data. + + gene : str, optional, default: 'Arc' + The gene of interest to be visualized in terms of its spatial interaction with source and target cells. + + gene_key : str, optional, default: 'feature_name' + The column name in the AnnData object used to identify the gene. + + cell_id_key : str, optional, default: 'cell_id' + The column name in the AnnData object used to identify individual cells. + + color_target : str, optional, default: 'blue' + The color to be used for target cells in the plot. + + color_source : str, optional, default: 'red' + The color to be used for source cells in the plot. + + color_transcript : str, optional, default: 'green' + The color to be used for the RNA transcripts in the plot. + + spatial_key : str, optional, default: 'spatial' + The key in the AnnData object that stores the spatial coordinates of the cells. + + img : Optional[Union[bool, Sequence]], optional, default: None + A background image to overlay on the plot, such as a tissue section. Can be set to `None` to omit. + + img_alpha : Optional[float], optional, default: None + The transparency level of the background image. Ignored if `img` is `None`. + + image_cmap : Optional[Colormap], optional, default: None + The colormap to be used for the background image, if applicable. + + size : Optional[Union[float, Sequence[float]]], optional, default: 8 + The size of the scatter plot points for the cells and transcripts. + + alpha : float, optional, default: 0.6 + The transparency level for the scatter plot points. + + title : Optional[Union[str, Sequence[str]]], optional, default: None + The title of the plot. If `None`, the gene name is used. + + legend_loc : Optional[str], optional, default: 'best' + The location of the legend in the plot. + + figsize : Tuple[float, float], optional, default: (10, 10) + The dimensions of the plot in inches. + + dpi : Optional[int], optional, default: 100 + The resolution (dots per inch) for the plot. + + save : Optional[Union[str, Path]], optional, default: None + The path to save the plot image. If `None`, the plot is displayed but not saved. + + **kwargs : Additional keyword arguments + Any additional arguments passed to the `scatter` or `imshow` functions for customizing plot appearance. + + Returns: + ------- + None + The function generates and displays (or saves) a scatter plot. + """ # Extract relevant data transcripts = sdata.points[layer] trans_filt = transcripts[transcripts[gene_key] == gene] @@ -657,8 +876,6 @@ def spatial_interactions( plt.scatter(cell_positions.loc[source_cells, 'x'], cell_positions.loc[source_cells, 'y'], c=color_source, s=size, label='Source Cells', **kwargs) plt.scatter(trans_filt['x'], trans_filt['y'], c=color_transcript, s=size*0.4, label='Transcripts', **kwargs) - - # Titles and Legends plt.title(title or gene) plt.legend(loc=legend_loc)