diff --git a/sourmash/commands.py b/sourmash/commands.py index 9c4db3f408..9e82b35a6e 100644 --- a/sourmash/commands.py +++ b/sourmash/commands.py @@ -753,7 +753,7 @@ def index(args): scaleds.add(ss.minhash.scaled) leaf = SigLeaf(ss.md5sum(), ss) - tree.add_node(leaf) + tree.add_node(leaf, update_internal=False) n += 1 if not ss: diff --git a/sourmash/sbt.py b/sourmash/sbt.py index 760941866d..ab6eec7f2a 100644 --- a/sourmash/sbt.py +++ b/sourmash/sbt.py @@ -106,7 +106,7 @@ class SBT(object): """A Sequence Bloom Tree implementation allowing generic internal nodes and leaves. The default node and leaf format is a Bloom Filter (like the original implementation), - but we also provide a MinHash leaf class (in the sourmash.sbtmh.Leaf + but we also provide a MinHash leaf class (in the sourmash.sbtmh.SigLeaf class) Parameters ---------- @@ -114,36 +114,60 @@ class SBT(object): Callable for generating new datastores for internal nodes. d: int Number of children for each internal node. Defaults to 2 (a binary tree) - n_tables: int - number of nodegraph tables to be used. - + storage: Storage, default: None + A Storage is any place where we can save and load data for the nodes. + If set to None, will use a FSStorage. Notes ----- - We use a defaultdict to store the tree structure. Nodes are numbered - specific node they are numbered + We use two dicts to store the tree structure: One for the internal nodes, + and another for the leaves. """ def __init__(self, factory, d=2, storage=None): self.factory = factory - self.nodes = defaultdict(lambda: None) - self.missing_nodes = set() + self._nodes = {} + self._missing_nodes = set() + self._leaves = {} self.d = d self.next_node = 0 self.storage = storage + self.is_ready = False def new_node_pos(self, node): - while self.nodes.get(self.next_node, None) is not None: - self.next_node += 1 + if not self._nodes: + self.next_node = 1 + return 0 + + if not self._leaves: + self.next_node = 2 + return 1 + + min_leaf = min(self._leaves.keys()) + + next_internal_node = None + if self.next_node <= min_leaf: + for i in range(min_leaf): + if all((self._nodes.get(i, None) is None, + self._leaves.get(i, None) is None, + i not in self._missing_nodes)): + next_internal_node = i + break + + if next_internal_node is None: + self.next_node = max(self._leaves.keys()) + 1 + else: + self.next_node = next_internal_node + return self.next_node - def add_node(self, node): - pos = self.new_node_pos(node) + def add_node(self, leaf, update_internal=True): + pos = self.new_node_pos(leaf) if pos == 0: # empty tree; initialize w/node. n = Node(self.factory, name="internal." + str(pos)) - self.nodes[0] = n - pos = self.new_node_pos(node) + self._nodes[0] = n + pos = self.new_node_pos(leaf) # Cases: # 1) parent is a Leaf (already covered) @@ -158,36 +182,52 @@ def add_node(self, node): # Create a new internal node # node and parent are children of new internal node n = Node(self.factory, name="internal." + str(p.pos)) - self.nodes[p.pos] = n + self._nodes[p.pos] = n c1, c2 = self.children(p.pos)[:2] - self.nodes[c1.pos] = p.node - self.nodes[c2.pos] = node + self._leaves[c1.pos] = p.node + self._leaves[c2.pos] = leaf + del self._leaves[p.pos] - for child in (p.node, node): - child.update(n) + if update_internal: + for child in (p.node, leaf): + child.update(n) + else: + self.is_ready = False elif isinstance(p.node, Node): - self.nodes[pos] = node - node.update(p.node) + self._leaves[pos] = leaf + if update_internal: + leaf.update(p.node) + else: + self.is_ready = False elif p.node is None: n = Node(self.factory, name="internal." + str(p.pos)) - self.nodes[p.pos] = n + self._nodes[p.pos] = n c1 = self.children(p.pos)[0] - self.nodes[c1.pos] = node - node.update(n) - - # update all parents! - p = self.parent(p.pos) - while p: - self._rebuild_node(p.pos) - node.update(self.nodes[p.pos]) + self._leaves[c1.pos] = leaf + if update_internal: + leaf.update(n) + else: + self.is_ready = False + + if update_internal: + # update all parents! p = self.parent(p.pos) + while p: + self._rebuild_node(p.pos) + leaf.update(self._nodes[p.pos]) + p = self.parent(p.pos) + else: + self.is_ready = False def find(self, search_fn, *args, **kwargs): "Search the tree using `search_fn`." # initialize search queue with top node of tree + if not self.is_ready: + self._fill_internal() + matches = [] visited, queue = set(), [0] @@ -195,13 +235,16 @@ def find(self, search_fn, *args, **kwargs): # function. while queue: node_p = queue.pop(0) - node_g = self.nodes.get(node_p, None) # repair while searching. + node_g = self._leaves.get(node_p, None) if node_g is None: - if node_p in self.missing_nodes: + node_g = self._nodes.get(node_p, None) + + if node_g is None: + if node_p in self._missing_nodes: self._rebuild_node(node_p) - node_g = self.nodes[node_p] + node_g = self._nodes[node_p] else: continue @@ -235,22 +278,20 @@ def _rebuild_node(self, pos=0): (the default). """ - node = self.nodes.get(pos, None) + node = self._nodes.get(pos, None) if node is not None: # this node was already build, skip return node = Node(self.factory, name="internal.{}".format(pos)) - self.nodes[pos] = node + self._nodes[pos] = node for c in self.children(pos): - if c.pos in self.missing_nodes or isinstance(c.node, Leaf): - if c.node is None: + if c.pos in self._missing_nodes or isinstance(c.node, Leaf): + cnode = c.node + if cnode is None: self._rebuild_node(c.pos) - c_node = self.nodes[c.pos] - if c_node is not None: - c_node.update(node) - self.missing_nodes.remove(pos) - + cnode = self._nodes[c.pos] + cnode.update(node) def parent(self, pos): """Return the parent of the node at position ``pos``. @@ -271,7 +312,10 @@ def parent(self, pos): if pos == 0: return None p = int(math.floor((pos - 1) / self.d)) - node = self.nodes.get(p, None) + if p in self._leaves: + return NodePos(p, self._leaves[p]) + + node = self._nodes.get(p, None) return NodePos(p, node) def children(self, pos): @@ -309,7 +353,10 @@ def child(self, parent, pos): child node. """ cd = self.d * parent + pos + 1 - node = self.nodes.get(cd, None) + if cd in self._leaves: + return NodePos(cd, self._leaves[cd]) + + node = self._nodes.get(cd, None) return NodePos(cd, node) def save(self, path, storage=None, sparseness=0.0, structure_only=False): @@ -335,7 +382,7 @@ def save(self, path, storage=None, sparseness=0.0, structure_only=False): str full path to the new SBT description """ - version = 4 + version = 5 if path.endswith('.sbt.json'): path = path[:-9] @@ -363,8 +410,12 @@ def save(self, path, storage=None, sparseness=0.0, structure_only=False): 'args': self.factory.init_args() } - structure = {} - total_nodes = len(self.nodes) + if not self.is_ready: + self._fill_internal() + + nodes = {} + leaves = {} + total_nodes = len(self) for n, (i, node) in enumerate(self): if node is None: continue @@ -394,12 +445,18 @@ def save(self, path, storage=None, sparseness=0.0, structure_only=False): data['filename'] = node.save(data['filename']) - structure[i] = data + node.storage = storage + data['filename'] = node.save(data['filename']) + if isinstance(node, Node): + nodes[i] = data + else: + leaves[i] = data notify("{} of {} nodes saved".format(n+1, total_nodes), end='\r') notify("\nFinished saving nodes, now saving SBT json file.") - info['nodes'] = structure + info['nodes'] = nodes + info['leaves'] = leaves with open(fn, 'w') as fp: json.dump(info, fp) @@ -434,6 +491,7 @@ def load(cls, location, leaf_loader=None, storage=None, print_version_warning=Tr 2: cls._load_v2, 3: cls._load_v3, 4: cls._load_v4, + 5: cls._load_v5, } # @CTB hack: check to make sure khmer Nodegraph supports the @@ -469,7 +527,7 @@ def _load_v1(jnodes, leaf_loader, dirname, storage, print_version_warning=True): if jnodes[0] is None: raise ValueError("Empty tree!") - sbt_nodes = defaultdict(lambda: None) + sbt_nodes = {} sample_bf = os.path.join(dirname, jnodes[0]['filename']) ksize, tablesize, ntables = khmer.extract_nodegraph_info(sample_bf)[:3] @@ -490,7 +548,7 @@ def _load_v1(jnodes, leaf_loader, dirname, storage, print_version_warning=True): sbt_nodes[i] = sbt_node tree = SBT(factory) - tree.nodes = sbt_nodes + tree._nodes = sbt_nodes return tree @@ -501,7 +559,8 @@ def _load_v2(cls, info, leaf_loader, dirname, storage, print_version_warning=Tru if nodes[0] is None: raise ValueError("Empty tree!") - sbt_nodes = defaultdict(lambda: None) + sbt_nodes = {} + sbt_leaves = {} sample_bf = os.path.join(dirname, nodes[0]['filename']) k, size, ntables = khmer.extract_nodegraph_info(sample_bf)[:3] @@ -516,13 +575,14 @@ def _load_v2(cls, info, leaf_loader, dirname, storage, print_version_warning=Tru if 'internal' in node['name']: node['factory'] = factory sbt_node = Node.load(node, storage) + sbt_nodes[k] = sbt_node else: sbt_node = leaf_loader(node, storage) - - sbt_nodes[k] = sbt_node + sbt_leaves[k] = sbt_node tree = cls(factory, d=info['d']) - tree.nodes = sbt_nodes + tree._nodes = sbt_nodes + tree._leaves = sbt_leaves return tree @@ -533,7 +593,8 @@ def _load_v3(cls, info, leaf_loader, dirname, storage, print_version_warning=Tru if not nodes: raise ValueError("Empty tree!") - sbt_nodes = defaultdict(lambda: None) + sbt_nodes = {} + sbt_leaves = {} klass = STORAGES[info['storage']['backend']] if info['storage']['backend'] == "FSStorage": @@ -551,22 +612,23 @@ def _load_v3(cls, info, leaf_loader, dirname, storage, print_version_warning=Tru if 'internal' in node['name']: node['factory'] = factory sbt_node = Node.load(node, storage) + sbt_nodes[k] = sbt_node else: sbt_node = leaf_loader(node, storage) + sbt_leaves[k] = sbt_node - sbt_nodes[k] = sbt_node max_node = max(max_node, k) tree = cls(factory, d=info['d'], storage=storage) - tree.nodes = sbt_nodes - tree.missing_nodes = {i for i in range(max_node) - if i not in sbt_nodes} - # TODO: this might not be true with combine... - tree.next_node = max_node + tree._nodes = sbt_nodes + tree._leaves = sbt_leaves + tree._missing_nodes = {i for i in range(max_node) + if i not in sbt_nodes and i not in sbt_leaves} if print_version_warning: error("WARNING: this is an old index version, please run `sourmash migrate` to update it.") error("WARNING: proceeding with execution, but it will take longer to finish!") + tree._fill_min_n_below() return tree @@ -609,6 +671,52 @@ def _load_v4(cls, info, leaf_loader, dirname, storage, print_version_warning=Tru # TODO: this might not be true with combine... tree.next_node = max_node + if print_version_warning: + error("WARNING: this is an old index version, please run `sourmash migrate` to update it.") + error("WARNING: proceeding with execution, but it will take longer to finish!") + + tree._fill_min_n_below() + + return tree + + @classmethod + def _load_v5(cls, info, leaf_loader, dirname, storage, print_version_warning=True): + nodes = {int(k): v for (k, v) in info['nodes'].items()} + leaves = {int(k): v for (k, v) in info['leaves'].items()} + + if not leaves: + raise ValueError("Empty tree!") + + sbt_nodes = {} + sbt_leaves = {} + + klass = STORAGES[info['storage']['backend']] + if info['storage']['backend'] == "FSStorage": + storage = FSStorage(dirname, info['storage']['args']['path']) + elif storage is None: + storage = klass(**info['storage']['args']) + + factory = GraphFactory(*info['factory']['args']) + + max_node = 0 + for k, node in nodes.items(): + node['factory'] = factory + sbt_node = Node.load(node, storage) + + sbt_nodes[k] = sbt_node + max_node = max(max_node, k) + + for k, node in leaves.items(): + sbt_leaf = leaf_loader(node, storage) + sbt_leaves[k] = sbt_leaf + max_node = max(max_node, k) + + tree = cls(factory, d=info['d'], storage=storage) + tree._nodes = sbt_nodes + tree._leaves = sbt_leaves + tree._missing_nodes = {i for i in range(max_node) + if i not in sbt_nodes and i not in sbt_leaves} + return tree def _fill_min_n_below(self): @@ -637,8 +745,20 @@ def fill_min_n_below(node, *args, **kwargs): self._fill_up(fill_min_n_below) + def _fill_internal(self): + + def fill_nodegraphs(node, *args, **kwargs): + children = kwargs['children'] + for child in children: + if child.node is not None: + child.node.update(node) + return True + + self._fill_up(fill_nodegraphs) + self.is_ready = True + def _fill_up(self, search_fn, *args, **kwargs): - visited, queue = set(), [i[0] for i in reversed(sorted(self._leaves()))] + visited, queue = set(), list(reversed(sorted(self._leaves.keys()))) debug("started filling up") processed = 0 while queue: @@ -652,7 +772,7 @@ def _fill_up(self, search_fn, *args, **kwargs): was_missing = False if parent.node is None: - if parent.pos in self.missing_nodes: + if parent.pos in self._missing_nodes: self._rebuild_node(parent.pos) parent = self.parent(node_p) was_missing = True @@ -677,6 +797,10 @@ def _fill_up(self, search_fn, *args, **kwargs): if processed % 100 == 0: debug("processed {}, in queue {}", processed, len(queue), sep='\r') + def __len__(self): + internal_nodes = set(self._nodes).union(self._missing_nodes) + return len(internal_nodes) + len(self._leaves) + def print_dot(self): print(""" digraph G { @@ -687,7 +811,7 @@ def print_dot(self): edge [arrowsize=0.8]; """) - for i, node in list(self.nodes.items()): + for i, node in self._nodes.items(): if isinstance(node, Node): print('"{}" [shape=box fillcolor=gray style=filled]'.format( node.name)) @@ -700,7 +824,7 @@ def print(self): visited, stack = set(), [0] while stack: node_p = stack.pop() - node_g = self.nodes.get(node_p, None) + node_g = self._nodes.get(node_p, None) if node_p not in visited and node_g is not None: visited.add(node_p) depth = int(math.floor(math.log(node_p + 1, self.d))) @@ -710,7 +834,9 @@ def print(self): if c.pos not in visited) def __iter__(self): - for i, node in self.nodes.items(): + for i, node in self._nodes.items(): + yield (i, node) + for i, node in self._leaves.items(): yield (i, node) def _parents(self, pos=0): @@ -722,53 +848,48 @@ def _parents(self, pos=0): yield p.pos p = self.parent(p.pos) - - def _leaves(self, pos=0): - for i, node in self: - if isinstance(node, Leaf): - if pos in self._parents(i): - yield (i, node) - - def leaves(self): - for c in self.nodes.values(): - if isinstance(c, Leaf): - yield c + def leaves(self, with_pos=False): + for pos, data in self._leaves.items(): + if with_pos: + yield (pos, data) + else: + yield data def combine(self, other): larger, smaller = self, other - if len(other.nodes) > len(self.nodes): + if len(other) > len(self): larger, smaller = other, self n = Node(self.factory, name="internal.0", storage=self.storage) - larger.nodes[0].update(n) - smaller.nodes[0].update(n) - new_nodes = defaultdict(lambda: None) + larger._nodes[0].update(n) + smaller._nodes[0].update(n) + new_nodes = {} new_nodes[0] = n - levels = int(math.ceil(math.log(len(larger.nodes), self.d))) + 1 + new_leaves = {} + + levels = int(math.ceil(math.log(len(larger), self.d))) + 1 current_pos = 1 n_previous = 0 n_next = 1 for level in range(1, levels + 1): for tree in (larger, smaller): for pos in range(n_previous, n_next): - if tree.nodes.get(pos, None) is not None: - new_node = copy(tree.nodes[pos]) - if isinstance(new_node, Node): - # An internal node, we need to update the name - new_node.name = "internal.{}".format(current_pos) + if tree._nodes.get(pos, None) is not None: + new_node = copy(tree._nodes[pos]) + new_node.name = "internal.{}".format(current_pos) new_nodes[current_pos] = new_node + elif tree._leaves.get(pos, None) is not None: + new_node = copy(tree._leaves[pos]) + new_leaves[current_pos] = new_node current_pos += 1 n_previous = n_next n_next = n_previous + int(self.d ** level) current_pos = n_next - # reset next_node, next time we add a node it will find the next - # empty position - self.next_node = 2 - # TODO: do we want to return a new tree, or merge into this one? - self.nodes = new_nodes + self._nodes = new_nodes + self._leaves = new_leaves return self @@ -827,11 +948,12 @@ def load(info, storage=None): def update(self, parent): parent.data.update(self.data) - min_n_below = min(parent.metadata.get('min_n_below', sys.maxsize), - self.metadata.get('min_n_below')) - if min_n_below == 0: - min_n_below = 1 - parent.metadata['min_n_below'] = min_n_below + if 'min_n_below' in self.metadata: + min_n_below = min(parent.metadata.get('min_n_below', sys.maxsize), + self.metadata.get('min_n_below')) + if min_n_below == 0: + min_n_below = 1 + parent.metadata['min_n_below'] = min_n_below class Leaf(object): diff --git a/tests/test_sbt.py b/tests/test_sbt.py index 8589c6aa71..139e2f25a4 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -285,7 +285,7 @@ def test_sbt_combine(n_children): # check if adding a new node will use the next empty position next_empty = 0 - for n, d in enumerate(tree_1.nodes): + for n, (d, _) in enumerate(tree_1): if n != d: next_empty = n break @@ -468,7 +468,7 @@ def test_tree_repair_add_node(): leaf = SigLeaf(os.path.basename(f), sig) tree_repair.add_node(leaf) - for pos, node in list(tree_repair.nodes.items()): + for pos, node in tree_repair: # Every parent of a node must be an internal node (and not a leaf), # except for node 0 (the root), whose parent is None. if pos != 0: @@ -499,7 +499,7 @@ def test_save_sparseness(n_children): tree.save(os.path.join(location, 'demo'), sparseness=1.0) tree_loaded = SBT.load(os.path.join(location, 'demo'), leaf_loader=SigLeaf.load) - assert all(not isinstance(n, Node) for n in tree_loaded.nodes.values()) + assert all(not isinstance(n, Node) for _, n in tree_loaded) print('*' * 60) print("{}:".format(to_search.metadata)) @@ -509,7 +509,7 @@ def test_save_sparseness(n_children): assert old_result == new_result - for pos, node in list(tree_loaded.nodes.items()): + for pos, node in tree_loaded: # Every parent of a node must be an internal node (and not a leaf), # except for node 0 (the root), whose parent is None. if pos != 0: diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index f85d9f5afc..dcfe719a4a 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -1507,7 +1507,7 @@ def test_do_sourmash_sbt_search_check_bug(): assert '1 matches:' in out tree = load_sbt_index(os.path.join(location, 'zzz.sbt.json')) - assert tree.nodes[0].metadata['min_n_below'] == 431 + assert tree._nodes[0].metadata['min_n_below'] == 431 def test_do_sourmash_sbt_search_empty_sig(): @@ -1531,7 +1531,7 @@ def test_do_sourmash_sbt_search_empty_sig(): assert '1 matches:' in out tree = load_sbt_index(os.path.join(location, 'zzz.sbt.json')) - assert tree.nodes[0].metadata['min_n_below'] == 1 + assert tree._nodes[0].metadata['min_n_below'] == 1 def test_do_sourmash_sbt_move_and_search_output(): @@ -3538,10 +3538,9 @@ def test_storage_convert(): ipfs = SBT.load(testsbt, leaf_loader=SigLeaf.load) - assert len(original.nodes) == len(ipfs.nodes) + assert len(original) == len(ipfs) assert all(n1[1].name == n2[1].name - for (n1, n2) in zip(sorted(original.nodes.items()), - sorted(ipfs.nodes.items()))) + for (n1, n2) in zip(sorted(original), sorted(ipfs))) args = ['storage', 'convert', '-b', """'TarStorage("{}")'""".format( @@ -3551,10 +3550,10 @@ def test_storage_convert(): in_directory=location) tar = SBT.load(testsbt, leaf_loader=SigLeaf.load) - assert len(original.nodes) == len(tar.nodes) + assert len(original) == len(tar) assert all(n1[1].name == n2[1].name - for (n1, n2) in zip(sorted(original.nodes.items()), - sorted(tar.nodes.items()))) + for (n1, n2) in zip(sorted(original), sorted(tar))) + def test_storage_convert_identity(): with utils.TempDirectory() as location: @@ -3572,10 +3571,9 @@ def test_storage_convert_identity(): identity = SBT.load(testsbt, leaf_loader=SigLeaf.load) - assert len(original.nodes) == len(identity.nodes) + assert len(original) == len(identity) assert all(n1[1].name == n2[1].name - for (n1, n2) in zip(sorted(original.nodes.items()), - sorted(identity.nodes.items()))) + for (n1, n2) in zip(sorted(original), sorted(identity))) def test_storage_convert_fsstorage_newpath(): @@ -3596,10 +3594,9 @@ def test_storage_convert_fsstorage_newpath(): identity = SBT.load(testsbt, leaf_loader=SigLeaf.load) - assert len(original.nodes) == len(identity.nodes) + assert len(original) == len(identity) assert all(n1[1].name == n2[1].name - for (n1, n2) in zip(sorted(original.nodes.items()), - sorted(identity.nodes.items()))) + for (n1, n2) in zip(sorted(original), sorted(identity))) def test_migrate(): @@ -3617,14 +3614,14 @@ def test_migrate(): identity = SBT.load(testsbt, leaf_loader=SigLeaf.load) - assert len(original.nodes) == len(identity.nodes) + assert len(original) == len(identity) assert all(n1[1].name == n2[1].name - for (n1, n2) in zip(sorted(original.nodes.items()), - sorted(identity.nodes.items()))) + for (n1, n2) in zip(sorted(original), + sorted(identity))) assert "this is an old index version" not in err assert all('min_n_below' in node.metadata - for node in identity.nodes.values() + for node in identity if isinstance(node, Node))