Skip to content

Commit

Permalink
merge with master
Browse files Browse the repository at this point in the history
  • Loading branch information
HaraldMelin committed Oct 29, 2023
2 parents 7696195 + 842b81a commit ba8aa99
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 28 deletions.
65 changes: 65 additions & 0 deletions src/experiments/scicone_10x.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'scicone'",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[1], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mscicone\u001B[39;00m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01manndata\u001B[39;00m\n\u001B[1;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpandas\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mpd\u001B[39;00m\n",
"\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'scicone'"
]
}
],
"source": [
"import scicone\n",
"import anndata\n",
"import pandas as pd\n",
"import numpy\n",
"\n",
"dat10x_path = \"/home/zemp/scilife/scicone/breast_tissue_A_2k_cnv_data.h5\""
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
9 changes: 7 additions & 2 deletions src/experiments/var_tree_experiments/run_single_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from utils.evaluation import sample_dataset_generation, evaluate_victree_to_df


def run_dataset(K, M, N, seed):
def run_dataset(K, M, N, seed, extend_qt_temp=1.):
out_path = f"./dat{seed}_K{K}M{M}N{N}"
if not os.path.exists(out_path):
os.mkdir(out_path)
Expand All @@ -25,6 +25,7 @@ def run_dataset(K, M, N, seed):
# sieving=(3, 3),
split='ELBO',
debug=True)
config.qT_temp_extend = extend_qt_temp
victree = VICTree(config, jq, data_handler=dh, draft=True, elbo_rtol=1e-4)
victree.run(100)

Expand All @@ -38,7 +39,11 @@ def run_dataset(K, M, N, seed):
if __name__ == '__main__':
# read dataset path
dat_path = sys.argv[1]
qt_temp_extend = 1.
if len(sys.argv) > 2:
qt_temp_extend = float(sys.argv[2])
print(f"setting qt temp extend to {qt_temp_extend}")
params_re = re.compile(r'^.*/K(?P<K>\d+)M(?P<M>\d+)N(?P<N>\d+)/(?P<seed>\d+)\.png$')
re_match = params_re.match(dat_path)
run_dataset(int(re_match.group('K')), int(re_match.group('M')), int(re_match.group('N')),
int(re_match.group('seed')))
int(re_match.group('seed')), extend_qt_temp=qt_temp_extend)
8 changes: 4 additions & 4 deletions src/inference/victree.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ def run(self, n_iter=-1, args=None):

self.step()



rel_change = np.abs((self.elbo - old_elbo) / self.elbo)

# progress bar showing elbo
Expand All @@ -175,7 +173,8 @@ def run(self, n_iter=-1, args=None):
'elbo': self.elbo,
'diff': f"{rel_change * 100:.3f}%",
'll': f"{self.q.total_log_likelihood:.3f}",
'ss': self.config.step_size
'ss': self.config.step_size,
'qttemp': self.q.t.temp
})

# early-stopping
Expand Down Expand Up @@ -385,7 +384,8 @@ def set_temperature(self, it, n_iter):
a = torch.tensor(self.config.qT_temp)
b = torch.tensor(int(n_iter * 0.2))
d = torch.tensor(1.)
c = math_utils.inverse_decay_function_calculate_c(a, b, d, torch.tensor(n_iter))
c = math_utils.inverse_decay_function_calculate_c(a, b, d, torch.tensor(n_iter),
extend=self.config.qT_temp_extend)
self.q.t.temp = math_utils.inverse_decay_function(it, a, b, c)
self.q.t.g_temp = self.q.t.temp # g(T) temp by default set to q(T) temp

Expand Down
3 changes: 2 additions & 1 deletion src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def __init__(self,
step_size_forgetting_rate=0.7,
step_size_delay=2.,
merge_and_split_interval=5,
qT_temp=1.,
qT_temp=1., qT_temp_extend = 1.,
gT_temp=1.,
qZ_temp=1.) -> None:
self.qT_temp_extend = qT_temp_extend
self.qZ_temp = qZ_temp
self.qT_temp = qT_temp
self.gT_temp = gT_temp
Expand Down
71 changes: 55 additions & 16 deletions src/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,32 @@ def best_mapping(gt_z: np.ndarray, vi_z: np.ndarray, with_score=False):
else:
return perms[best_perm_idx]

def best_vi_map(vi_z, ref_vi_z):
# FIXME: maybe too intensive for K > 10
k = vi_z.shape[1]
kref = ref_vi_z.shape[1]
if k <= kref:
ext_vi_z = np.zeros_like(ref_vi_z)
ext_vi_z[:, :k] = vi_z
perms = [list((0,) + p) for p in itertools.combinations(range(1, kref), k)]
scores = []
for p in perms:
score = np.sum(ref_vi_z * ext_vi_z[:, p])
scores.append(score)
best_perm = perms[np.argmax(scores)]
else:
# kref < k
# need to find best and then append extra labels
ext_ref_vi_z = np.zeros_like(vi_z)
ext_ref_vi_z[:, :kref] = ref_vi_z
perms = [list((0,) + p) for p in itertools.permutations(range(1, k))]
scores = []
for p in perms:
score = np.sum(ext_ref_vi_z * vi_z[:, p])
scores.append(score)
best_perm = perms[np.argmax(scores)]
return best_perm


def evaluate_victree_to_df(true_joint, victree, dataset_id, df=None, tree_enumeration=False):
"""
Expand All @@ -112,14 +138,14 @@ def evaluate_victree_to_df(true_joint, victree, dataset_id, df=None, tree_enumer
"""
out_data = {}
out_data['dataset-id'] = dataset_id
out_data['dataset_id'] = dataset_id
out_data['K'] = true_joint.config.n_nodes
out_data['vK'] = victree.config.n_nodes
out_data['M'] = true_joint.config.chain_length
out_data['N'] = true_joint.config.n_cells
out_data['true-ll'] = true_joint.total_log_likelihood
out_data['vi-ll'] = victree.q.total_log_likelihood
out_data['vi-diff'] = out_data['true-ll'] - out_data['vi-ll']
out_data['true_ll'] = true_joint.total_log_likelihood
out_data['vi_ll'] = victree.q.total_log_likelihood
out_data['vi_diff'] = out_data['true_ll'] - out_data['vi_ll']
out_data['elbo'] = victree.elbo
out_data['iters'] = victree.it_counter
out_data['time'] = victree.exec_time_
Expand All @@ -128,42 +154,55 @@ def evaluate_victree_to_df(true_joint, victree, dataset_id, df=None, tree_enumer
true_lab = true_joint.z.true_params['z']
vi_lab = victree.q.z.best_assignment()
out_data['ari'] = adjusted_rand_score(true_lab, vi_lab)
out_data['v-meas'] = v_measure_score(true_lab, vi_lab)
out_data['v_meas'] = v_measure_score(true_lab, vi_lab)

# copy number calling eval
true_c = true_joint.c.true_params['c'][true_lab].numpy()
pred_c = victree.q.c.get_viterbi()[vi_lab].numpy()
cn_mad = np.abs(pred_c - true_c).mean()
out_data['cn-mad'] = cn_mad
out_data['cn_mad'] = cn_mad

# tree eval
if victree.config.n_nodes == true_joint.config.n_nodes:
best_map = best_mapping(true_lab, victree.q.z.pi.numpy())
true_tree = tree_utils.relabel_nodes(true_joint.t.true_params['tree'], best_map)
mst = nx.maximum_spanning_arborescence(victree.q.t.weighted_graph)
intersect_edges = nx.intersection(true_tree, mst).edges
out_data['edge-sensitivity'] = len(intersect_edges) / len(mst.edges)
out_data['edge-precision'] = len(intersect_edges) / len(true_tree.edges)
out_data['edge_sensitivity'] = len(intersect_edges) / len(mst.edges)
out_data['edge_precision'] = len(intersect_edges) / len(true_tree.edges)

qt_pmf = victree.q.t.get_pmf_estimate(True, n=50)
qt_pmf = victree.q.t.get_pmf_estimate(True, n=100, desc_sorted=True)
true_tree_newick = tree_to_newick(true_tree)
mst_newick = tree_to_newick(mst)
out_data['qt-true'] = qt_pmf[true_tree_newick].item() if true_tree_newick in qt_pmf.keys() else 0.
out_data['qt-mst'] = qt_pmf[mst_newick].item() if mst_newick in qt_pmf.keys() else 0.
out_data['qt_support'] = len(qt_pmf.keys())
out_data['qt_true_rank'] = -1
if true_tree_newick in qt_pmf.keys():
out_data['qt_true'] = qt_pmf[true_tree_newick].item()
rank = 0
for nwk in qt_pmf.keys():
if true_tree_newick == nwk:
break
else:
rank += 1
out_data['qt_true_rank'] = rank
else:
out_data['qt_true'] = 0.

out_data['qt_mst'] = qt_pmf[mst_newick].item() if mst_newick in qt_pmf.keys() else 0.
pmf_arr = np.array(list(qt_pmf.values()))
out_data['pt-true'] = np.nan
out_data['pt-mst'] = np.nan
out_data['pt_true'] = np.nan
out_data['pt_mst'] = np.nan
if tree_enumeration:
try:
pt = victree.q.t.enumerate_trees()
pt_dict = {tree_to_newick(nwk): math.exp(logp) for nwk, logp in zip(pt[0], pt[1].tolist())}
out_data['pt-true'] = pt_dict[true_tree_newick]
out_data['pt-mst'] = pt_dict[mst_newick]
out_data['pt_true'] = pt_dict[true_tree_newick]
out_data['pt_mst'] = pt_dict[mst_newick]
except BaseException:
print(traceback.format_exc())

# normalized entropy
out_data['qt-entropy'] = - np.sum(pmf_arr * np.log(pmf_arr)) / np.log(pmf_arr.size)
out_data['qt_entropy'] = - np.sum(pmf_arr * np.log(pmf_arr)) / np.log(pmf_arr.size)

if df is None:
df = pd.DataFrame()
Expand Down
7 changes: 5 additions & 2 deletions src/utils/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,13 @@ def inverse_decay_function(x: torch.Tensor, a, b, c):
z = torch.max(torch.tensor(1.), x - b)
return a * z ** (-c)

def inverse_decay_function_calculate_c(a, b, d, x):
def inverse_decay_function_calculate_c(a, b, d, x, extend=1.):
"""
Returns the value c which solves the equation: d = f(x) = a * 1 / (x - b)^c
Given by: c = (log(a) - log(d)) / log(x - b)
Can be used to calculate the required c for needed for f(max_iter) = 1. when tempering.
Params:
extend: extends the temperature cooling process
so that temp(max_iter * extend) = 1.
"""
return (torch.log(a) - torch.log(d)) / torch.log(x - b)
return (torch.log(a) - torch.log(d)) / torch.log(x * extend - b)
2 changes: 1 addition & 1 deletion src/utils/tree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def parse_newick(tree_file, config=None):
und_tree_nx = Phylo.to_networkx(tree)
# Phylo names add unwanted information in unstructured way
# find node numbers and relabel nx tree
names_string = ''.join(str(cl.confidence) if cl.name is None else cl.name for cl in und_tree_nx.nodes)
names_string = list(str(cl.confidence) if cl.name is None else cl.name for cl in und_tree_nx.nodes)
mapping = dict(zip(und_tree_nx, names_string))
relabeled_tree = nx.relabel_nodes(und_tree_nx, mapping)
tree_nx = nx.DiGraph()
Expand Down
24 changes: 22 additions & 2 deletions src/variational_distributions/var_dists.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ class qZ(VariationalDistribution):
def __init__(self, config: Config, true_params=None):
super().__init__(config, true_params is not None)

self.temp = 1.0
self._temp = 1.0
self._pi = torch.empty((config.n_cells, config.n_nodes))

self.kmeans_labels = torch.empty(config.n_cells, dtype=torch.long)
Expand All @@ -890,6 +890,16 @@ def pi(self, pi):
logging.warning('Trying to re-set qc attribute when it should be fixed')
self._pi[...] = pi

@property
def temp(self):
return self._temp

@temp.setter
def temp(self, t):
if isinstance(t, torch.Tensor):
t = t.item()
self._temp = t

def get_params_as_dict(self):
return {
'pi': self.pi.numpy()
Expand Down Expand Up @@ -1072,7 +1082,7 @@ def __init__(self, config: Config, true_params=None, norm_method='stochastic', s
self._norm_method = norm_method
self._sampling_method = sampling_method

self.temp = 1.0
self._temp = 1.0
self.g_temp = 1.0

if true_params is not None:
Expand All @@ -1083,6 +1093,16 @@ def __init__(self, config: Config, true_params=None, norm_method='stochastic', s
self.params_history["trees_sample_newick"] = []
self.params_history["trees_sample_weights"] = []

@property
def temp(self):
return self._temp

@temp.setter
def temp(self, t):
if isinstance(t, torch.Tensor):
t = t.item()
self._temp = t

@property
def weighted_graph(self):
return self._weighted_graph
Expand Down

0 comments on commit ba8aa99

Please sign in to comment.