Skip to content

Commit

Permalink
Some code improvements for better performance (#125)
Browse files Browse the repository at this point in the history
* Minor fixes

* Minor improvements for code style & performance

* Remove useless comments

* bump plugin versions

---------

Co-authored-by: Henning Schulze Eißing <h_schu55@uni-muenster.de>
  • Loading branch information
mhliu0001 and HenningSE authored Feb 13, 2024
1 parent 7f4eb42 commit e6aea36
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 15 deletions.
2 changes: 1 addition & 1 deletion fuse/plugins/detector_physics/s1_photon_hits.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,4 @@ def get_n_photons(self, n_photons, positions):

n_photon_hits = self.rng.binomial(n=n_photons, p=ly)

return n_photon_hits
return n_photon_hits
3 changes: 1 addition & 2 deletions fuse/plugins/detector_physics/s1_photon_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def setup(self):
)

self.pmt_mask = np.array(self.gains) > 0 # Converted from to pe (from cmt by default)
self.turned_off_pmts = np.arange(len(self.gains))[np.array(self.gains) == 0]
self.turned_off_pmts = np.nonzero(np.array(self.gains) == 0)[0]

self.spe_scaling_factor_distributions = init_spe_scaling_factor_distributions(self.photon_area_distribution)

Expand Down Expand Up @@ -307,7 +307,6 @@ def photon_timings(self,
:param local_field: local field in the point of the deposit, 1d array of floats
returns photon timing array"""
_photon_timings = np.repeat(t, n_photon_hits)
_n_hits_total = len(_photon_timings)

z_positions = np.repeat(positions[:, 2], n_photon_hits)

Expand Down
5 changes: 2 additions & 3 deletions fuse/plugins/detector_physics/s2_photon_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def setup(self):
)

self.pmt_mask = np.array(self.gains) > 0 # Converted from to pe (from cmt by default)
self.turned_off_pmts = np.arange(len(self.gains))[np.array(self.gains) == 0]
self.turned_off_pmts = np.nonzero(np.array(self.gains) == 0)[0]

self.spe_scaling_factor_distributions = init_spe_scaling_factor_distributions(self.photon_area_distribution)

Expand Down Expand Up @@ -450,8 +450,7 @@ def photon_channels(self, n_electron, z_obs, positions, drift_time_mean, n_photo

channels = np.arange(self.n_tpc_pmts).astype(np.int64)
top_index = np.arange(self.n_top_pmts)
channels_bottom = np.arange(self.n_top_pmts, self.n_tpc_pmts)
bottom_index = np.array(channels_bottom)
bottom_index = np.arange(self.n_top_pmts, self.n_tpc_pmts)

if self.diffusion_constant_transverse > 0:
pattern = self.s2_pattern_map_diffuse(n_electron, z_obs, positions, drift_time_mean) # [position, pmt]
Expand Down
7 changes: 4 additions & 3 deletions fuse/plugins/detector_physics/secondary_scintillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@export
class SecondaryScintillation(FuseBasePlugin):

__version__ = "0.1.3"
__version__ = "0.1.4"

depends_on = ("drifted_electrons","extracted_electrons" ,"electron_time")
provides = ("s2_photons", "s2_photons_sum")
Expand Down Expand Up @@ -188,9 +188,10 @@ def compute(self, interactions_in_roi, individual_electrons):
if self.s2_gain_spread:
n_photons_per_ele += self.rng.normal(0, self.s2_gain_spread, len(n_photons_per_ele)).astype(np.int64)

sum_photons_per_interaction = [np.sum(x) for x in np.split(n_photons_per_ele, np.cumsum(interactions_in_roi[mask]["n_electron_extracted"]))[:-1]]
electron_indices = np.cumsum(interactions_in_roi[mask]["n_electron_extracted"])
sum_photons_per_interaction = np.add.reduceat(n_photons_per_ele, np.r_[0, electron_indices[:-1]])

n_photons_per_ele[n_photons_per_ele < 0] = 0
n_photons_per_ele = np.clip(n_photons_per_ele, 0, None)

reorder_electrons = np.argsort(individual_electrons, order = ["order_index", "time"])

Expand Down
12 changes: 6 additions & 6 deletions fuse/plugins/micro_physics/find_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@export
class FindCluster(FuseBasePlugin):

__version__ = "0.1.2"
__version__ = "0.1.3"

depends_on = ("geant4_interactions")

Expand Down Expand Up @@ -72,13 +72,13 @@ def find_cluster(interactions, cluster_size_space, cluster_size_time):
time_cluster = simple_1d_clustering(interactions["time"], cluster_size_time)

# Splitting into time cluster and apply space clustering space:
cluster_id = np.zeros(len(interactions), dtype=np.int32)
spacial_cluster = np.zeros(len(interactions), dtype=np.int32)

_t_clusters = np.unique(time_cluster)
for _t in _t_clusters:
_cl = _find_cluster(interactions[time_cluster == _t], cluster_size_space=cluster_size_space)
spacial_cluster[time_cluster == _t] = _cl
time_cluster_mask = time_cluster == _t
_cl = _find_cluster(interactions[time_cluster_mask], cluster_size_space=cluster_size_space)
spacial_cluster[time_cluster_mask] = _cl
_, cluster_id = np.unique((time_cluster, spacial_cluster), axis=1, return_inverse=True)

return cluster_id
Expand All @@ -95,8 +95,8 @@ def _find_cluster(x, cluster_size_space):
"""
db_cluster = DBSCAN(eps=cluster_size_space, min_samples=1)

#Conversion from numpy structured array to regular array
xprime = np.array(x[['x', 'y', 'z']].tolist())
# Conversion from numpy structured array to regular array
xprime = np.stack((x['x'], x['y'], x['z']), axis=1)

return db_cluster.fit_predict(xprime)

Expand Down

0 comments on commit e6aea36

Please sign in to comment.