Skip to content

Commit

Permalink
feat: add full type signatures to DiGraph
Browse files Browse the repository at this point in the history
This commit combines the followed commits, all by dhd:
 - feat: add full type signatures to DiGraph
 - fix: support python 3.7
 - fix: one more type annotation
   Theres always one more type annotation!
 - fix: one more type annotation (again!)
   Theres always one more type annotation!
 - fix: nope! it does not just work
 - fix: one more type annotation
   Theres always one more type annotation!
  • Loading branch information
dhdaines authored and joanise committed Sep 12, 2024
1 parent 96abff3 commit 123e27b
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 35 deletions.
4 changes: 2 additions & 2 deletions g2p/mappings/langs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def load_langs(path: str = LANGS_PKL):
return {}


def load_network(path: str = LANGS_NWORK_PATH) -> DiGraph:
def load_network(path: str = LANGS_NWORK_PATH) -> DiGraph[str]:
try:
with gzip.open(path, "rt", encoding="utf8") as f:
data = json.load(f)
Expand Down Expand Up @@ -56,7 +56,7 @@ def get_available_mappings(langs: dict) -> list:
return mappings_available


LANGS_NETWORK: DiGraph = load_network()
LANGS_NETWORK = load_network()
# Making private because it should be imported from g2p.mappings instead
_LANGS = load_langs()
LANGS_AVAILABLE = get_available_languages(_LANGS)
Expand Down
99 changes: 67 additions & 32 deletions g2p/mappings/langs/network_lite.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,57 @@
from collections import deque
from typing import Any, Iterable, Iterator


class DiGraph:
from typing import (
Any,
Dict,
Generic,
Hashable,
Iterable,
Iterator,
List,
Set,
Tuple,
TypeVar,
Union,
)

from typing_extensions import TypedDict

T = TypeVar("T", bound=Hashable)


class DiGraph(Generic[T]):
"""A simple directed graph class
Most functions raise KeyError if called with a node u or v not in the graph.
"""

def __init__(self) -> None:
"""Contructor, empty if no data, else load from data"""
self._edges: dict = {}
self._edges: Dict[T, List[T]] = {}

def clear(self):
"""Clear the graph"""
self._edges.clear()

def update(self, edges: Iterable, nodes: Iterable):
def update(self, edges: Iterable[Tuple[T, T]], nodes: Iterable[T]):
"""Update the graph with new edges and nodes"""
for node in nodes:
self.add_node(node)
for u, v in edges:
self.add_edge(u, v)

def add_node(self, u):
def add_node(self, u: T):
"""Add a node to the graph"""
if u not in self._edges:
self._edges[u] = []

def add_edge(self, u, v):
def add_edge(self, u: T, v: T):
"""Add a directed edge from u to v"""
self.add_node(u)
self.add_node(v)
if v not in self._edges[u]:
self._edges[u].append(v)

def add_edges_from(self, edges: Iterable):
def add_edges_from(self, edges: Iterable[Tuple[T, T]]):
"""Add edges from a list of tuples"""
for u, v in edges:
self.add_edge(u, v)
Expand All @@ -46,24 +62,24 @@ def nodes(self):
return self._edges.keys()

@property # read-only
def edges(self) -> Iterator:
def edges(self) -> Iterator[Tuple[T, T]]:
"""Iterate over all edges"""
for u, neighbours in self._edges.items():
for v in neighbours:
yield u, v

def __contains__(self, u) -> bool:
def __contains__(self, u: T) -> bool:
"""Check if a node is in the graph"""
return u in self._edges

def has_path(self, u, v) -> bool:
def has_path(self, u: T, v: T) -> bool:
"""Check if there is a path from u to v"""
if v not in self._edges:
raise KeyError(f"Node {v} not in graph")
visited: set = set()
visited: Set[T] = set()
return self._has_path(u, v, visited)

def _has_path(self, u, v, visited: set) -> bool:
def _has_path(self, u: T, v: T, visited: Set[T]) -> bool:
"""Helper function for has_path"""
visited.add(u)
if u == v:
Expand All @@ -74,33 +90,33 @@ def _has_path(self, u, v, visited: set) -> bool:
return True
return False

def successors(self, u) -> Iterator:
def successors(self, u: T) -> Iterator[T]:
"""Return the successors of u"""
return iter(self._edges[u])

def descendants(self, u) -> set:
def descendants(self, u: T) -> Set[T]:
"""Return the descendants of u"""
visited: set = set()
visited: Set[T] = set()
self._descendants(u, visited)
visited.remove(u)
return visited

def _descendants(self, u, visited: set):
def _descendants(self, u: T, visited: Set[T]):
"""Helper function for descendants"""
visited.add(u)
for neighbour in self._edges[u]:
if neighbour not in visited:
self._descendants(neighbour, visited)

def ancestors(self, u):
def ancestors(self, u: T) -> Set[T]:
"""Return the ancestors of u"""
reversed_graph = DiGraph()
reversed_graph: DiGraph[T] = DiGraph()
reversed_graph.add_edges_from((v, u) for u, v in self.edges)
for node in self.nodes:
reversed_graph.add_node(node)
return reversed_graph.descendants(u)

def shortest_path(self, u, v) -> list:
def shortest_path(self, u: T, v: T) -> List[T]:
"""Return the shortest path from u to v
Algorithm: Dijsktra's algorithm for unweighted graphs, which is just BFS
Expand All @@ -114,14 +130,17 @@ def shortest_path(self, u, v) -> list:

if v not in self._edges:
raise KeyError(f"Node {v} not in graph")
visited = {u: None} # dict of {node: predecessor on shortest path from u}
queue: deque[Any] = deque()
visited: Dict[T, Union[T, None]] = {
u: None
} # dict of {node: predecessor on shortest path from u}
queue: deque[T] = deque()
while True:
if u == v:
rev_path = []
while u is not None:
rev_path.append(u)
u = visited[u]
rev_path: List[T] = []
nextu: Union[T, None] = u
while nextu is not None:
rev_path.append(nextu)
nextu = visited[nextu]
return list(reversed(rev_path))
for neighbour in self._edges[u]:
if neighbour not in visited:
Expand All @@ -132,7 +151,23 @@ def shortest_path(self, u, v) -> list:
u = queue.popleft()


def node_link_graph(data: dict):
NodeDict = TypedDict("NodeDict", {"id": Any})


class NodeLinkDict(TypedDict, Generic[T]):
source: T
target: T


class NodeLinkDataDict(TypedDict, Generic[T]):
directed: bool
graph: Dict
links: List[NodeLinkDict[T]]
multigraph: bool
nodes: List[NodeDict]


def node_link_graph(data: NodeLinkDataDict[T]) -> DiGraph[T]:
"""Replacement for networkx.node_link_graph"""
if not data.get("directed", False):
raise ValueError("Graph must be directed")
Expand All @@ -143,18 +178,18 @@ def node_link_graph(data: dict):
if not isinstance(data.get("links", None), list):
raise ValueError('data["links"] must be a list')

graph = DiGraph()
graph: DiGraph[T] = DiGraph()
for node in data["nodes"]:
graph.add_node(node["id"])
for edge in data["links"]:
graph.add_edge(edge["source"], edge["target"])
return graph


def node_link_data(graph: DiGraph):
def node_link_data(graph: DiGraph[T]) -> NodeLinkDataDict[T]:
"""Replacement for networkx.node_link_data"""
nodes = [{"id": node} for node in graph.nodes]
links = [{"source": u, "target": v} for u, v in graph.edges]
nodes: List[NodeDict] = [{"id": node} for node in graph.nodes]
links: List[NodeLinkDict[T]] = [{"source": u, "target": v} for u, v in graph.edges]
return {
"directed": True,
"graph": {},
Expand Down
2 changes: 1 addition & 1 deletion g2p/mappings/langs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def cache_langs(
langs[code] = mapping_config.export_to_dict()

# Save as a Directional Graph
lang_network = DiGraph()
lang_network: DiGraph[str] = DiGraph()
lang_network.add_edges_from(mappings_legal_pairs)
write_json_gz(network_path, node_link_data(lang_network))
write_json_gz(langs_path, langs)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"regex",
"text_unidecode",
"tqdm",
"typing_extensions",
]

[project.optional-dependencies]
Expand Down

0 comments on commit 123e27b

Please sign in to comment.