From 497e76911a68fc33b903064d30c83e516fd2bf1a Mon Sep 17 00:00:00 2001 From: Vittorio Date: Sat, 28 Oct 2023 18:50:12 +0200 Subject: [PATCH 1/2] [feat] new params for qt temperature scheduling and evaluation --- .../run_single_dataset.py | 9 ++- src/inference/victree.py | 6 +- src/utils/config.py | 3 +- src/utils/evaluation.py | 71 ++++++++++++++----- src/utils/math_utils.py | 7 +- src/utils/tree_utils.py | 2 +- src/variational_distributions/var_dists.py | 24 ++++++- 7 files changed, 96 insertions(+), 26 deletions(-) diff --git a/src/experiments/var_tree_experiments/run_single_dataset.py b/src/experiments/var_tree_experiments/run_single_dataset.py index fba7e65..37296a1 100755 --- a/src/experiments/var_tree_experiments/run_single_dataset.py +++ b/src/experiments/var_tree_experiments/run_single_dataset.py @@ -8,7 +8,7 @@ from utils.evaluation import sample_dataset_generation, evaluate_victree_to_df -def run_dataset(K, M, N, seed): +def run_dataset(K, M, N, seed, extend_qt_temp=1.): out_path = f"./dat{seed}_K{K}M{M}N{N}" if not os.path.exists(out_path): os.mkdir(out_path) @@ -25,6 +25,7 @@ def run_dataset(K, M, N, seed): # sieving=(3, 3), split='ELBO', debug=True) + config.qT_temp_extend = extend_qt_temp victree = VICTree(config, jq, data_handler=dh, draft=True, elbo_rtol=1e-4) victree.run(100) @@ -38,7 +39,11 @@ def run_dataset(K, M, N, seed): if __name__ == '__main__': # read dataset path dat_path = sys.argv[1] + qt_temp_extend = 1. + if len(sys.argv) > 2: + qt_temp_extend = float(sys.argv[2]) + print(f"setting qt temp extend to {qt_temp_extend}") params_re = re.compile(r'^.*/K(?P\d+)M(?P\d+)N(?P\d+)/(?P\d+)\.png$') re_match = params_re.match(dat_path) run_dataset(int(re_match.group('K')), int(re_match.group('M')), int(re_match.group('N')), - int(re_match.group('seed'))) + int(re_match.group('seed')), extend_qt_temp=qt_temp_extend) diff --git a/src/inference/victree.py b/src/inference/victree.py index e44d1b1..6f4346f 100644 --- a/src/inference/victree.py +++ b/src/inference/victree.py @@ -173,7 +173,8 @@ def run(self, n_iter=-1, args=None): 'elbo': self.elbo, 'diff': f"{rel_change * 100:.3f}%", 'll': f"{self.q.total_log_likelihood:.3f}", - 'ss': self.config.step_size + 'ss': self.config.step_size, + 'qttemp': self.q.t.temp }) # early-stopping @@ -383,7 +384,8 @@ def set_temperature(self, it, n_iter): b = torch.tensor(int(n_iter * 0.2)) d = torch.tensor(1.) a = torch.tensor(self.config.qT_temp) - c = math_utils.inverse_decay_function_calculate_c(a, b, d, torch.tensor(n_iter)) + c = math_utils.inverse_decay_function_calculate_c(a, b, d, torch.tensor(n_iter), + extend=self.config.qT_temp_extend) self.q.t.temp = math_utils.inverse_decay_function(it, a, b, c) if self.config.qZ_temp != 1.: diff --git a/src/utils/config.py b/src/utils/config.py index 1d60c94..238aba1 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -33,8 +33,9 @@ def __init__(self, step_size_forgetting_rate=0.7, step_size_delay=2., merge_and_split_interval=5, - qT_temp=1., + qT_temp=1., qT_temp_extend = 1., qZ_temp=1.) -> None: + self.qT_temp_extend = qT_temp_extend self.qZ_temp = qZ_temp self.qT_temp = qT_temp self.merge_and_split_interval = merge_and_split_interval diff --git a/src/utils/evaluation.py b/src/utils/evaluation.py index 1bc071c..a2d79f7 100644 --- a/src/utils/evaluation.py +++ b/src/utils/evaluation.py @@ -96,6 +96,32 @@ def best_mapping(gt_z: np.ndarray, vi_z: np.ndarray, with_score=False): else: return perms[best_perm_idx] +def best_vi_map(vi_z, ref_vi_z): + # FIXME: maybe too intensive for K > 10 + k = vi_z.shape[1] + kref = ref_vi_z.shape[1] + if k <= kref: + ext_vi_z = np.zeros_like(ref_vi_z) + ext_vi_z[:, :k] = vi_z + perms = [list((0,) + p) for p in itertools.combinations(range(1, kref), k)] + scores = [] + for p in perms: + score = np.sum(ref_vi_z * ext_vi_z[:, p]) + scores.append(score) + best_perm = perms[np.argmax(scores)] + else: + # kref < k + # need to find best and then append extra labels + ext_ref_vi_z = np.zeros_like(vi_z) + ext_ref_vi_z[:, :kref] = ref_vi_z + perms = [list((0,) + p) for p in itertools.permutations(range(1, k))] + scores = [] + for p in perms: + score = np.sum(ext_ref_vi_z * vi_z[:, p]) + scores.append(score) + best_perm = perms[np.argmax(scores)] + return best_perm + def evaluate_victree_to_df(true_joint, victree, dataset_id, df=None, tree_enumeration=False): """ @@ -112,14 +138,14 @@ def evaluate_victree_to_df(true_joint, victree, dataset_id, df=None, tree_enumer """ out_data = {} - out_data['dataset-id'] = dataset_id + out_data['dataset_id'] = dataset_id out_data['K'] = true_joint.config.n_nodes out_data['vK'] = victree.config.n_nodes out_data['M'] = true_joint.config.chain_length out_data['N'] = true_joint.config.n_cells - out_data['true-ll'] = true_joint.total_log_likelihood - out_data['vi-ll'] = victree.q.total_log_likelihood - out_data['vi-diff'] = out_data['true-ll'] - out_data['vi-ll'] + out_data['true_ll'] = true_joint.total_log_likelihood + out_data['vi_ll'] = victree.q.total_log_likelihood + out_data['vi_diff'] = out_data['true_ll'] - out_data['vi_ll'] out_data['elbo'] = victree.elbo out_data['iters'] = victree.it_counter out_data['time'] = victree.exec_time_ @@ -128,13 +154,13 @@ def evaluate_victree_to_df(true_joint, victree, dataset_id, df=None, tree_enumer true_lab = true_joint.z.true_params['z'] vi_lab = victree.q.z.best_assignment() out_data['ari'] = adjusted_rand_score(true_lab, vi_lab) - out_data['v-meas'] = v_measure_score(true_lab, vi_lab) + out_data['v_meas'] = v_measure_score(true_lab, vi_lab) # copy number calling eval true_c = true_joint.c.true_params['c'][true_lab].numpy() pred_c = victree.q.c.get_viterbi()[vi_lab].numpy() cn_mad = np.abs(pred_c - true_c).mean() - out_data['cn-mad'] = cn_mad + out_data['cn_mad'] = cn_mad # tree eval if victree.config.n_nodes == true_joint.config.n_nodes: @@ -142,28 +168,41 @@ def evaluate_victree_to_df(true_joint, victree, dataset_id, df=None, tree_enumer true_tree = tree_utils.relabel_nodes(true_joint.t.true_params['tree'], best_map) mst = nx.maximum_spanning_arborescence(victree.q.t.weighted_graph) intersect_edges = nx.intersection(true_tree, mst).edges - out_data['edge-sensitivity'] = len(intersect_edges) / len(mst.edges) - out_data['edge-precision'] = len(intersect_edges) / len(true_tree.edges) + out_data['edge_sensitivity'] = len(intersect_edges) / len(mst.edges) + out_data['edge_precision'] = len(intersect_edges) / len(true_tree.edges) - qt_pmf = victree.q.t.get_pmf_estimate(True, n=50) + qt_pmf = victree.q.t.get_pmf_estimate(True, n=100, desc_sorted=True) true_tree_newick = tree_to_newick(true_tree) mst_newick = tree_to_newick(mst) - out_data['qt-true'] = qt_pmf[true_tree_newick].item() if true_tree_newick in qt_pmf.keys() else 0. - out_data['qt-mst'] = qt_pmf[mst_newick].item() if mst_newick in qt_pmf.keys() else 0. + out_data['qt_support'] = len(qt_pmf.keys()) + out_data['qt_true_rank'] = -1 + if true_tree_newick in qt_pmf.keys(): + out_data['qt_true'] = qt_pmf[true_tree_newick].item() + rank = 0 + for nwk in qt_pmf.keys(): + if true_tree_newick == nwk: + break + else: + rank += 1 + out_data['qt_true_rank'] = rank + else: + out_data['qt_true'] = 0. + + out_data['qt_mst'] = qt_pmf[mst_newick].item() if mst_newick in qt_pmf.keys() else 0. pmf_arr = np.array(list(qt_pmf.values())) - out_data['pt-true'] = np.nan - out_data['pt-mst'] = np.nan + out_data['pt_true'] = np.nan + out_data['pt_mst'] = np.nan if tree_enumeration: try: pt = victree.q.t.enumerate_trees() pt_dict = {tree_to_newick(nwk): math.exp(logp) for nwk, logp in zip(pt[0], pt[1].tolist())} - out_data['pt-true'] = pt_dict[true_tree_newick] - out_data['pt-mst'] = pt_dict[mst_newick] + out_data['pt_true'] = pt_dict[true_tree_newick] + out_data['pt_mst'] = pt_dict[mst_newick] except BaseException: print(traceback.format_exc()) # normalized entropy - out_data['qt-entropy'] = - np.sum(pmf_arr * np.log(pmf_arr)) / np.log(pmf_arr.size) + out_data['qt_entropy'] = - np.sum(pmf_arr * np.log(pmf_arr)) / np.log(pmf_arr.size) if df is None: df = pd.DataFrame() diff --git a/src/utils/math_utils.py b/src/utils/math_utils.py index e59a00d..70167b7 100644 --- a/src/utils/math_utils.py +++ b/src/utils/math_utils.py @@ -79,10 +79,13 @@ def inverse_decay_function(x: torch.Tensor, a, b, c): z = torch.max(torch.tensor(1.), x - b) return a * z ** (-c) -def inverse_decay_function_calculate_c(a, b, d, x): +def inverse_decay_function_calculate_c(a, b, d, x, extend=1.): """ Returns the value c which solves the equation: d = f(x) = a * 1 / (x - b)^c Given by: c = (log(a) - log(d)) / log(x - b) Can be used to calculate the required c for needed for f(max_iter) = 1. when tempering. + Params: + extend: extends the temperature cooling process + so that temp(max_iter * extend) = 1. """ - return (torch.log(a) - torch.log(d)) / torch.log(x - b) \ No newline at end of file + return (torch.log(a) - torch.log(d)) / torch.log(x * extend - b) \ No newline at end of file diff --git a/src/utils/tree_utils.py b/src/utils/tree_utils.py index 311fa9c..19dd1f4 100644 --- a/src/utils/tree_utils.py +++ b/src/utils/tree_utils.py @@ -291,7 +291,7 @@ def parse_newick(tree_file, config=None): und_tree_nx = Phylo.to_networkx(tree) # Phylo names add unwanted information in unstructured way # find node numbers and relabel nx tree - names_string = ''.join(str(cl.confidence) if cl.name is None else cl.name for cl in und_tree_nx.nodes) + names_string = list(str(cl.confidence) if cl.name is None else cl.name for cl in und_tree_nx.nodes) mapping = dict(zip(und_tree_nx, names_string)) relabeled_tree = nx.relabel_nodes(und_tree_nx, mapping) tree_nx = nx.DiGraph() diff --git a/src/variational_distributions/var_dists.py b/src/variational_distributions/var_dists.py index 7faff5e..7405721 100644 --- a/src/variational_distributions/var_dists.py +++ b/src/variational_distributions/var_dists.py @@ -866,7 +866,7 @@ class qZ(VariationalDistribution): def __init__(self, config: Config, true_params=None): super().__init__(config, true_params is not None) - self.temp = 1.0 + self._temp = 1.0 self._pi = torch.empty((config.n_cells, config.n_nodes)) self.kmeans_labels = torch.empty(config.n_cells, dtype=torch.long) @@ -890,6 +890,16 @@ def pi(self, pi): logging.warning('Trying to re-set qc attribute when it should be fixed') self._pi[...] = pi + @property + def temp(self): + return self._temp + + @temp.setter + def temp(self, t): + if isinstance(t, torch.Tensor): + t = t.item() + self._temp = t + def get_params_as_dict(self): return { 'pi': self.pi.numpy() @@ -1072,7 +1082,7 @@ def __init__(self, config: Config, true_params=None, norm_method='stochastic', s self._norm_method = norm_method self._sampling_method = sampling_method - self.temp = 1.0 + self._temp = 1.0 if true_params is not None: assert 'tree' in true_params @@ -1082,6 +1092,16 @@ def __init__(self, config: Config, true_params=None, norm_method='stochastic', s self.params_history["trees_sample_newick"] = [] self.params_history["trees_sample_weights"] = [] + @property + def temp(self): + return self._temp + + @temp.setter + def temp(self, t): + if isinstance(t, torch.Tensor): + t = t.item() + self._temp = t + @property def weighted_graph(self): return self._weighted_graph From 842b81ab6ba7e28098e5002840cd7264471e6df4 Mon Sep 17 00:00:00 2001 From: Vittorio Date: Sat, 28 Oct 2023 18:50:43 +0200 Subject: [PATCH 2/2] [feat] scicone preprocessing --- src/experiments/scicone_10x.ipynb | 65 +++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 src/experiments/scicone_10x.ipynb diff --git a/src/experiments/scicone_10x.ipynb b/src/experiments/scicone_10x.ipynb new file mode 100644 index 0000000..6cac8c7 --- /dev/null +++ b/src/experiments/scicone_10x.ipynb @@ -0,0 +1,65 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'scicone'", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[1], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mscicone\u001B[39;00m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01manndata\u001B[39;00m\n\u001B[1;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpandas\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mpd\u001B[39;00m\n", + "\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'scicone'" + ] + } + ], + "source": [ + "import scicone\n", + "import anndata\n", + "import pandas as pd\n", + "import numpy\n", + "\n", + "dat10x_path = \"/home/zemp/scilife/scicone/breast_tissue_A_2k_cnv_data.h5\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file