Skip to content

Commit

Permalink
Merge pull request #51 from vhewes/feature/plotting-updates
Browse files Browse the repository at this point in the history
updates to plotting
  • Loading branch information
vhewes authored Aug 24, 2023
2 parents eaa6df8 + 586b45e commit 1a25e42
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions pynuml/plot/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pandas as pd
from torch_geometric.data import HeteroData
from torch_geometric.data import Batch, HeteroData
import plotly.express as px
from plotly.graph_objects import FigureWidget

Expand All @@ -10,13 +10,15 @@ def __init__(self,
self._planes = planes
self._classes = classes
self._labels = pd.CategoricalDtype(classes, ordered=True)
self._cmap = { c: px.colors.qualitative.Plotly[i] for i, c in enumerate(self._classes) }
self._cmap = { c: px.colors.qualitative.Plotly[i] for i, c in enumerate(classes) }
self._data = None
self._df = None

def to_dataframe(self, data: HeteroData):
def to_categorical(arr):
return pd.Categorical.from_codes(codes=arr, dtype=self._labels)
if isinstance(data, Batch):
raise Exception('to_dataframe does not support batches!')
dfs = []
for p in self._planes:
plane = data[p].to_dict()
Expand All @@ -33,13 +35,20 @@ def to_categorical(arr):
if 'x_filter' in plane.keys():
df['x_filter'] = plane['x_filter'].detach()
dfs.append(df)
return pd.concat(dfs)
df = pd.concat(dfs)
md = data['metadata']
df['run'] = md.run.item()
df['subrun'] = md.subrun.item()
df['event'] = md.event.item()
return df

def plot(self,
data: HeteroData,
target: str = 'hits',
how: str = 'none',
filter: str = 'none') -> FigureWidget:
filter: str = 'none',
width: int = None,
height: int = None) -> FigureWidget:

if data is not self._data:
self._data = data
Expand Down Expand Up @@ -127,8 +136,8 @@ def plot(self,
else:
raise Exception('"filter" must be one of "none", "true" or "pred.')

fig = px.scatter(df, x='wire', y='time', facet_col='plane', **opts)
fig.update_yaxes(matches=None)
fig = px.scatter(df, x='wire', y='time', facet_col='plane',
width=width, height=height, **opts)
fig.update_xaxes(matches=None)
for a in fig.layout.annotations:
a.text = a.text.replace('plane=', '')
Expand Down

0 comments on commit 1a25e42

Please sign in to comment.