Skip to content

Commit

Permalink
adds an interactive legend for the edges
Browse files Browse the repository at this point in the history
  • Loading branch information
avivko committed May 13, 2022
1 parent c22e73e commit 9be13eb
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions graphein/protein/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ def asteroid_plot(
height: int = 500,
use_plotly: bool = True,
show_edges: bool = False,
show_legend: bool = True,
node_size_multiplier: float = 10,
) -> Union[plotly.graph_objects.Figure, matplotlib.figure.Figure]:
"""Plots a k-hop subgraph around a node as concentric shells.
Expand Down Expand Up @@ -703,8 +704,10 @@ def asteroid_plot(
:type height: int
:param use_plotly: Use plotly to render the graph. Defaults to ``True``.
:type use_plotly: bool
:param show_edges: Whether or not to show edges in the plot. Defaults to ``False``.
:param show_edges: Whether to show edges in the plot. Defaults to ``False``.
:type show_edges: bool
:param show_legend: Whether to show the legend of the edges. Fefaults to `True``.
:type show_legend: bool
:param node_size_multiplier: Multiplier for the size of the nodes. Defaults to ``10``.
:type node_size_multiplier: float.
:returns: Plotly figure or matplotlib figure.
Expand Down Expand Up @@ -735,17 +738,23 @@ def asteroid_plot(
subgraph, colour_map=edge_colour_map, colour_by=colour_edges_by,
set_alpha=edge_alpha, return_as_rgba=True
)
show_legend_bools = [(True if x not in edge_colors[:i] else False)
for i, x in enumerate(edge_colors)]
edge_trace = []
for i, (u, v) in enumerate(subgraph.edges()):
x0, y0 = subgraph.nodes[u]["pos"]
x1, y1 = subgraph.nodes[v]["pos"]
bond_kind = " / ".join(list(subgraph[u][v]["kind"]))
tr = go.Scatter(
x=(x0, x1),
y=(y0, y1),
mode="lines",
line=dict(width=1, color=edge_colors[i]),
hoverinfo="text",
text=[" / ".join(list(subgraph[u][v]["kind"]))],
text=[bond_kind],
name=bond_kind,
legendgroup=bond_kind,
showlegend=show_legend_bools[i],
)
edge_trace.append(tr)

Expand Down Expand Up @@ -778,6 +787,7 @@ def asteroid_plot(
mode="markers+text" if show_labels else "markers",
hoverinfo="text",
textposition="bottom center",
showlegend=False,
marker=dict(
colorscale="YlGnBu",
reversescale=True,
Expand All @@ -802,7 +812,13 @@ def asteroid_plot(
width=width,
height=height,
titlefont_size=16,
showlegend=False,
legend=dict(
yanchor="top",
y=1,
xanchor="left",
x=1.10
),
showlegend=True if show_legend else False,
hovermode="closest",
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(
Expand Down

0 comments on commit 9be13eb

Please sign in to comment.