Skip to content

Commit

Permalink
Merge branch 'master' into harald/qT-tempering
Browse files Browse the repository at this point in the history
  • Loading branch information
HaraldMelin committed Oct 19, 2023
2 parents b906d48 + a5634c3 commit e92f854
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/sampling/laris.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def sample_arborescence_from_weighted_graph(graph: nx.DiGraph,

def _sample_feasible_arc(weighted_arcs):
# weighted_arcs is a list of 3-tuples (u, v, weight)
unnorm_probs = torch.stack([w for u, v, w in weighted_arcs])
# weights are negative: need transformation
unnorm_probs = 1 / (-torch.stack([w for u, v, w in weighted_arcs]))
probs = unnorm_probs / unnorm_probs.sum()
c = np.random.choice(np.arange(len(weighted_arcs)), p=probs.numpy())
return weighted_arcs[c][:2], probs[c]
Expand Down
2 changes: 1 addition & 1 deletion src/simul.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def make_anndata(obs, raw_counts, chr_dataframe, c, z, mu, tree, obs_names: list
return adata


def simulate_full_dataset(config: Config, eps_a=5., eps_b=50., mu0=1., lambda0=10.,
def simulate_full_dataset(config: Config, eps_a=500., eps_b=50000., mu0=1., lambda0=1000.,
alpha0=500., beta0=50., dir_delta: [float | list[float]] = 1., tree=None, raw_reads=True,
chr_df: pd.DataFrame | None = None, nans: bool = False,
fixed_z:torch.Tensor = None, cne_length_factor: int = 0):
Expand Down
2 changes: 2 additions & 0 deletions src/variational_distributions/var_dists.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,8 @@ def _normalize_graph_weights(self, w_matrix):
norm_w_matrix[...] = w_matrix - math_utils.nanlogsumexp(w_matrix, dim=1, keepdim=True)
else:
raise ValueError(f"normalization method {self._norm_method} not recognized")
# weight shouldn't be equal to 0 for stability
norm_w_matrix = torch.clamp(norm_w_matrix, max=-1e-8)

return norm_w_matrix

Expand Down
5 changes: 3 additions & 2 deletions tests/tests_utils/test_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import simul
from inference.victree import VICTree
from utils.config import Config
from utils.config import Config, set_seed
from utils.data_handling import DataHandler
from tests.data.generate_data import generate_2chr_adata
from variational_distributions.joint_dists import VarTreeJointDist
Expand All @@ -18,6 +18,7 @@ class dataHandlingTestCase(unittest.TestCase):

def setUp(self) -> None:

set_seed(0)
self.output_dir = "./test_output"
if not os.path.exists(self.output_dir):
os.mkdir(self.output_dir)
Expand All @@ -39,7 +40,7 @@ def test_write_output(self):
out_file = os.path.join(self.output_dir, 'out_test.h5')
# run victree
config = Config(n_nodes=4, n_cells=20, n_states=4,
n_run_iter=3, sieving_size=2, n_sieving_iter=2)
n_run_iter=3)
adata = generate_2chr_adata(config)
data_handler = DataHandler(adata=adata)
obs = data_handler.norm_reads
Expand Down

0 comments on commit e92f854

Please sign in to comment.