From 2deb8afd509e93c8d5f38f101768e72d669b41fd Mon Sep 17 00:00:00 2001 From: v hewes Date: Thu, 24 Aug 2023 10:30:00 -0400 Subject: [PATCH 1/5] tiny syntax tweak --- pynuml/plot/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynuml/plot/graph.py b/pynuml/plot/graph.py index 06781a3..8d551ad 100644 --- a/pynuml/plot/graph.py +++ b/pynuml/plot/graph.py @@ -10,7 +10,7 @@ def __init__(self, self._planes = planes self._classes = classes self._labels = pd.CategoricalDtype(classes, ordered=True) - self._cmap = { c: px.colors.qualitative.Plotly[i] for i, c in enumerate(self._classes) } + self._cmap = { c: px.colors.qualitative.Plotly[i] for i, c in enumerate(classes) } self._data = None self._df = None From 56a2ae2ad37a809a2b3649a9f62acd5a5b2ce535 Mon Sep 17 00:00:00 2001 From: v hewes Date: Thu, 24 Aug 2023 10:31:23 -0400 Subject: [PATCH 2/5] add check for casting batch to dataframe --- pynuml/plot/graph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pynuml/plot/graph.py b/pynuml/plot/graph.py index 8d551ad..dac5831 100644 --- a/pynuml/plot/graph.py +++ b/pynuml/plot/graph.py @@ -1,5 +1,5 @@ import pandas as pd -from torch_geometric.data import HeteroData +from torch_geometric.data import Batch, HeteroData import plotly.express as px from plotly.graph_objects import FigureWidget @@ -17,6 +17,8 @@ def __init__(self, def to_dataframe(self, data: HeteroData): def to_categorical(arr): return pd.Categorical.from_codes(codes=arr, dtype=self._labels) + if isinstance(data, Batch): + raise Exception('to_dataframe does not support batches!') dfs = [] for p in self._planes: plane = data[p].to_dict() From 98e472bbde00d616793daa6c8cd26ddc948837b9 Mon Sep 17 00:00:00 2001 From: v hewes Date: Thu, 24 Aug 2023 10:31:46 -0400 Subject: [PATCH 3/5] add event metadata to dataframe --- pynuml/plot/graph.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pynuml/plot/graph.py b/pynuml/plot/graph.py index dac5831..8dc8ca4 100644 --- a/pynuml/plot/graph.py +++ b/pynuml/plot/graph.py @@ -35,7 +35,12 @@ def to_categorical(arr): if 'x_filter' in plane.keys(): df['x_filter'] = plane['x_filter'].detach() dfs.append(df) - return pd.concat(dfs) + df = pd.concat(dfs) + md = data['metadata'] + df['run'] = md.run.item() + df['subrun'] = md.subrun.item() + df['event'] = md.event.item() + return df def plot(self, data: HeteroData, From 693cd08bc5baa3b1b35beab0204e06af01fbd7b1 Mon Sep 17 00:00:00 2001 From: v hewes Date: Thu, 24 Aug 2023 10:33:08 -0400 Subject: [PATCH 4/5] configure graph plot size --- pynuml/plot/graph.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pynuml/plot/graph.py b/pynuml/plot/graph.py index 8dc8ca4..c914afa 100644 --- a/pynuml/plot/graph.py +++ b/pynuml/plot/graph.py @@ -46,7 +46,9 @@ def plot(self, data: HeteroData, target: str = 'hits', how: str = 'none', - filter: str = 'none') -> FigureWidget: + filter: str = 'none', + width: int = None, + height: int = None) -> FigureWidget: if data is not self._data: self._data = data @@ -134,7 +136,8 @@ def plot(self, else: raise Exception('"filter" must be one of "none", "true" or "pred.') - fig = px.scatter(df, x='wire', y='time', facet_col='plane', **opts) + fig = px.scatter(df, x='wire', y='time', facet_col='plane', + width=width, height=height, **opts) fig.update_yaxes(matches=None) fig.update_xaxes(matches=None) for a in fig.layout.annotations: From 586b45e4a523a273ba5e9cc912bfef6a8bea5bd5 Mon Sep 17 00:00:00 2001 From: v hewes Date: Thu, 24 Aug 2023 10:33:25 -0400 Subject: [PATCH 5/5] use same axis range in time dimension --- pynuml/plot/graph.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pynuml/plot/graph.py b/pynuml/plot/graph.py index c914afa..8b35470 100644 --- a/pynuml/plot/graph.py +++ b/pynuml/plot/graph.py @@ -138,7 +138,6 @@ def plot(self, fig = px.scatter(df, x='wire', y='time', facet_col='plane', width=width, height=height, **opts) - fig.update_yaxes(matches=None) fig.update_xaxes(matches=None) for a in fig.layout.annotations: a.text = a.text.replace('plane=', '')