Skip to content

Commit

Permalink
Merge pull request #73 from nugraph/feature/plotting-updates
Browse files Browse the repository at this point in the history
Updates to graph processing and plotting
  • Loading branch information
vhewes authored Jun 4, 2024
2 parents 4203a2d + 5fc0eab commit 85c58cb
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 28 deletions.
72 changes: 65 additions & 7 deletions pynuml/plot/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self,
# which we don't have any power to fix but will presumably
# be fixed on their end at some point
warnings.filterwarnings("ignore", ".*The default of observed=False is deprecated and will be changed to True in a future version of pandas.*")
self._truth_cols = ( 'g4_id', 'parent_id', 'pdg' )

def to_dataframe(self, data: HeteroData):
def to_categorical(arr):
Expand All @@ -32,15 +33,24 @@ def to_categorical(arr):
df = pd.DataFrame(plane['id'], columns=['id'])
df['plane'] = p
df[['wire','time']] = plane['pos']
df[["x", "y", "z"]] = plane["c"]
df['y_filter'] = plane['y_semantic'] != -1
mask = df.y_filter.values
df['y_semantic'] = to_categorical(plane['y_semantic'])
df['y_instance'] = plane['y_instance'].numpy()
df['y_instance'] = plane['y_instance'].numpy().astype(str)

# add detailed truth information if it's available
for col in self._truth_cols:
if col in plane.keys():
df[col] = plane[col].numpy()

# add model prediction if it's available
if 'x_semantic' in plane.keys():
df['x_semantic'] = to_categorical(plane['x_semantic'].argmax(dim=-1).detach())
df[self._classes] = plane['x_semantic'].detach()
if 'x_filter' in plane.keys():
df['x_filter'] = plane['x_filter'].detach()

dfs.append(df)
df = pd.concat(dfs)
md = data['metadata']
Expand All @@ -54,8 +64,10 @@ def plot(self,
target: str = 'hits',
how: str = 'none',
filter: str = 'show',
xyz: bool = False,
width: int = None,
height: int = None) -> FigureWidget:
height: int = None,
title: bool = True) -> FigureWidget:

df = self.to_dataframe(data)

Expand Down Expand Up @@ -98,12 +110,16 @@ def plot(self,
'title': 'True instance labels',
'labels': { 'y_instance': 'Instance label' },
'color': 'y_instance',
'symbol': 'y_semantic',
'color_discrete_map': self._cmap,
}
elif how == 'pred':
opts = {
'title': 'Predicted instance labels',
'labels': { 'x_instance': 'Instance label' },
'color': 'x_instance',
'symbol': 'x_semantic',
'color_discrete_map': self._cmap,
}
else:
raise Exception('for instance labels, "how" must be one of "true" or "pred".')
Expand Down Expand Up @@ -148,9 +164,51 @@ def plot(self,
else:
raise Exception('"filter" must be one of "none", "show", "true" or "pred".')

fig = px.scatter(df, x='wire', y='time', facet_col='plane',
width=width, height=height, **opts)
fig.update_xaxes(matches=None)
for a in fig.layout.annotations:
a.text = a.text.replace('plane=', '')
if not title:
opts.pop('title')

# set hover data
opts['hover_data'] = {
'y_semantic': True,
'wire': ':.1f',
'time': ':.1f',
}
opts['labels'] = {
'y_filter': 'filter truth',
'y_semantic': 'semantic truth',
'y_instance': 'instance truth',
}
if 'x_filter' in df:
opts['hover_data']['x_filter'] = True
opts['labels']['x_filter'] = 'filter prediction'
if 'x_semantic' in df:
opts['hover_data']['x_semantic'] = True
opts['labels']['x_semantic'] = 'semantic prediction'
if 'x_instance' in df:
opts['hover_data']['x_instance'] = ':.4f'
opts['labels']['x_instance'] = 'instance prediction'
for col in self._truth_cols:
if col in df:
opts['hover_data'][col] = True

if xyz:
fig = px.scatter_3d(df, x="x", y="y", z="z",
width=width, height=height, **opts)
fig.update_traces(marker_size=1)
else:
fig = px.scatter(df, x='wire', y='time', facet_col='plane',
width=width, height=height, **opts)
fig.update_xaxes(matches=None)
for a in fig.layout.annotations:
a.text = a.text.replace('plane=', '')

# set the legend to horizontal
fig.update_layout(
legend_orientation='h',
legend_yanchor='bottom', legend_y=1.05,
legend_xanchor='right', legend_x=1,
margin_l=20, margin_r=20, margin_t=20, margin_b=20,
title_automargin=title,
)

return FigureWidget(fig)
57 changes: 36 additions & 21 deletions pynuml/process/hitgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,24 @@ def __init__(self,
semantic_labeller: Callable = None,
event_labeller: Callable = None,
label_vertex: bool = False,
label_position: bool = False,
planes: list[str] = ['u','v','y'],
node_pos: list[str] = ['local_wire','local_time'],
pos_norm: list[float] = [0.3,0.055],
node_feats: list[str] = ['integral','rms'],
lower_bound: int = 20,
filter_hits: bool = False):
store_detailed_truth: bool = False):

self.semantic_labeller = semantic_labeller
self.event_labeller = event_labeller
self.label_vertex = label_vertex
self.label_position = label_position
self.planes = planes
self.node_pos = node_pos
self.pos_norm = torch.tensor(pos_norm).float()
self.node_feats = node_feats
self.lower_bound = lower_bound
self.filter_hits = filter_hits
self.store_detailed_truth = store_detailed_truth

self.transform = pyg.transforms.Compose((
pyg.transforms.Delaunay(),
Expand All @@ -55,6 +57,8 @@ def columns(self) -> dict[str, list[str]]:
groups['event_table'].extend(keys)
else:
groups['event_table'] = keys
if self.label_position:
groups["hit_table"].extend(["x_position", "y_position", "z_position"])
return groups

@property
Expand All @@ -68,9 +72,6 @@ def metadata(self):

def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]:

event_id = evt.event_id
name = f'r{event_id[0]}_sr{event_id[1]}_evt{event_id[2]}'

if self.event_labeller or self.label_vertex:
event = evt['event_table'].squeeze()

Expand All @@ -84,23 +85,28 @@ def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]:
return evt.name, None

# handle energy depositions
if self.filter_hits or self.semantic_labeller:
if self.semantic_labeller:
edeps = evt['edep_table']
energy_col = 'energy' if 'energy' in edeps.columns else 'energy_fraction' # for backwards compatibility
edeps = edeps.sort_values(by=[energy_col],

# get ID of max particle
g4_id = edeps[[energy_col, 'g4_id', 'hit_id']]
g4_id = g4_id.sort_values(by=[energy_col],
ascending=False,
kind='mergesort').drop_duplicates('hit_id')
hits = edeps.merge(hits, on='hit_id', how='right')

# if we're filtering out data hits, do that
if self.filter_hits:
hitmask = hits[energy_col].isnull()
filtered_hits = hits[hitmask].hit_id.tolist()
hits = hits[~hitmask].reset_index(drop=True)
# filter spacepoints from noise
cols = [ f'hit_id_{p}' for p in self.planes ]
spmask = spacepoints[cols].isin(filtered_hits).any(axis='columns')
spacepoints = spacepoints[~spmask].reset_index(drop=True)
hits = g4_id.merge(hits, on='hit_id', how='right')

# charge-weighted average of 3D position
if self.label_position:
edeps = edeps.drop("g4_id", axis="columns")
edeps["x_position"] = edeps.x_position * edeps.energy
edeps["y_position"] = edeps.y_position * edeps.energy
edeps["z_position"] = edeps.z_position * edeps.energy
edeps = edeps.groupby("hit_id").sum()
edeps["x_position"] = edeps.x_position / edeps.energy
edeps["y_position"] = edeps.y_position / edeps.energy
edeps["z_position"] = edeps.z_position / edeps.energy
hits = edeps.merge(hits, on="hit_id", how="right")

hits['filter_label'] = ~hits[energy_col].isnull()
hits = hits.drop(energy_col, axis='columns')
Expand Down Expand Up @@ -137,9 +143,10 @@ def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]:
data = pyg.data.HeteroData()

# event metadata
data['metadata'].run = event_id[0]
data['metadata'].subrun = event_id[1]
data['metadata'].event = event_id[2]
r, sr, e = evt.event_id
data['metadata'].run = r
data['metadata'].subrun = sr
data['metadata'].event = e

# spacepoint nodes
data['sp'].num_nodes = spacepoints.shape[0]
Expand All @@ -157,6 +164,10 @@ def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]:
# node features
data[p].x = torch.tensor(plane_hits[self.node_feats].values).float()

# node true position
if self.label_position:
data[p].c = torch.tensor(plane_hits[["x_position", "y_position", "z_position"]].values).float()

# hit indices
data[p].id = torch.tensor(plane_hits['hit_id'].values).long()

Expand All @@ -175,6 +186,10 @@ def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]:
if self.semantic_labeller:
data[p].y_semantic = torch.tensor(plane_hits['semantic_label'].fillna(-1).values).long()
data[p].y_instance = torch.tensor(plane_hits['instance_label'].fillna(-1).values).long()
if self.store_detailed_truth:
data[p].g4_id = torch.tensor(plane_hits['g4_id'].fillna(-1).values).long()
data[p].parent_id = torch.tensor(plane_hits['parent_id'].fillna(-1).values).long()
data[p].pdg = torch.tensor(plane_hits['type'].fillna(-1).values).long()
if self.label_vertex:
vtx_2d = torch.tensor([ event[f'nu_vtx_wire_pos_{i}'], event.nu_vtx_wire_time ]).float()
data[p].y_vtx = vtx_2d * self.pos_norm[None,:]
Expand Down
15 changes: 15 additions & 0 deletions tests/test_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Test pynuml graph processing"""
import pynuml

def test_process_uboone():
"""Test graph processing with MicroBooNE open data release"""
f = pynuml.io.File("/raid/nugraph/uboone-opendata/uboone-opendata.evt.h5")
processor = pynuml.process.HitGraphProducer(
file=f,
semantic_labeller=pynuml.labels.StandardLabels(),
event_labeller=pynuml.labels.FlavorLabels(),
label_vertex=True)
f.read_data(0, 100)
evts = f.build_evt()
for evt in evts:
processor(evt)

0 comments on commit 85c58cb

Please sign in to comment.