Skip to content

Commit

Permalink
feat: network_lite with minimal DiGraph class
Browse files Browse the repository at this point in the history
This DiGraph class is meant to replace our use of networkx.
While networkx is a great library, it loads really a lot of stuff and
takes about 30MB of RAM just upon doing `import networkx`. We use maybe
a hundred lines of code from this library, so let's rewrite our own
directed graph class with just the algorithms we use.
  • Loading branch information
joanise committed Sep 12, 2024
1 parent aa9de1c commit c70f30f
Show file tree
Hide file tree
Showing 3 changed files with 327 additions and 1 deletion.
164 changes: 164 additions & 0 deletions g2p/mappings/langs/network_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from collections import deque
from typing import Any, Iterable, Iterator


class DiGraph:
"""A simple directed graph class
Most functions raise KeyError if called with a node u or v not in the graph.
"""

def __init__(self) -> None:
"""Contructor, empty if no data, else load from data"""
self._edges: dict = {}

def clear(self):
"""Clear the graph"""
self._edges.clear()

def update(self, edges: Iterable, nodes: Iterable):
"""Update the graph with new edges and nodes"""
for node in nodes:
self.add_node(node)
for u, v in edges:
self.add_edge(u, v)

def add_node(self, u):
"""Add a node to the graph"""
if u not in self._edges:
self._edges[u] = []

def add_edge(self, u, v):
"""Add a directed edge from u to v"""
self.add_node(u)
self.add_node(v)
if v not in self._edges[u]:
self._edges[u].append(v)

def add_edges_from(self, edges: Iterable):
"""Add edges from a list of tuples"""
for u, v in edges:
self.add_edge(u, v)

@property # read-only
def nodes(self):
"""Return the nodes"""
return self._edges.keys()

@property # read-only
def edges(self) -> Iterator:
"""Iterate over all edges"""
for u, neighbours in self._edges.items():
for v in neighbours:
yield u, v

def __contains__(self, u) -> bool:
"""Check if a node is in the graph"""
return u in self._edges

def has_path(self, u, v) -> bool:
"""Check if there is a path from u to v"""
if v not in self._edges:
raise KeyError(f"Node {v} not in graph")
visited: set = set()
return self._has_path(u, v, visited)

def _has_path(self, u, v, visited: set) -> bool:
"""Helper function for has_path"""
visited.add(u)
if u == v:
return True
for neighbour in self._edges[u]:
if neighbour not in visited:
if self._has_path(neighbour, v, visited):
return True
return False

def successors(self, u) -> Iterator:
"""Return the successors of u"""
return iter(self._edges[u])

def descendants(self, u) -> set:
"""Return the descendants of u"""
visited: set = set()
self._descendants(u, visited)
visited.remove(u)
return visited

def _descendants(self, u, visited: set):
"""Helper function for descendants"""
visited.add(u)
for neighbour in self._edges[u]:
if neighbour not in visited:
self._descendants(neighbour, visited)

def ancestors(self, u):
"""Return the ancestors of u"""
reversed_graph = DiGraph()
reversed_graph.add_edges_from((v, u) for u, v in self.edges)
for node in self.nodes:
reversed_graph.add_node(node)
return reversed_graph.descendants(u)

def shortest_path(self, u, v) -> list:
"""Return the shortest path from u to v
Algorithm: Dijsktra's algorithm for unweighted graphs, which is just BFS
Returns:
list: the shortest path from u to v
Raises:
KeyError: if u or v is not in the graph
ValueError: if there is no path from u to v
"""

if v not in self._edges:
raise KeyError(f"Node {v} not in graph")
visited = {u: None} # dict of {node: predecessor on shortest path from u}
queue: deque[Any] = deque()
while True:
if u == v:
rev_path = []
while u is not None:
rev_path.append(u)
u = visited[u]
return list(reversed(rev_path))
for neighbour in self._edges[u]:
if neighbour not in visited:
visited[neighbour] = u
queue.append(neighbour)
if len(queue) == 0:
raise ValueError(f"No path from {u} to {v}")
u = queue.popleft()


def node_link_graph(data: dict):
"""Replacement for networkx.node_link_graph"""
if not data.get("directed", False):
raise ValueError("Graph must be directed")
if data.get("multigraph", True):
raise ValueError("Graph must not be a multigraph")
if not isinstance(data.get("nodes", None), list):
raise ValueError('data["nodes"] must be a list')
if not isinstance(data.get("links", None), list):
raise ValueError('data["links"] must be a list')

graph = DiGraph()
for node in data["nodes"]:
graph.add_node(node["id"])
for edge in data["links"]:
graph.add_edge(edge["source"], edge["target"])
return graph


def node_link_data(graph: DiGraph):
"""Replacement for networkx.node_link_data"""
nodes = [{"id": node} for node in graph.nodes]
links = [{"source": u, "target": v} for u, v in graph.edges]
return {
"directed": True,
"graph": {},
"links": links,
"multigraph": False,
"nodes": nodes,
}
3 changes: 2 additions & 1 deletion g2p/tests/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from g2p.tests.test_langs import LangTest
from g2p.tests.test_lexicon_transducer import LexiconTransducerTest
from g2p.tests.test_mappings import MappingTest
from g2p.tests.test_network import NetworkTest
from g2p.tests.test_network import NetworkLiteTest, NetworkTest
from g2p.tests.test_tokenize_and_map import TokenizeAndMapTest
from g2p.tests.test_tokenizer import TokenizerTest
from g2p.tests.test_transducer import TransducerTest
Expand Down Expand Up @@ -60,6 +60,7 @@
MappingCreationTest,
MappingTest,
NetworkTest,
NetworkLiteTest,
UtilsTest,
TokenizerTest,
TokenizeAndMapTest,
Expand Down
161 changes: 161 additions & 0 deletions g2p/tests/test_network.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
#!/usr/bin/env python

import gzip
import json
from unittest import TestCase, main

from g2p import make_g2p
from g2p.exceptions import InvalidLanguageCode, NoPath
from g2p.log import LOGGER
from g2p.mappings.langs import LANGS_NWORK_PATH
from g2p.mappings.langs.network_lite import DiGraph, node_link_data, node_link_graph
from g2p.transducer import CompositeTransducer, Transducer


Expand Down Expand Up @@ -37,5 +41,162 @@ def test_valid_transducer(self):
self.assertEqual("niɡiɡw", transducer("nikikw").output_string)


class NetworkLiteTest(TestCase):
@classmethod
def setUpClass(cls):
with gzip.open(LANGS_NWORK_PATH, "rt", encoding="utf8") as f:
cls.data = json.load(f)

def test_has_path(self):
graph = DiGraph()
graph.add_edge("a", "b")
graph.add_edge("b", "a") # cycle
graph.add_edge("a", "c")
graph.add_edge("c", "d")
graph.add_edge("e", "f")
self.assertTrue(graph.has_path("a", "c"))
self.assertTrue(graph.has_path("a", "d"))
self.assertTrue(graph.has_path("b", "a"))
self.assertFalse(graph.has_path("a", "e"))
self.assertFalse(graph.has_path("a", "f"))
self.assertFalse(graph.has_path("c", "a"))
with self.assertRaises(KeyError):
graph.has_path("a", "y")
with self.assertRaises(KeyError):
graph.has_path("x", "b")

def test_g2p_path(self):
graph = node_link_graph(self.data)
self.assertTrue(graph.has_path("atj", "eng-ipa"))
self.assertTrue(graph.has_path("atj", "atj-ipa"))
self.assertFalse(graph.has_path("hei", "git"))

def test_successors(self):
graph = DiGraph()
graph.add_edge("a", "b")
graph.add_edge("b", "a")
graph.add_edge("a", "c")
self.assertEqual(set(graph.successors("a")), {"b", "c"})
self.assertEqual(set(graph.successors("b")), {"a"})
self.assertEqual(set(graph.successors("c")), set())

def test_descendants(self):
graph = DiGraph()
graph.add_edge("a", "b")
graph.add_edge("b", "a") # cycle
graph.add_edge("a", "c")
graph.add_edge("c", "d")
graph.add_edge("e", "f")
self.assertEqual(graph.descendants("a"), {"b", "c", "d"})
self.assertEqual(graph.descendants("b"), {"a", "c", "d"})
self.assertEqual(graph.descendants("c"), {"d"})
self.assertEqual(graph.descendants("d"), set())
self.assertEqual(graph.descendants("e"), {"f"})
self.assertEqual(graph.descendants("f"), set())
with self.assertRaises(KeyError):
graph.descendants("x")

def test_g2p_descendants(self):
graph = node_link_graph(self.data)
self.assertEqual(
graph.descendants("atj"), {"eng-ipa", "atj-ipa", "eng-arpabet"}
)
self.assertEqual(graph.descendants("eng-ipa"), {"eng-arpabet"})
self.assertEqual(graph.descendants("atj-ipa"), {"eng-ipa", "eng-arpabet"})
self.assertEqual(graph.descendants("eng-arpabet"), set())

def test_ancestors(self):
graph = DiGraph()
graph.add_edge("a", "b")
graph.add_edge("a", "c")
graph.add_edge("d", "a") # cycle
graph.add_edge("c", "d")
graph.add_edge("e", "f")
self.assertEqual(graph.ancestors("a"), {"c", "d"})
self.assertEqual(graph.ancestors("b"), {"a", "d", "c"})
self.assertEqual(graph.ancestors("c"), {"a", "d"})
self.assertEqual(graph.ancestors("d"), {"a", "c"})
self.assertEqual(graph.ancestors("e"), set())
self.assertEqual(graph.ancestors("f"), {"e"})
with self.assertRaises(KeyError):
graph.ancestors("x")

def test_g2p_ancestors(self):
graph = node_link_graph(self.data)
self.assertEqual(graph.ancestors("atj"), set())
self.assertGreater(len(graph.ancestors("eng-ipa")), 50)

def test_shortest_path(self):
graph = DiGraph()
graph.add_edge("a", "e")
graph.add_edge("e", "f")
graph.add_edge("f", "g")
graph.add_edge("g", "d")
graph.add_edge("f", "d")
graph.add_edge("a", "b")
graph.add_edge("b", "a") # Cycle
graph.add_edge("a", "c")
graph.add_edge("c", "d")
graph.add_edge("a", "d")
graph.add_edge("b", "d")
self.assertEqual(graph.shortest_path("a", "d"), ["a", "d"])
self.assertEqual(graph.shortest_path("c", "d"), ["c", "d"])
self.assertEqual(graph.shortest_path("a", "a"), ["a"])
with self.assertRaises(ValueError):
graph.shortest_path("c", "a")
with self.assertRaises(KeyError):
graph.shortest_path("a", "y")
with self.assertRaises(KeyError):
graph.shortest_path("x", "b")

def test_g2p_shortest_path(self):
graph = node_link_graph(self.data)
self.assertEqual(
graph.shortest_path("atj", "eng-arpabet"),
["atj", "atj-ipa", "eng-ipa", "eng-arpabet"],
)

def test_contains(self):
graph = DiGraph()
graph.add_edge("a", "b")
self.assertTrue("a" in graph)
self.assertTrue("b" in graph)
self.assertFalse("c" in graph)

def test_node_link_data(self):
graph = node_link_graph(self.data)
self.assertEqual(node_link_data(graph), self.data)

def test_node_link_graph_errors(self):
with self.assertRaises(ValueError):
node_link_graph({**self.data, "directed": False})
with self.assertRaises(ValueError):
node_link_graph({**self.data, "multigraph": True})
with self.assertRaises(ValueError):
node_link_graph({**self.data, "nodes": "not a list"})
with self.assertRaises(ValueError):
node_link_graph({**self.data, "links": "not a list"})
with self.assertRaises(ValueError):
data = self.data.copy()
del data["nodes"]
node_link_graph(data)
with self.assertRaises(ValueError):
data = self.data.copy()
del data["links"]
node_link_graph(data)

def test_no_duplicates(self):
graph = DiGraph()
graph.add_edge("a", "b")
graph.add_edge("b", "c")
graph.add_edge("a", "c")
graph.add_edge("a", "b")
self.assertEqual(len(list(graph.edges)), 3)
self.assertEqual(len(graph.nodes), 3)
self.assertEqual(len(list(graph.successors("a"))), 2)
self.assertEqual(len(list(graph.successors("b"))), 1)
self.assertEqual(len(list(graph.successors("c"))), 0)


if __name__ == "__main__":
main()

0 comments on commit c70f30f

Please sign in to comment.