Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement propagate on cc generation level #188

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions big_scape/cli/cluster_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def cluster(ctx, *args, **kwargs):
# get context parameters
ctx.obj.update(ctx.params)
ctx.obj["query_bgc_path"] = None
ctx.obj["propagate"] = True # compatibility with query wrt cc generation
ctx.obj["mode"] = "Cluster"

# workflow validations
Expand Down
10 changes: 9 additions & 1 deletion big_scape/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import string
import tqdm
import click
from typing import Optional, Generator, cast
from sqlalchemy import (
Column,
Expand Down Expand Up @@ -133,13 +134,20 @@ def dfs(adj_list, start):
Returns:
set: set of visited nodes
"""
click_context = click.get_current_context()

stack = [start]
visited = set()
while stack:
node = stack.pop()
if node not in visited:
visited.add(node)
stack.extend([n for n in adj_list[node] if n not in visited])

# in query mode, only expand beyond the query if propagate flag is given
if click_context and click_context.obj["propagate"]:
stack.extend([n for n in adj_list[node] if n not in visited])
else:
visited.update([n for n in adj_list[node]])
return visited


Expand Down
6 changes: 6 additions & 0 deletions test/integration/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

# from python
from unittest import TestCase
from unittest.mock import MagicMock
from click.globals import push_context
from pathlib import Path
from itertools import combinations

Expand Down Expand Up @@ -1471,6 +1473,8 @@ def test_query_generators_workflow(self):
self.assertEqual(missing_pair_generator.num_pairs(), 0)

def test_generate_bins_query_workflow(self):
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)
bs_data.DB.create_in_mem()

run = {
Expand Down Expand Up @@ -1608,6 +1612,8 @@ def test_generate_bins_query_workflow(self):
self.assertEqual(rows, 6)

def test_calculate_distances_query(self):
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)
bs_data.DB.create_in_mem()

run = {
Expand Down
14 changes: 14 additions & 0 deletions test/integration/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

# from python
from unittest import TestCase
from unittest.mock import MagicMock
from click.globals import push_context
from pathlib import Path
from itertools import combinations

Expand Down Expand Up @@ -98,6 +100,9 @@ def __init__(self, methodName: str = "runTest") -> None:
self.addCleanup(self.clean_db)

def test_get_connected_components_two_cutoffs(self):
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)

bs_data.DB.create_in_mem()

run = {
Expand Down Expand Up @@ -142,6 +147,9 @@ def test_get_connected_components_two_cutoffs(self):
self.assertEqual(len(ccs), 2)

def test_get_connected_components_all_cutoffs(self):
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)

bs_data.DB.create_in_mem()

run = {
Expand Down Expand Up @@ -395,6 +403,8 @@ def test_get_connected_components_all_cutoffs(self):

def test_get_connected_components_no_ref_to_ref_ccs(self):
"""Tests whether ref only ccs are correclty excluded from the analysis"""
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)
bs_data.DB.create_in_mem()

run = {
Expand Down Expand Up @@ -493,6 +503,8 @@ def test_get_connected_components_no_ref_to_ref_ccs(self):
self.assertEqual(rows, len(gbks_b + [gbks_c[0]]))

def test_get_connected_components_two_bins(self):
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)
bs_data.DB.create_in_mem()

run = {
Expand Down Expand Up @@ -551,6 +563,8 @@ def test_get_connected_components_two_bins(self):
# bs_data.DB.save_to_disk(Path("after_legacy.db"))

def test_get_connected_components_two_bins_different_edge_params(self):
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)
bs_data.DB.create_in_mem()

run = {
Expand Down
70 changes: 70 additions & 0 deletions test/network/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# from python
import pathlib
from unittest import TestCase
from unittest.mock import MagicMock
from click.globals import push_context
from itertools import combinations

# from dependencies
Expand Down Expand Up @@ -142,6 +144,8 @@ def test_get_edges(self):

def test_generate_cc(self):
"""Test the generate_connected_components function"""
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)
bs_data.DB.create_in_mem()

# create a bunch of gbk files
Expand Down Expand Up @@ -189,6 +193,8 @@ def test_generate_cc(self):

def test_get_cc_ids(self):
"""Test the get_connected_component_ids function"""
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)
bs_data.DB.create_in_mem()

# create a bunch of gbk files
Expand Down Expand Up @@ -364,6 +370,8 @@ def test_get_nodes_from_cc(self):
self.assertEqual(len(cc_nodes), len(gbks_a))

def test_is_ref_only_connected_component(self):
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)
bs_data.DB.create_in_mem()

run = {
Expand Down Expand Up @@ -425,6 +433,8 @@ def test_is_ref_only_connected_component(self):

def test_get_connected_component_id(self):
"""Tests the get_connected_component_id function"""
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)
bs_data.DB.create_in_mem()

run = {
Expand Down Expand Up @@ -460,6 +470,8 @@ def test_get_connected_component_id(self):
self.assertEqual(expected_data, cc_id)

def test_remove_connected_component(self):
ctx = MagicMock(obj={"propagate": True})
push_context(ctx)
bs_data.DB.create_in_mem()

run = {
Expand Down Expand Up @@ -527,3 +539,61 @@ def test_remove_connected_component(self):
expected_data = {"pre": 2, "post": 1}

self.assertEqual(expected_data, cc_status)

def test_query_no_propagate_cc_generation(self):
"""Tests generation of connected components does not propagate"""
ctx = MagicMock(obj={"propagate": False})
push_context(ctx)

bs_data.DB.create_in_mem()

run = {
"record_type": bs_enums.RECORD_TYPE.REGION,
"query_record_number": 1,
"alignment_mode": bs_enums.ALIGNMENT_MODE.AUTO,
"extend_strategy": bs_enums.EXTEND_STRATEGY.LEGACY,
}

query_gbk = create_mock_gbk(1, bs_enums.SOURCE_TYPE.QUERY)
first_layer_gbk = create_mock_gbk(2, bs_enums.SOURCE_TYPE.REFERENCE)
second_layer_gbk = create_mock_gbk(3, bs_enums.SOURCE_TYPE.REFERENCE)
third_layer_gbk = create_mock_gbk(4, bs_enums.SOURCE_TYPE.REFERENCE)
gbks = [query_gbk, first_layer_gbk, second_layer_gbk, third_layer_gbk]
for gbk in gbks:
gbk.save_all()

include_records, q_record = bs_files.get_all_bgc_records_query(run, gbks)

# query is connected to first layer, but not to second layer
qf_edges = gen_mock_edge_list([query_gbk, first_layer_gbk], 0.5)
fs_edges = gen_mock_edge_list([first_layer_gbk, second_layer_gbk], 0.5)
st_edges = gen_mock_edge_list([second_layer_gbk, third_layer_gbk], 0.5)

edges = qf_edges + fs_edges + st_edges

# save the edges
for edge in edges:
bs_comparison.save_edge_to_db(edge)

query_bin = bs_comparison.QueryRecordPairGenerator(
"Query", 1, "mix", run["record_type"]
)
query_bin.add_records(include_records)

query_cc = next(
bs_network.get_connected_components(1, 1, query_bin, 1, q_record), None
)

# if not propagated correctly, cc should contain only one edge: between the
# query and the first layer
self.assertEqual(len(query_cc), 1)

ctx = MagicMock(obj={"propagate": True})
push_context(ctx)

query_cc = next(
bs_network.get_connected_components(0.99, 1, query_bin, 1, q_record), None
)

# with propagation, the second and third layers are included in the cc: 3 edges
self.assertEqual(len(query_cc), 3)
Loading