Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to graph processing and plotting #73

Merged
merged 22 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions pynuml/plot/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self,
# which we don't have any power to fix but will presumably
# be fixed on their end at some point
warnings.filterwarnings("ignore", ".*The default of observed=False is deprecated and will be changed to True in a future version of pandas.*")
self._truth_cols = ( 'g4_id', 'parent_id', 'pdg' )

def to_dataframe(self, data: HeteroData):
def to_categorical(arr):
Expand All @@ -32,15 +33,24 @@ def to_categorical(arr):
df = pd.DataFrame(plane['id'], columns=['id'])
df['plane'] = p
df[['wire','time']] = plane['pos']
df[["x", "y", "z"]] = plane["c"]
df['y_filter'] = plane['y_semantic'] != -1
mask = df.y_filter.values
df['y_semantic'] = to_categorical(plane['y_semantic'])
df['y_instance'] = plane['y_instance'].numpy()
df['y_instance'] = plane['y_instance'].numpy().astype(str)

# add detailed truth information if it's available
for col in self._truth_cols:
if col in plane.keys():
df[col] = plane[col].numpy()

# add model prediction if it's available
if 'x_semantic' in plane.keys():
df['x_semantic'] = to_categorical(plane['x_semantic'].argmax(dim=-1).detach())
df[self._classes] = plane['x_semantic'].detach()
if 'x_filter' in plane.keys():
df['x_filter'] = plane['x_filter'].detach()

dfs.append(df)
df = pd.concat(dfs)
md = data['metadata']
Expand All @@ -54,8 +64,10 @@ def plot(self,
target: str = 'hits',
how: str = 'none',
filter: str = 'show',
xyz: bool = False,
width: int = None,
height: int = None) -> FigureWidget:
height: int = None,
title: bool = True) -> FigureWidget:

df = self.to_dataframe(data)

Expand Down Expand Up @@ -98,12 +110,16 @@ def plot(self,
'title': 'True instance labels',
'labels': { 'y_instance': 'Instance label' },
'color': 'y_instance',
'symbol': 'y_semantic',
'color_discrete_map': self._cmap,
}
elif how == 'pred':
opts = {
'title': 'Predicted instance labels',
'labels': { 'x_instance': 'Instance label' },
'color': 'x_instance',
'symbol': 'x_semantic',
'color_discrete_map': self._cmap,
}
else:
raise Exception('for instance labels, "how" must be one of "true" or "pred".')
Expand Down Expand Up @@ -148,9 +164,51 @@ def plot(self,
else:
raise Exception('"filter" must be one of "none", "show", "true" or "pred".')

fig = px.scatter(df, x='wire', y='time', facet_col='plane',
width=width, height=height, **opts)
fig.update_xaxes(matches=None)
for a in fig.layout.annotations:
a.text = a.text.replace('plane=', '')
if not title:
opts.pop('title')

# set hover data
opts['hover_data'] = {
'y_semantic': True,
'wire': ':.1f',
'time': ':.1f',
}
opts['labels'] = {
'y_filter': 'filter truth',
'y_semantic': 'semantic truth',
'y_instance': 'instance truth',
}
if 'x_filter' in df:
opts['hover_data']['x_filter'] = True
opts['labels']['x_filter'] = 'filter prediction'
if 'x_semantic' in df:
opts['hover_data']['x_semantic'] = True
opts['labels']['x_semantic'] = 'semantic prediction'
if 'x_instance' in df:
opts['hover_data']['x_instance'] = ':.4f'
opts['labels']['x_instance'] = 'instance prediction'
for col in self._truth_cols:
if col in df:
opts['hover_data'][col] = True

if xyz:
fig = px.scatter_3d(df, x="x", y="y", z="z",
width=width, height=height, **opts)
fig.update_traces(marker_size=1)
else:
fig = px.scatter(df, x='wire', y='time', facet_col='plane',
width=width, height=height, **opts)
fig.update_xaxes(matches=None)
for a in fig.layout.annotations:
a.text = a.text.replace('plane=', '')

# set the legend to horizontal
fig.update_layout(
legend_orientation='h',
legend_yanchor='bottom', legend_y=1.05,
legend_xanchor='right', legend_x=1,
margin_l=20, margin_r=20, margin_t=20, margin_b=20,
title_automargin=title,
)

return FigureWidget(fig)
57 changes: 36 additions & 21 deletions pynuml/process/hitgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,24 @@ def __init__(self,
semantic_labeller: Callable = None,
event_labeller: Callable = None,
label_vertex: bool = False,
label_position: bool = False,
planes: list[str] = ['u','v','y'],
node_pos: list[str] = ['local_wire','local_time'],
pos_norm: list[float] = [0.3,0.055],
node_feats: list[str] = ['integral','rms'],
lower_bound: int = 20,
filter_hits: bool = False):
store_detailed_truth: bool = False):

self.semantic_labeller = semantic_labeller
self.event_labeller = event_labeller
self.label_vertex = label_vertex
self.label_position = label_position
self.planes = planes
self.node_pos = node_pos
self.pos_norm = torch.tensor(pos_norm).float()
self.node_feats = node_feats
self.lower_bound = lower_bound
self.filter_hits = filter_hits
self.store_detailed_truth = store_detailed_truth

self.transform = pyg.transforms.Compose((
pyg.transforms.Delaunay(),
Expand All @@ -55,6 +57,8 @@ def columns(self) -> dict[str, list[str]]:
groups['event_table'].extend(keys)
else:
groups['event_table'] = keys
if self.label_position:
groups["hit_table"].extend(["x_position", "y_position", "z_position"])
return groups

@property
Expand All @@ -68,9 +72,6 @@ def metadata(self):

def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]:

event_id = evt.event_id
name = f'r{event_id[0]}_sr{event_id[1]}_evt{event_id[2]}'

if self.event_labeller or self.label_vertex:
event = evt['event_table'].squeeze()

Expand All @@ -84,23 +85,28 @@ def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]:
return evt.name, None

# handle energy depositions
if self.filter_hits or self.semantic_labeller:
if self.semantic_labeller:
edeps = evt['edep_table']
energy_col = 'energy' if 'energy' in edeps.columns else 'energy_fraction' # for backwards compatibility
edeps = edeps.sort_values(by=[energy_col],

# get ID of max particle
g4_id = edeps[[energy_col, 'g4_id', 'hit_id']]
g4_id = g4_id.sort_values(by=[energy_col],
ascending=False,
kind='mergesort').drop_duplicates('hit_id')
hits = edeps.merge(hits, on='hit_id', how='right')

# if we're filtering out data hits, do that
if self.filter_hits:
hitmask = hits[energy_col].isnull()
filtered_hits = hits[hitmask].hit_id.tolist()
hits = hits[~hitmask].reset_index(drop=True)
# filter spacepoints from noise
cols = [ f'hit_id_{p}' for p in self.planes ]
spmask = spacepoints[cols].isin(filtered_hits).any(axis='columns')
spacepoints = spacepoints[~spmask].reset_index(drop=True)
hits = g4_id.merge(hits, on='hit_id', how='right')

# charge-weighted average of 3D position
if self.label_position:
edeps = edeps.drop("g4_id", axis="columns")
edeps["x_position"] = edeps.x_position * edeps.energy
edeps["y_position"] = edeps.y_position * edeps.energy
edeps["z_position"] = edeps.z_position * edeps.energy
edeps = edeps.groupby("hit_id").sum()
edeps["x_position"] = edeps.x_position / edeps.energy
edeps["y_position"] = edeps.y_position / edeps.energy
edeps["z_position"] = edeps.z_position / edeps.energy
hits = edeps.merge(hits, on="hit_id", how="right")

hits['filter_label'] = ~hits[energy_col].isnull()
hits = hits.drop(energy_col, axis='columns')
Expand Down Expand Up @@ -137,9 +143,10 @@ def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]:
data = pyg.data.HeteroData()

# event metadata
data['metadata'].run = event_id[0]
data['metadata'].subrun = event_id[1]
data['metadata'].event = event_id[2]
r, sr, e = evt.event_id
data['metadata'].run = r
data['metadata'].subrun = sr
data['metadata'].event = e

# spacepoint nodes
data['sp'].num_nodes = spacepoints.shape[0]
Expand All @@ -157,6 +164,10 @@ def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]:
# node features
data[p].x = torch.tensor(plane_hits[self.node_feats].values).float()

# node true position
if self.label_position:
data[p].c = torch.tensor(plane_hits[["x_position", "y_position", "z_position"]].values).float()

# hit indices
data[p].id = torch.tensor(plane_hits['hit_id'].values).long()

Expand All @@ -175,6 +186,10 @@ def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]:
if self.semantic_labeller:
data[p].y_semantic = torch.tensor(plane_hits['semantic_label'].fillna(-1).values).long()
data[p].y_instance = torch.tensor(plane_hits['instance_label'].fillna(-1).values).long()
if self.store_detailed_truth:
data[p].g4_id = torch.tensor(plane_hits['g4_id'].fillna(-1).values).long()
data[p].parent_id = torch.tensor(plane_hits['parent_id'].fillna(-1).values).long()
data[p].pdg = torch.tensor(plane_hits['type'].fillna(-1).values).long()
if self.label_vertex:
vtx_2d = torch.tensor([ event[f'nu_vtx_wire_pos_{i}'], event.nu_vtx_wire_time ]).float()
data[p].y_vtx = vtx_2d * self.pos_norm[None,:]
Expand Down
15 changes: 15 additions & 0 deletions tests/test_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Test pynuml graph processing"""
import pynuml

def test_process_uboone():
"""Test graph processing with MicroBooNE open data release"""
f = pynuml.io.File("/raid/nugraph/uboone-opendata/uboone-opendata.evt.h5")
processor = pynuml.process.HitGraphProducer(
file=f,
semantic_labeller=pynuml.labels.StandardLabels(),
event_labeller=pynuml.labels.FlavorLabels(),
label_vertex=True)
f.read_data(0, 100)
evts = f.build_evt()
for evt in evts:
processor(evt)