-
Notifications
You must be signed in to change notification settings - Fork 16
/
run-seacells.py
149 lines (122 loc) · 4.04 KB
/
run-seacells.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python3
"""
SEAcells Analysis Script
Joshua Shapiro
2024-07-18
This script runs the SEACells algorithm on a dataset.
"""
import argparse
import contextlib
import pathlib
import pickle
import sys
import anndata
import numpy as np
import scanpy as sc
import SEACells
import session_info
def convert_adata(adata: anndata.AnnData) -> anndata.AnnData:
"""
Convert an ScPCA AnnData object to the formatting expected by scanpy and SEACells.
"""
# put the highly variable genes in a column of adata.var
adata.var["highly_variable"] = adata.var.gene_ids.isin(
adata.uns["highly_variable_genes"]
)
# recompute principal components and umap (adds additional metadata)
sc.tl.pca(adata, n_comps=50, mask_var="highly_variable")
sc.pp.neighbors(adata)
sc.tl.umap(adata)
return adata
def run_seacells(
adata: anndata.AnnData, cell_ratio: float = 75, verbose: bool = False
) -> tuple[anndata.AnnData, SEACells.core.SEACells]:
"""
Run the SEACells algorithm on the given dataset.
Parameters
----------
adata : anndata.AnnData
An AnnData object containing the data to run the SEACells algorithm on.
Should contain an X_pca field with the PCA coordinates of the cells.
cell_ratio : float
The ratio of cells to metacells to use; i.e. number of cells per metacell
verbose : bool
Whether to print verbose output during the SEACells algorithm
Returns
-------
anndata.AnnData
The input AnnData object with the metacell assignments added to the obs table with the key "SEACell"
SEACells.core.SEACells
The SEACells model object
"""
n_metacells = round(adata.n_obs / cell_ratio)
n_eigs = 10 # number of eigenvalues for initialization
# initialize the SEACells model
model = SEACells.core.SEACells(
adata,
build_kernel_on="X_pca",
n_SEACells=n_metacells,
n_waypoint_eigs=n_eigs,
convergence_epsilon=1e-5,
verbose=verbose,
)
# initialize and fit model
model.construct_kernel_matrix()
model.initialize_archetypes()
model.fit(min_iter=10, max_iter=50)
return (adata, model)
def main() -> None:
parser = argparse.ArgumentParser(
description="Run the SEACell algorithm on the given dataset."
)
parser.add_argument(
"adata_file", type=pathlib.Path, help="The input data in H5AD format."
)
parser.add_argument(
"--adata_out",
type=pathlib.Path,
required=True,
help="The output file path for the AnnData object (should end in .h5ad).",
)
parser.add_argument(
"--model_out",
type=pathlib.Path,
required=False,
help="The output file path for the SEACells model object.",
)
parser.add_argument(
"--seed",
type=int,
help="The random seed to use for reproducibility.",
default=2024,
)
parser.add_argument(
"--logfile", type=pathlib.Path, help="File path for log outputs"
)
args = parser.parse_args()
# check filenames
if args.adata_out.suffix != ".h5ad":
raise ValueError("Output file must end in .h5ad")
if args.model_out and args.model_out.suffix != ".pkl":
raise ValueError("Model output file must end in .pkl")
if args.logfile:
logs = open(args.logfile, "w")
sys.stdout = logs
# set seed for reproducibility
np.random.seed(args.seed)
adata = anndata.read_h5ad(args.adata_file)
adata = convert_adata(adata)
adata, seacell_model = run_seacells(adata, verbose=args.logfile is not None)
# save the results
print(f"Saving results to {args.adata_out}")
adata.write_h5ad(args.adata_out, compression="gzip")
if args.model_out:
print(f"Saving results to {args.model_out}")
with open(args.model_out, "wb") as f:
pickle.dump(seacell_model, f)
# only write session info if a logfile is provided
if args.logfile:
session_info.show(dependencies=True)
logs.close()
if __name__ == "__main__":
main()