Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The problem encountered in extracting features from spot patches #5

Open
tqwang743 opened this issue Apr 1, 2023 · 2 comments
Open

Comments

@tqwang743
Copy link

hello,
I was trying to retrain the conST model without using the trained weights conST_151673.pth, but i I encountered difficulties while performing the following step . So can you share the code about this step.
image
Thank you!

@wenhuidu
Copy link

wenhuidu commented Jul 3, 2023

I have the same problem

@bbchond
Copy link

bbchond commented Mar 20, 2024

I face the same problem, and I use the following code to extract features from spot patches:

from tqdm import tqdm
from MAEpytorch.modeling_pretrain import pretrain_mae_base_patch16_224
from PIL import Image
import torch.backends.cudnn as cudnn
from MAEpytorch.datasets import DataAugmentationForMAE

cudnn.benchmark = True
input_size = 224

# load image, and then crop image into multiple 224 * 224 size
image = adata_h5.uns["spatial"][section_id]['images']['hires']
if image.dtype == np.float32 or image.dtype == np.float64:
    image = (image * 255).astype(np.uint8)

scale_factor = adata_h5.uns['spatial'][section_id]['scalefactors']['tissue_hires_scalef']
image_coord = adata_h5.obsm['spatial'] * scale_factor
# patches = []
crop_size = input_size / 2

# load MAE model
# model = get_model(args)
model = pretrain_mae_base_patch16_224()
patch_size = model.encoder.patch_embed.patch_size
print("Patch size = %s" % str(patch_size))

window_size = (input_size // patch_size[0], input_size // patch_size[1])
model.to(device)

checkpoint = torch.load('./MAEpytorch/pretrain_mae_vit_base_mask_0.75_400e.pth', map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.eval()

params.window_size = (input_size // patch_size[0], input_size // patch_size[1])
# extract morphological information from image patches
transforms = DataAugmentationForMAE(params)
image_pillow = Image.fromarray(image)
features_all = []

with tqdm(total=len(adata_h5), desc='image cropping and feature extracting...', bar_format='{l_bar}{bar} [ time left: {remaining} ]') as pbar:
    for image_row, image_col in zip(adata.obs['image_row'], adata.obs['image_col']):
        patch = image_pillow.crop((image_col - crop_size, image_row - crop_size, image_col + crop_size, image_row + crop_size))
        patch.thumbnail((input_size, input_size), Image.LANCZOS)
        patch.resize((input_size, input_size))
        # patches.append(patch)
        
        img, bool_masked_pos = transforms(patch)
        bool_masked_pos = torch.from_numpy(bool_masked_pos)
        
        with torch.no_grad():
            img = img[None, :]
            bool_masked_pos = bool_masked_pos[None, :]
            img = img.to(device, non_blocking=True)
            bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
            
            features = model.encoder.forward(img, bool_masked_pos)
            features_pool = features.mean(dim = 1)
            features_all.append(features_pool.detach().cpu().numpy())
        pbar.update(1)

You can use phenograph python package to visualize the domain identification result based on the extracted morphological features:

from sklearn.decomposition import PCA
import phenograph

params.cell_num = adata_h5.shape[0]
img_transformed = np.asarray(features_all).reshape(params.cell_num, -1)
img_transformed = (img_transformed - img_transformed.mean()) / img_transformed.std() * adata_X.std() + adata_X.mean()

pca = PCA(n_components=50, random_state=42)
img_feat_pca = pca.fit_transform(img_transformed)

adata_h5.obsm['image_feat_pca'] = img_feat_pca
graph_label, _, _ = phenograph.cluster(adata_h5.obsm['image_feat_pca'])
adata_h5.obs['graph_label'] = graph_label
adata_h5.obs['graph_label'] = adata_h5.obs['graph_label'].astype('category')
sc.pl.spatial(adata_h5, color='graph_label')

However, there are some bugs need to be fixed because the dims of network layer is not correspond to the latent dims, and the final domain identification sometimes strange. Hope it will help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants