Skip to content

Commit

Permalink
Use same files names for peak and event level pos-rec (#1160)
Browse files Browse the repository at this point in the history
* Tidy-up styles

* Unify plugin and files name to `positions`
  • Loading branch information
dachengx authored May 11, 2023
1 parent b706a85 commit 5fa734a
Show file tree
Hide file tree
Showing 12 changed files with 26 additions and 28 deletions.
16 changes: 8 additions & 8 deletions straxen/plugins/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@
from . import veto_proximity
from .veto_proximity import *

from . import event_s2_position_mlp
from .event_s2_position_mlp import *
from . import event_s2_positions_mlp
from .event_s2_positions_mlp import *

from . import event_s2_position_cnn
from .event_s2_position_cnn import *
from . import event_s2_positions_cnn
from .event_s2_positions_cnn import *

from . import event_s2_position_gcn
from .event_s2_position_gcn import *
from . import event_s2_positions_gcn
from .event_s2_positions_gcn import *

from . import event_s1_position_cnn
from .event_s1_position_cnn import *
from . import event_s1_positions_cnn
from .event_s1_positions_cnn import *

from . import local_minimum_info
from .local_minimum_info import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import strax
import straxen
from straxen.plugins.events._event_s1_position_base import EventS1PositionBase
from straxen.plugins.events._event_s1_positions_base import EventS1PositionBase


export, __all__ = strax.exporter()
Expand All @@ -12,8 +12,8 @@ class EventS1PositionCNN(EventS1PositionBase):
CNN for (x,y,z) position S1 reconstruction at event level
"""
algorithm = "s1_cnn"
provides = "event_s1_position_cnn"
# tf_event_model_s1_cnn = straxen.URLConfig.evaluate_dry(f'tf:///project2/lgrandi/guidam/CNN_S1_XYZ_SAVED_MODELS/xnt_s1_posrec_cnn_datadriven_00_080921.tar.gz')
provides = "event_s1_positions_cnn"

tf_event_model_s1_cnn = straxen.URLConfig(
default=f'tf://'
f'resource://'
Expand All @@ -26,4 +26,4 @@ class EventS1PositionCNN(EventS1PositionBase):
help='s1 position 3d reconstruction cnn model. Should be opened using the "tf" descriptor. '
'Set to "None" to skip computation',
cache=3,
)
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import strax
import straxen
from straxen.plugins.events._event_s2_position_base import EventS2PositionBase
from straxen.plugins.events._event_s2_positions_base import EventS2PositionBase


export, __all__ = strax.exporter()
Expand All @@ -12,7 +12,7 @@ class EventS2PositionCNN(EventS2PositionBase):
CNN for position S2 reconstruction at event level
"""
algorithm = "cnn"
provides = "event_s2_position_cnn"
provides = "event_s2_positions_cnn"

tf_event_model_cnn = straxen.URLConfig(
default=f'tf://'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import strax
import straxen
from straxen.plugins.events._event_s2_position_base import EventS2PositionBase
from straxen.plugins.events._event_s2_positions_base import EventS2PositionBase


export, __all__ = strax.exporter()
Expand All @@ -12,7 +12,7 @@ class EventS2PositionGCN(EventS2PositionBase):
GCN net for position S2 reconstruction at event level
"""
algorithm = "gcn"
provides = "event_s2_position_gcn"
provides = "event_s2_positions_gcn"

tf_event_model_gcn = straxen.URLConfig(
default=f'tf://'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import strax
import straxen
from straxen.plugins.events._event_s2_position_base import EventS2PositionBase
from straxen.plugins.events._event_s2_positions_base import EventS2PositionBase


export, __all__ = strax.exporter()
Expand All @@ -12,7 +12,7 @@ class EventS2PositionMLP(EventS2PositionBase):
MLP neural net for S2 position reconstruction at event level
"""
algorithm = "mlp"
provides = "event_s2_position_mlp"
provides = "event_s2_positions_mlp"

tf_event_model_mlp = straxen.URLConfig(
default=f'tf://'
Expand Down
4 changes: 2 additions & 2 deletions straxen/plugins/peaks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from . import peak_classification_bayes
from .peak_classification_bayes import *

from . import peak_s1_position_cnn
from .peak_s1_position_cnn import *
from . import peak_s1_positions_cnn
from .peak_s1_positions_cnn import *

from . import peak_positions
from .peak_positions import *
Expand Down
1 change: 0 additions & 1 deletion straxen/plugins/peaks/peak_positions_cnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import strax
import straxen

from straxen.plugins.peaks._peak_positions_base import PeakPositionsBaseNT


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import strax
import straxen
from straxen.plugins.peaks._peak_s1_position_base import PeakS1PositionBase
from straxen.plugins.peaks._peak_s1_positions_base import PeakS1PositionBase


export, __all__ = strax.exporter()
Expand All @@ -11,11 +11,10 @@ class PeakS1PositionCNN(PeakS1PositionBase):
"""
S1 CNN for (x,y,z) position S1 reconstruction at peak level
"""
provides = "peak_s1_position_cnn"
provides = "peak_s1_positions_cnn"
algorithm = "s1_cnn"
__version__ = '0.0.1'

# tf_peak_model_s1_cnn = straxen.URLConfig.evaluate_dry(f'tf:///project2/lgrandi/guidam/CNN_S1_XYZ_SAVED_MODELS/xnt_s1_posrec_cnn_datadriven_00_080921.tar.gz')
tf_peak_model_s1_cnn = straxen.URLConfig(
default=f'tf://'
f'resource://'
Expand All @@ -28,4 +27,4 @@ class PeakS1PositionCNN(PeakS1PositionBase):
help='s1 position 3d reconstruction cnn model. Should be opened using the "tf" descriptor. '
'Set to "None" to skip computation',
cache=3,
)
)
6 changes: 3 additions & 3 deletions tests/plugins/s1_posrec_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@PluginTestAccumulator.register('test_posrec_set_path')
def test_posrec_set_path(self,
target='event_s1_position_cnn',
target='event_s1_positions_cnn',
config_name='tf_event_model_s1_cnn',
field='event_x_s1_cnn',
):
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_posrec_set_path(self,

@PluginTestAccumulator.register('test_posrec_set_to_none')
def test_posrec_set_to_none(self,
target='event_s1_position_cnn',
target='event_s1_positions_cnn',
config_name='tf_event_model_s1_cnn',
field='event_x_s1_cnn',
):
Expand All @@ -54,7 +54,7 @@ def test_posrec_set_to_none(self,

@PluginTestAccumulator.register('test_posrec_bad_configs_raising_errors')
def test_posrec_bad_configs_raising_errors(self,
target='event_s1_position_cnn',
target='event_s1_positions_cnn',
config_name='tf_event_model_s1_cnn',
):
"""Test that we get the right errors when we set invalid options"""
Expand Down

0 comments on commit 5fa734a

Please sign in to comment.