Skip to content

Commit

Permalink
[feat] changes tailored for figures
Browse files Browse the repository at this point in the history
  • Loading branch information
toyo97 committed Nov 1, 2023
1 parent 3ce24f5 commit 0fccc70
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/experiments/var_tree_experiments/run_single_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import os.path
import re
import sys
from pathlib import Path

from inference.victree import make_input, VICTree
from utils.evaluation import sample_dataset_generation, evaluate_victree_to_df


def run_dataset(K, M, N, seed, extend_qt_temp=1., gt_temp_mult=1., final_step=False):
def run_dataset(K, M, N, seed, extend_qt_temp=1., gt_temp_mult=1., final_step=False, save_out=False):
print(f"running dat: K {K} M {M} N {N}, dataset {seed}")
jq_true, ad = sample_dataset_generation(K, M, N, seed)
n_nodes = jq_true.config.n_nodes
Expand All @@ -22,10 +23,17 @@ def run_dataset(K, M, N, seed, extend_qt_temp=1., gt_temp_mult=1., final_step=Fa
# sieving=(3, 3),
split='ELBO',
debug=True)
out_path = f"./dat{seed}_K{K}M{M}N{N}"
if not os.path.exists(out_path):
os.mkdir(out_path)
ad.write_h5ad(Path(os.path.join(out_path, "adata.h5ad")))

config.out_dir = out_path
config.gT_temp = gt_temp_mult * config.qT_temp
config.temp_extend = extend_qt_temp
victree = VICTree(config, jq, data_handler=dh, draft=True, elbo_rtol=1e-4)
victree = VICTree(config, jq, data_handler=dh, draft=not save_out, elbo_rtol=1e-4)
victree.run(100, final_step=final_step)
victree.write()

# save results
results_df = evaluate_victree_to_df(jq_true, victree, dataset_id=seed, tree_enumeration=n_nodes < 7)
Expand All @@ -34,7 +42,6 @@ def run_dataset(K, M, N, seed, extend_qt_temp=1., gt_temp_mult=1., final_step=Fa
out_suff += f"gtm{gt_temp_mult}" if gt_temp_mult != 1. else ""
out_suff += f"fs" if final_step else ""

out_path = './'
out_csv = os.path.join(out_path, "score" + out_suff + ".csv")
results_df['extend_temp'] = extend_qt_temp
results_df['gt_temp_mult'] = gt_temp_mult
Expand All @@ -59,4 +66,4 @@ def run_dataset(K, M, N, seed, extend_qt_temp=1., gt_temp_mult=1., final_step=Fa
re_match = params_re.match(dat_path)
run_dataset(int(re_match.group('K')), int(re_match.group('M')), int(re_match.group('N')),
int(re_match.group('seed')), extend_qt_temp=qt_temp_extend, gt_temp_mult=gt_temp_mult,
final_step=True)
final_step=True, save_out=True)
1 change: 1 addition & 0 deletions src/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def evaluate_victree_to_df(true_joint, victree, dataset_id, df=None, tree_enumer
# tree eval
if victree.config.n_nodes == true_joint.config.n_nodes:
best_map = best_mapping(true_lab, victree.q.z.pi.numpy())
print(best_map)
true_tree = tree_utils.relabel_nodes(true_joint.t.true_params['tree'], best_map)
mst = nx.maximum_spanning_arborescence(victree.q.t.weighted_graph)
intersect_edges = nx.intersection(true_tree, mst).edges
Expand Down

0 comments on commit 0fccc70

Please sign in to comment.