diff --git a/OTAnalytics/adapter_ui/abstract_frame_track_statistics.py b/OTAnalytics/adapter_ui/abstract_frame_track_statistics.py new file mode 100644 index 000000000..93e1d6172 --- /dev/null +++ b/OTAnalytics/adapter_ui/abstract_frame_track_statistics.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + +from OTAnalytics.application.use_cases.track_statistics import TrackStatistics + + +class AbstractFrameTrackStatistics(ABC): + @abstractmethod + def introduce_to_viewmodel(self) -> None: + raise NotImplementedError + + @abstractmethod + def update_track_statistics(self, track_statistics: TrackStatistics) -> None: + raise NotImplementedError diff --git a/OTAnalytics/adapter_ui/view_model.py b/OTAnalytics/adapter_ui/view_model.py index 9a6bfef72..6dd535920 100644 --- a/OTAnalytics/adapter_ui/view_model.py +++ b/OTAnalytics/adapter_ui/view_model.py @@ -17,12 +17,15 @@ from OTAnalytics.adapter_ui.abstract_frame_track_plotting import ( AbstractFrameTrackPlotting, ) +from OTAnalytics.adapter_ui.abstract_frame_track_statistics import ( + AbstractFrameTrackStatistics, +) from OTAnalytics.adapter_ui.abstract_frame_tracks import AbstractFrameTracks from OTAnalytics.adapter_ui.abstract_main_window import AbstractMainWindow from OTAnalytics.adapter_ui.abstract_treeview_interface import AbstractTreeviewInterface from OTAnalytics.adapter_ui.text_resources import ColumnResources from OTAnalytics.domain.date import DateRange -from OTAnalytics.domain.flow import Flow +from OTAnalytics.domain.flow import Flow, FlowId from OTAnalytics.domain.section import Section from OTAnalytics.domain.video import Video @@ -85,6 +88,10 @@ def set_filter_frame(self, filter_frame: AbstractFrameFilter) -> None: def set_frame_project(self, project_frame: AbstractFrameProject) -> None: pass + @abstractmethod + def set_frame_track_statistics(self, frame: AbstractFrameTrackStatistics) -> None: + pass + @abstractmethod def set_button_quick_save_config( self, button_quick_save_config: AbstractButtonQuickSaveConfig @@ -426,3 +433,7 @@ def set_svz_metadata_frame(self, frame: AbstractFrameSvzMetadata) -> None: @abstractmethod def get_save_path_suggestion(self, file_type: str, context_file_type: str) -> Path: raise NotImplementedError + + @abstractmethod + def get_tracks_assigned_to_each_flow(self) -> dict[FlowId, int]: + raise NotImplementedError diff --git a/OTAnalytics/application/application.py b/OTAnalytics/application/application.py index 22f48adb7..6513f01b2 100644 --- a/OTAnalytics/application/application.py +++ b/OTAnalytics/application/application.py @@ -36,6 +36,9 @@ EnableFilterTrackByDate, ) from OTAnalytics.application.use_cases.flow_repository import AddFlow +from OTAnalytics.application.use_cases.flow_statistics import ( + NumberOfTracksAssignedToEachFlow, +) from OTAnalytics.application.use_cases.generate_flows import GenerateFlows from OTAnalytics.application.use_cases.load_otconfig import LoadOtconfig from OTAnalytics.application.use_cases.load_otflow import LoadOtflow @@ -59,6 +62,10 @@ GetAllTrackFiles, TrackRepositorySize, ) +from OTAnalytics.application.use_cases.track_statistics import ( + CalculateTrackStatistics, + TrackStatistics, +) from OTAnalytics.application.use_cases.update_project import ProjectUpdater from OTAnalytics.domain.date import DateRange from OTAnalytics.domain.filter import FilterElement, FilterElementSettingRestorer @@ -131,6 +138,8 @@ def __init__( config_has_changed: ConfigHasChanged, export_road_user_assignments: ExportRoadUserAssignments, file_name_suggester: SavePathSuggester, + calculate_track_statistics: CalculateTrackStatistics, + number_of_tracks_assigned_to_each_flow: NumberOfTracksAssignedToEachFlow, ) -> None: self._datastore: Datastore = datastore self.track_state: TrackState = track_state @@ -171,6 +180,10 @@ def __init__( self._config_has_changed = config_has_changed self._export_road_user_assignments = export_road_user_assignments self._file_name_suggester = file_name_suggester + self._calculate_track_statistics = calculate_track_statistics + self._number_of_tracks_assigned_to_each_flow = ( + number_of_tracks_assigned_to_each_flow + ) def connect_observers(self) -> None: """ @@ -667,6 +680,12 @@ def suggest_save_path(self, file_type: str, context_file_type: str = "") -> Path """ return self._file_name_suggester.suggest(file_type, context_file_type) + def calculate_track_statistics(self) -> TrackStatistics: + return self._calculate_track_statistics.get_statistics() + + def number_of_tracks_assigned_to_each_flow(self) -> dict[FlowId, int]: + return self._number_of_tracks_assigned_to_each_flow.get() + class MissingTracksError(Exception): pass diff --git a/OTAnalytics/application/use_cases/flow_statistics.py b/OTAnalytics/application/use_cases/flow_statistics.py new file mode 100644 index 000000000..ae747c9e6 --- /dev/null +++ b/OTAnalytics/application/use_cases/flow_statistics.py @@ -0,0 +1,25 @@ +from OTAnalytics.application.use_cases.get_road_user_assignments import ( + GetRoadUserAssignments, +) +from OTAnalytics.domain.flow import FlowId, FlowRepository + + +class NumberOfTracksAssignedToEachFlow: + def __init__( + self, get_assignments: GetRoadUserAssignments, flow_repository: FlowRepository + ) -> None: + self._get_assignments = get_assignments + self._flow_repository = flow_repository + + def get(self) -> dict[FlowId, int]: + result = self._tracks_assigned_to_flows() + for road_user_assignment in self._get_assignments.get(): + flow_id = road_user_assignment.assignment.id + result[flow_id] += 1 + return result + + def _tracks_assigned_to_flows(self) -> dict[FlowId, int]: + flows = {} + for flow in self._flow_repository.get_all(): + flows[flow.id] = 0 + return flows diff --git a/OTAnalytics/application/use_cases/get_road_user_assignments.py b/OTAnalytics/application/use_cases/get_road_user_assignments.py new file mode 100644 index 000000000..0577ae853 --- /dev/null +++ b/OTAnalytics/application/use_cases/get_road_user_assignments.py @@ -0,0 +1,23 @@ +from OTAnalytics.application.analysis.traffic_counting import ( + RoadUserAssigner, + RoadUserAssignment, +) +from OTAnalytics.domain.event import EventRepository +from OTAnalytics.domain.flow import FlowRepository + + +class GetRoadUserAssignments: + def __init__( + self, + flow_repository: FlowRepository, + event_repository: EventRepository, + assigner: RoadUserAssigner, + ) -> None: + self._flow_repository = flow_repository + self._event_repository = event_repository + self._assigner = assigner + + def get(self) -> list[RoadUserAssignment]: + return self._assigner.assign( + self._event_repository.get_all(), self._flow_repository.get_all() + ).as_list() diff --git a/OTAnalytics/application/use_cases/highlight_intersections.py b/OTAnalytics/application/use_cases/highlight_intersections.py index 93fa95dc1..4dc26b1e5 100644 --- a/OTAnalytics/application/use_cases/highlight_intersections.py +++ b/OTAnalytics/application/use_cases/highlight_intersections.py @@ -5,7 +5,11 @@ from OTAnalytics.application.analysis.intersect import TracksIntersectingSections from OTAnalytics.application.analysis.traffic_counting import RoadUserAssigner from OTAnalytics.application.state import FlowState, SectionState, TrackViewState -from OTAnalytics.application.use_cases.section_repository import GetSectionsById +from OTAnalytics.application.use_cases.section_repository import ( + GetAllSections, + GetCuttingSections, + GetSectionsById, +) from OTAnalytics.domain.event import EventRepository from OTAnalytics.domain.flow import FlowId, FlowRepository from OTAnalytics.domain.section import SectionId @@ -59,6 +63,76 @@ def get_ids(self) -> set[TrackId]: ).get_ids() +class TracksIntersectingAllNonCuttingSections(TrackIdProvider): + """Returns track ids intersecting all sections which are not a cutting section. + + Args: + get_all_sections (GetAllSections): the use case to get all sections. + tracks_intersecting_sections (TracksIntersectingSections): get track ids + intersecting sections. + get_section_by_id (GetSectionsById): use case to get sections by id. + """ + + def __init__( + self, + get_cutting_sections: GetCuttingSections, + get_all_sections: GetAllSections, + tracks_intersecting_sections: TracksIntersectingSections, + get_section_by_id: GetSectionsById, + intersection_repository: IntersectionRepository, + ) -> None: + self._get_cutting_sections = get_cutting_sections + self._get_all_sections = get_all_sections + self._tracks_intersecting_sections = tracks_intersecting_sections + self._get_section_by_id = get_section_by_id + self._intersection_repository = intersection_repository + + def get_ids(self) -> set[TrackId]: + ids_non_cutting_sections = { + section.id + for section in self._get_all_sections() + if section not in self._get_cutting_sections() + } + return TracksIntersectingGivenSections( + ids_non_cutting_sections, + self._tracks_intersecting_sections, + self._get_section_by_id, + self._intersection_repository, + ).get_ids() + + +class TracksIntersectingAllSections(TrackIdProvider): + """Returns track ids intersecting all sections. + + Args: + get_all_sections (GetAllSections): the use case to get all sections. + tracks_intersecting_sections (TracksIntersectingSections): get track ids + intersecting sections. + get_section_by_id (GetSectionsById): use case to get sections by id. + """ + + def __init__( + self, + get_all_sections: GetAllSections, + tracks_intersecting_sections: TracksIntersectingSections, + get_section_by_id: GetSectionsById, + intersection_repository: IntersectionRepository, + ) -> None: + self._get_all_sections = get_all_sections + self._tracks_intersecting_sections = tracks_intersecting_sections + self._get_section_by_id = get_section_by_id + self._intersection_repository = intersection_repository + + def get_ids(self) -> set[TrackId]: + ids_all_sections = {section.id for section in self._get_all_sections()} + return TracksIntersectingGivenSections( + ids_all_sections, + self._tracks_intersecting_sections, + self._get_section_by_id, + self._intersection_repository, + ).get_ids() + + class TracksIntersectingGivenSections(TrackIdProvider): """Returns track ids intersecting given sections. @@ -158,6 +232,32 @@ def get_ids(self) -> Iterable[TrackId]: return ids +class TracksAssignedToAllFlows(TrackIdProvider): + """Returns track ids that are assigned to all flows. + + Args: + assigner (RoadUserAssigner): to assign tracks to flows. + event_repository (EventRepository): the event repository. + flow_repository (FlowRepository): the track repository. + """ + + def __init__( + self, + assigner: RoadUserAssigner, + event_repository: EventRepository, + flow_repository: FlowRepository, + ) -> None: + self._assigner = assigner + self._event_repository = event_repository + self._flow_repository = flow_repository + + def get_ids(self) -> Iterable[TrackId]: + all_flow_ids = [flow.id for flow in self._flow_repository.get_all()] + return TracksAssignedToGivenFlows( + self._assigner, self._event_repository, self._flow_repository, all_flow_ids + ).get_ids() + + class TracksAssignedToGivenFlows(TrackIdProvider): """Returns track ids that are assigned to the given flows. @@ -165,7 +265,7 @@ class TracksAssignedToGivenFlows(TrackIdProvider): assigner (RoadUserAssigner): to assign tracks to flows. event_repository (EventRepository): the event repository. flow_repository (FlowRepository): the track repository. - flow_ids (list[FlowId]): the flows fo identify assigned tracks for. + flow_ids (list[FlowId]): the flows to identify assigned tracks for. """ def __init__( diff --git a/OTAnalytics/application/use_cases/inside_cutting_section.py b/OTAnalytics/application/use_cases/inside_cutting_section.py new file mode 100644 index 000000000..48cf1fc31 --- /dev/null +++ b/OTAnalytics/application/use_cases/inside_cutting_section.py @@ -0,0 +1,44 @@ +from OTAnalytics.application.use_cases.section_repository import GetCuttingSections +from OTAnalytics.application.use_cases.track_repository import GetAllTracks +from OTAnalytics.domain.section import SectionId +from OTAnalytics.domain.track import TrackId +from OTAnalytics.domain.types import EventType + + +class TrackIdsInsideCuttingSections: + def __init__( + self, get_tracks: GetAllTracks, get_cutting_sections: GetCuttingSections + ): + self._get_tracks = get_tracks + self._get_cutting_sections = get_cutting_sections + + def __call__(self) -> set[TrackId]: + track_dataset = self._get_tracks.as_dataset() + cutting_sections = self._get_cutting_sections() + if not cutting_sections: + return set() + + results: set[TrackId] = set() + for cutting_section in cutting_sections: + offset = cutting_section.get_offset(EventType.SECTION_ENTER) + # set of all tracks where at least one coordinate is contained + # by at least one cutting section + results.update( + set( + track_id + for track_id, section_data in ( + track_dataset.contained_by_sections( + [cutting_section], offset + ).items() + ) + if contains_true(section_data) + ) + ) + return results + + +def contains_true(section_data: list[tuple[SectionId, list[bool]]]) -> bool: + for _, bool_list in section_data: + if any(bool_list): + return True + return False diff --git a/OTAnalytics/application/use_cases/section_repository.py b/OTAnalytics/application/use_cases/section_repository.py index b9243970a..ed8814b94 100644 --- a/OTAnalytics/application/use_cases/section_repository.py +++ b/OTAnalytics/application/use_cases/section_repository.py @@ -1,7 +1,13 @@ from typing import Iterable +from OTAnalytics.application.config import CLI_CUTTING_SECTION_MARKER from OTAnalytics.domain.geometry import RelativeOffsetCoordinate -from OTAnalytics.domain.section import Section, SectionId, SectionRepository +from OTAnalytics.domain.section import ( + Section, + SectionId, + SectionRepository, + SectionType, +) from OTAnalytics.domain.types import EventType @@ -23,6 +29,25 @@ def __call__(self) -> list[Section]: return self._section_repository.get_all() +class GetCuttingSections: + """Get all cutting sections from the repository.""" + + def __init__(self, section_repository: SectionRepository) -> None: + self._section_repository = section_repository + + def __call__(self) -> list[Section]: + cutting_sections = sorted( + [ + section + for section in self._section_repository.get_all() + if section.get_type() == SectionType.CUTTING + or section.name.startswith(CLI_CUTTING_SECTION_MARKER) + ], + key=lambda section: section.id.id, + ) + return cutting_sections + + class GetSectionsById: """Get sections by their id. diff --git a/OTAnalytics/application/use_cases/track_statistics.py b/OTAnalytics/application/use_cases/track_statistics.py new file mode 100644 index 000000000..8e4665782 --- /dev/null +++ b/OTAnalytics/application/use_cases/track_statistics.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass + +from OTAnalytics.application.use_cases.highlight_intersections import ( + TracksAssignedToAllFlows, + TracksIntersectingAllNonCuttingSections, +) +from OTAnalytics.application.use_cases.inside_cutting_section import ( + TrackIdsInsideCuttingSections, +) +from OTAnalytics.application.use_cases.track_repository import GetAllTrackIds + +START_OF_CUTTING_SECTION_NAME: str = "#clicut" + + +@dataclass +class TrackStatistics: + track_count: int + track_count_outside: int + track_count_inside: int + track_count_inside_not_intersecting: int + track_count_inside_intersecting_but_unassigned: int + track_count_inside_assigned: int + percentage_inside_assigned: float + percentage_inside_not_intersection: float + percentage_inside_intersecting_but_unassigned: float + + +class CalculateTrackStatistics: + def __init__( + self, + intersection_all_non_cutting_sections: TracksIntersectingAllNonCuttingSections, + assigned_to_all_flows: TracksAssignedToAllFlows, + get_all_track_ids: GetAllTrackIds, + track_ids_inside_cutting_sections: TrackIdsInsideCuttingSections, + ) -> None: + self._intersection_all_non_cutting_sections = ( + intersection_all_non_cutting_sections + ) + self._assigned_to_all_flows = assigned_to_all_flows + self._get_all_track_ids = get_all_track_ids + self._track_ids_inside_cutting_sections = track_ids_inside_cutting_sections + + def get_statistics(self) -> TrackStatistics: + ids_all = set(self._get_all_track_ids()) + ids_inside_cutting_sections = self._track_ids_inside_cutting_sections() + + track_count_inside = len(ids_inside_cutting_sections) + track_count = len(ids_all) + + track_count_outside = track_count - track_count_inside + + track_count_inside_not_intersecting = len( + ids_inside_cutting_sections.difference( + self._intersection_all_non_cutting_sections.get_ids() + ) + ) + track_count_inside_assigned = len( + ids_inside_cutting_sections.intersection( + self._assigned_to_all_flows.get_ids() + ) + ) + track_count_inside_intersecting_but_unassigned = ( + track_count_inside + - track_count_inside_not_intersecting + - track_count_inside_assigned + ) + percentage_inside_assigned = self.__percentage( + track_count_inside_assigned, track_count_inside + ) + percentage_inside_not_intersection = self.__percentage( + track_count_inside_not_intersecting, track_count_inside + ) + percentage_inside_intersecting_but_unassigned = self.__percentage( + track_count_inside_intersecting_but_unassigned, track_count_inside + ) + return TrackStatistics( + track_count, + track_count_outside, + track_count_inside, + track_count_inside_not_intersecting, + track_count_inside_intersecting_but_unassigned, + track_count_inside_assigned, + percentage_inside_assigned, + percentage_inside_not_intersection, + percentage_inside_intersecting_but_unassigned, + ) + + def __percentage(self, track_count: int, all_tracks: int) -> float: + if all_tracks == 0: + return 1.0 + return track_count / all_tracks diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py index b6d0ee3f4..e4b502dcf 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py @@ -20,6 +20,9 @@ from OTAnalytics.adapter_ui.abstract_frame_track_plotting import ( AbstractFrameTrackPlotting, ) +from OTAnalytics.adapter_ui.abstract_frame_track_statistics import ( + AbstractFrameTrackStatistics, +) from OTAnalytics.adapter_ui.abstract_frame_tracks import AbstractFrameTracks from OTAnalytics.adapter_ui.abstract_main_window import AbstractMainWindow from OTAnalytics.adapter_ui.abstract_treeview_interface import AbstractTreeviewInterface @@ -108,6 +111,7 @@ validate_minute, validate_second, ) +from OTAnalytics.domain.event import EventRepositoryEvent from OTAnalytics.domain.files import DifferentDrivesException from OTAnalytics.domain.filter import FilterElement from OTAnalytics.domain.flow import Flow, FlowId, FlowListObserver @@ -176,12 +180,6 @@ LINE_SECTION: str = "line_section" TO_SECTION = "to_section" FROM_SECTION = "from_section" -MISSING_TRACK_FRAME_MESSAGE = "tracks frame" -MISSING_VIDEO_FRAME_MESSAGE = "videos frame" -MISSING_VIDEO_CONTROL_FRAME_MESSAGE = "video control frame" -MISSING_SECTION_FRAME_MESSAGE = "sections frame" -MISSING_FLOW_FRAME_MESSAGE = "flows frame" -MISSING_ANALYSIS_FRAME_MESSAGE = "analysis frame" class MissingInjectedInstanceError(Exception): @@ -218,6 +216,119 @@ class DummyViewModel( SectionListObserver, FlowListObserver, ): + @property + def window(self) -> AbstractMainWindow: + if self._window is None: + raise MissingInjectedInstanceError("window") + return self._window + + @property + def frame_project(self) -> AbstractFrameProject: + if self._frame_project is None: + raise MissingInjectedInstanceError("frame project") + return self._frame_project + + @property + def frame_tracks(self) -> AbstractFrameTracks: + if self._frame_tracks is None: + raise MissingInjectedInstanceError("frame tracks") + return self._frame_tracks + + @property + def frame_videos(self) -> AbstractFrame: + if self._frame_videos is None: + raise MissingInjectedInstanceError("frame videos") + return self._frame_videos + + @property + def frame_canvas(self) -> AbstractFrameCanvas: + if self._frame_canvas is None: + raise MissingInjectedInstanceError("frame canvas") + return self._frame_canvas + + @property + def frame_video_control(self) -> AbstractFrame: + if self._frame_video_control is None: + raise MissingInjectedInstanceError("frame video control") + return self._frame_video_control + + @property + def frame_sections(self) -> AbstractFrame: + if self._frame_sections is None: + raise MissingInjectedInstanceError("frame sections") + return self._frame_sections + + @property + def frame_flows(self) -> AbstractFrame: + if self._frame_flows is None: + raise MissingInjectedInstanceError("frame flows") + return self._frame_flows + + @property + def frame_filter(self) -> AbstractFrameFilter: + if self._frame_filter is None: + raise MissingInjectedInstanceError("frame filter") + return self._frame_filter + + @property + def frame_analysis(self) -> AbstractFrame: + if self._frame_analysis is None: + raise MissingInjectedInstanceError("frame analysis") + return self._frame_analysis + + @property + def canvas(self) -> AbstractCanvas: + if self._canvas is None: + raise MissingInjectedInstanceError("frame canvas") + return self._canvas + + @property + def frame_track_plotting(self) -> AbstractFrameTrackPlotting: + if self._frame_track_plotting is None: + raise MissingInjectedInstanceError("frame track plotting") + return self._frame_track_plotting + + @property + def frame_svz_metadata(self) -> AbstractFrameSvzMetadata: + if self._frame_svz_metadata is None: + raise MissingInjectedInstanceError("frame svz metadata") + return self._frame_svz_metadata + + @property + def treeview_videos(self) -> AbstractTreeviewInterface: + if self._treeview_videos is None: + raise MissingInjectedInstanceError("treeview videos") + return self._treeview_videos + + @property + def treeview_files(self) -> AbstractTreeviewInterface: + if self._treeview_files is None: + raise MissingInjectedInstanceError("treeview files") + return self._treeview_files + + @property + def treeview_sections(self) -> AbstractTreeviewInterface: + if self._treeview_sections is None: + raise MissingInjectedInstanceError("treeview sections") + return self._treeview_sections + + @property + def treeview_flows(self) -> AbstractTreeviewInterface: + if self._treeview_flows is None: + raise MissingInjectedInstanceError("treeview flows") + return self._treeview_flows + + @property + def button_quick_save_config(self) -> AbstractButtonQuickSaveConfig: + if self._button_quick_save_config is None: + raise MissingInjectedInstanceError("button quick save config") + return self._button_quick_save_config + + @property + def frame_track_statistics(self) -> AbstractFrameTrackStatistics: + if self._frame_track_statistics is None: + raise MissingInjectedInstanceError("frame track statistics") + return self._frame_track_statistics def __init__( self, @@ -245,25 +356,24 @@ def __init__( self._canvas: Optional[AbstractCanvas] = None self._frame_track_plotting: Optional[AbstractFrameTrackPlotting] = None self._frame_svz_metadata: Optional[AbstractFrameSvzMetadata] = None - self._treeview_sections: Optional[AbstractTreeviewInterface] - self._treeview_flows: Optional[AbstractTreeviewInterface] + self._treeview_videos: Optional[AbstractTreeviewInterface] = None + self._treeview_files: Optional[AbstractTreeviewInterface] = None + self._treeview_sections: Optional[AbstractTreeviewInterface] = None + self._treeview_flows: Optional[AbstractTreeviewInterface] = None self._button_quick_save_config: AbstractButtonQuickSaveConfig | None = None + self._frame_track_statistics: Optional[AbstractFrameTrackStatistics] = None self._new_section: dict = {} def show_svz(self) -> bool: return self._show_svz def notify_videos(self, videos: list[Video]) -> None: - if self._treeview_videos is None: - raise MissingInjectedInstanceError(type(self._treeview_videos).__name__) self.update_quick_save_button(videos) - self._treeview_videos.update_items() + self.treeview_videos.update_items() self._update_enabled_buttons() def notify_files(self) -> None: - if self._treeview_files is None: - raise MissingInjectedInstanceError(type(self._treeview_files).__name__) - self._treeview_files.update_items() + self.treeview_files.update_items() self._update_enabled_buttons() def _update_enabled_buttons(self) -> None: @@ -279,43 +389,35 @@ def _update_enabled_general_buttons(self) -> None: action_running = self._application.action_state.action_running.get() general_buttons_enabled = not action_running for frame in frames: - if frame is None: - raise MissingInjectedInstanceError(type(frame).__name__) frame.set_enabled_general_buttons(general_buttons_enabled) - def _get_frames(self) -> list: + def _get_frames(self) -> list[AbstractFrame | AbstractFrameProject]: return [ - self._frame_tracks, - self._frame_videos, - self._frame_project, - self._frame_sections, - self._frame_flows, - self._frame_analysis, + self.frame_tracks, + self.frame_videos, + self.frame_project, + self.frame_sections, + self.frame_flows, + self.frame_analysis, ] def _update_enabled_track_buttons(self) -> None: - if self._frame_tracks is None: - raise MissingInjectedInstanceError(MISSING_TRACK_FRAME_MESSAGE) action_running = self._application.action_state.action_running.get() selected_section_ids = self.get_selected_section_ids() single_section_selected = len(selected_section_ids) == 1 single_track_enabled = (not action_running) and single_section_selected - self._frame_tracks.set_enabled_change_single_item_buttons(single_track_enabled) + self.frame_tracks.set_enabled_change_single_item_buttons(single_track_enabled) def _update_enabled_video_buttons(self) -> None: - if self._frame_videos is None: - raise MissingInjectedInstanceError(MISSING_VIDEO_FRAME_MESSAGE) action_running = self._application.action_state.action_running.get() selected_videos: list[Video] = self._application.get_selected_videos() any_video_selected = len(selected_videos) > 0 multiple_videos_enabled = (not action_running) and any_video_selected - self._frame_videos.set_enabled_change_multiple_items_buttons( + self.frame_videos.set_enabled_change_multiple_items_buttons( multiple_videos_enabled ) def _update_enabled_section_buttons(self) -> None: - if self._frame_sections is None: - raise MissingInjectedInstanceError(MISSING_SECTION_FRAME_MESSAGE) action_running = self._application.action_state.action_running.get() videos_exist = len(self._application.get_all_videos()) > 0 selected_section_ids = self.get_selected_section_ids() @@ -326,17 +428,15 @@ def _update_enabled_section_buttons(self) -> None: single_section_enabled = add_section_enabled and single_section_selected multiple_sections_enabled = add_section_enabled and any_section_selected - self._frame_sections.set_enabled_add_buttons(add_section_enabled) - self._frame_sections.set_enabled_change_single_item_buttons( + self.frame_sections.set_enabled_add_buttons(add_section_enabled) + self.frame_sections.set_enabled_change_single_item_buttons( single_section_enabled ) - self._frame_sections.set_enabled_change_multiple_items_buttons( + self.frame_sections.set_enabled_change_multiple_items_buttons( multiple_sections_enabled ) def _update_enabled_flow_buttons(self) -> None: - if self._frame_flows is None: - raise MissingInjectedInstanceError(MISSING_FLOW_FRAME_MESSAGE) action_running = self._application.action_state.action_running.get() two_sections_exist = len(self._application.get_all_sections()) > 1 flows_exist = len(self._application.get_all_flows()) > 0 @@ -348,19 +448,17 @@ def _update_enabled_flow_buttons(self) -> None: single_flow_enabled = add_flow_enabled and single_flow_selected and flows_exist multiple_flows_enabled = add_flow_enabled and any_flow_selected and flows_exist - self._frame_flows.set_enabled_add_buttons(add_flow_enabled) - self._frame_flows.set_enabled_change_single_item_buttons(single_flow_enabled) - self._frame_flows.set_enabled_change_multiple_items_buttons( + self.frame_flows.set_enabled_add_buttons(add_flow_enabled) + self.frame_flows.set_enabled_change_single_item_buttons(single_flow_enabled) + self.frame_flows.set_enabled_change_multiple_items_buttons( multiple_flows_enabled ) def _update_enabled_video_control_buttons(self) -> None: - if self._frame_video_control is None: - raise MissingInjectedInstanceError(MISSING_VIDEO_CONTROL_FRAME_MESSAGE) action_running = self._application.action_state.action_running.get() videos_exist = len(self._application.get_all_videos()) > 0 general_activated = not action_running and videos_exist - self._frame_video_control.set_enabled_general_buttons(general_activated) + self.frame_video_control.set_enabled_general_buttons(general_activated) def _on_section_changed(self, section: SectionId) -> None: self._refresh_sections_in_ui() @@ -369,18 +467,12 @@ def _on_flow_changed(self, flow_id: FlowId) -> None: self.notify_flows([flow_id]) def _on_background_updated(self, image: Optional[TrackImage]) -> None: - if self._frame_canvas is None: - raise MissingInjectedInstanceError(AbstractFrameCanvas.__name__) - if image: - self._frame_canvas.update_background(image) + self.frame_canvas.update_background(image) else: - self._frame_canvas.clear_image() + self.frame_canvas.clear_image() def _update_date_range(self, filter_element: FilterElement) -> None: - if self._frame_filter is None: - raise MissingInjectedInstanceError(AbstractFrameFilter.__name__) - date_range = filter_element.date_range start_date = ( date_range.start_date.strftime(DATETIME_FORMAT) @@ -391,31 +483,26 @@ def _update_date_range(self, filter_element: FilterElement) -> None: end_date = ( date_range.end_date.strftime(DATETIME_FORMAT) if date_range.end_date else "" ) - self._frame_filter.update_date_range( + self.frame_filter.update_date_range( {"start_date": start_date, "end_date": end_date} ) def update_quick_save_button(self, _: Any) -> None: - if self._button_quick_save_config is None: - raise MissingInjectedInstanceError(AbstractButtonQuickSaveConfig.__name__) try: if self._application.config_has_changed(): - self._button_quick_save_config.set_state_changed_color() + self.button_quick_save_config.set_state_changed_color() else: - self._button_quick_save_config.set_default_color() + self.button_quick_save_config.set_default_color() except NoExistingConfigFound: - self._button_quick_save_config.set_default_color() + self.button_quick_save_config.set_default_color() def notify_tracks(self, track_event: TrackRepositoryEvent) -> None: self.notify_files() def _intersect_tracks_with_sections(self) -> None: - if self._window is None: - raise MissingInjectedInstanceError(type(self._window).__name__) - start_msg_popup = MinimalInfoBox( message="Create events...", - initial_position=self._window.get_position(), + initial_position=self.window.get_position(), ) self._application.intersect_tracks_with_sections() start_msg_popup.update_message(message="Creating events completed") @@ -426,17 +513,13 @@ def notify_sections(self, section_event: SectionRepositoryEvent) -> None: self.update_quick_save_button(section_event) def _refresh_sections_in_ui(self) -> None: - if self._treeview_sections is None: - raise MissingInjectedInstanceError(type(self._treeview_sections).__name__) self.refresh_items_on_canvas() - self._treeview_sections.update_items() + self.treeview_sections.update_items() self._update_enabled_buttons() def notify_flows(self, flows: list[FlowId]) -> None: - if self._treeview_flows is None: - raise MissingInjectedInstanceError(type(self._treeview_flows).__name__) self.refresh_items_on_canvas() - self._treeview_flows.update_items() + self.treeview_flows.update_items() self.update_quick_save_button(flow_id) def _notify_action_running_state(self, running: bool) -> None: @@ -470,9 +553,7 @@ def set_window(self, window: AbstractMainWindow) -> None: def _update_selected_videos(self, videos: list[Video]) -> None: current_paths = [str(video.get_path()) for video in videos] self._selected_videos = current_paths - if self._treeview_videos is None: - raise MissingInjectedInstanceError(type(self._treeview_sections).__name__) - self._treeview_videos.update_selected_items(current_paths) + self.treeview_videos.update_selected_items(current_paths) self._update_enabled_video_buttons() def add_video(self) -> None: @@ -514,10 +595,8 @@ def set_frame_project(self, project_frame: AbstractFrameProject) -> None: self.show_current_project() def show_current_project(self, _: Any = None) -> None: - if self._frame_project is None: - raise MissingInjectedInstanceError(type(self._frame_project).__name__) project = self._application._datastore.project - self._frame_project.update(name=project.name, start_date=project.start_date) + self.frame_project.update(name=project.name, start_date=project.start_date) def save_otconfig(self) -> None: suggested_save_path = self._application.suggest_save_path(OTCONFIG_FILE_TYPE) @@ -536,43 +615,35 @@ def _save_otconfig(self, otconfig_file: Path) -> None: logger().info(f"Config file to save: {otconfig_file}") try: self._application.save_otconfig(otconfig_file) - except NoSectionsToSave as cause: + except NoSectionsToSave: message = ( f"{MESSAGE_CONFIGURATION_NOT_SAVED}" f"No sections to save, please add new sections first." ) - self.__show_error(cause, message) + self.__show_error(message) return - except DifferentDrivesException as cause: + except DifferentDrivesException: message = ( f"{MESSAGE_CONFIGURATION_NOT_SAVED}" f"Configuration and video files are located on different drives." ) - self.__show_error(cause, message) + self.__show_error(message) return - except MissingDate as cause: + except MissingDate: message = ( f"{MESSAGE_CONFIGURATION_NOT_SAVED}" f"Start date is missing or invalid. Please add a valid start date." ) - self.__show_error(cause, message) + self.__show_error(message) return def _get_window_position(self) -> tuple[int, int]: - if self._window is None: - raise MissingInjectedInstanceError(type(self._window).__name__) - return self._window.get_position() - - def __show_error(self, cause: Exception, message: str) -> None: - if self._treeview_sections is None: - raise MissingInjectedInstanceError( - type(self._treeview_sections).__name__ - ) from cause - position = self._treeview_sections.get_position() + return self.window.get_position() + def __show_error(self, message: str) -> None: InfoBox( message=message, - initial_position=position, + initial_position=self.treeview_sections.get_position(), ) def load_otconfig(self) -> None: @@ -642,26 +713,20 @@ def _update_selected_sections(self, section_ids: list[SectionId]) -> None: self.update_section_offset_button_state() def _update_selected_section_items(self) -> None: - if self._treeview_sections is None: - raise MissingInjectedInstanceError(type(self._treeview_sections).__name__) - new_section_ids = self.get_selected_section_ids() - self._treeview_sections.update_selected_items(new_section_ids) + self.treeview_sections.update_selected_items(new_section_ids) self.refresh_items_on_canvas() def update_section_offset_button_state(self) -> None: - if self._frame_tracks is None: - raise MissingInjectedInstanceError(type(self._frame_tracks).__name__) - currently_selected_sections = ( self._application.section_state.selected_sections.get() ) - default_color = self._frame_tracks.get_default_offset_button_color() + default_color = self.frame_tracks.get_default_offset_button_color() single_section_selected = len(currently_selected_sections) == 1 if not single_section_selected: - self._frame_tracks.configure_offset_button(default_color, False) + self.frame_tracks.configure_offset_button(default_color, False) return section_offset = self._application.get_section_offset( @@ -673,20 +738,18 @@ def update_section_offset_button_state(self) -> None: visualization_offset = self._application.track_view_state.track_offset.get() if section_offset == visualization_offset: - self._frame_tracks.configure_offset_button(default_color, False) + self.frame_tracks.configure_offset_button(default_color, False) else: - self._frame_tracks.configure_offset_button(COLOR_ORANGE, True) + self.frame_tracks.configure_offset_button(COLOR_ORANGE, True) def _update_selected_flows(self, flow_ids: list[FlowId]) -> None: self._update_selected_flow_items() self._update_enabled_buttons() def _update_selected_flow_items(self) -> None: - if self._treeview_flows is None: - raise MissingInjectedInstanceError(type(self._treeview_flows).__name__) new_selected_flow_ids = self.get_selected_flow_ids() - self._treeview_flows.update_selected_items(new_selected_flow_ids) + self.treeview_flows.update_selected_items(new_selected_flow_ids) self.refresh_items_on_canvas() def set_selected_flow_ids(self, ids: list[str]) -> None: @@ -800,12 +863,8 @@ def _save_otflow(self, otflow_file: Path) -> None: logger().info(f"Sections file to save: {otflow_file}") try: self._application.save_otflow(Path(otflow_file)) - except NoSectionsToSave as cause: - if self._treeview_sections is None: - raise MissingInjectedInstanceError( - type(self._treeview_sections).__name__ - ) from cause - position = self._treeview_sections.get_position() + except NoSectionsToSave: + position = self.treeview_sections.get_position() InfoBox( message="No sections to save, please add new sections first", initial_position=position, @@ -817,19 +876,15 @@ def cancel_action(self) -> None: def add_line_section(self) -> None: self.set_selected_section_ids([]) - if self._canvas is None: - raise MissingInjectedInstanceError(AbstractCanvas.__name__) self._start_action() - SectionBuilder(viewmodel=self, canvas=self._canvas, style=EDITED_SECTION_STYLE) + SectionBuilder(viewmodel=self, canvas=self.canvas, style=EDITED_SECTION_STYLE) def add_area_section(self) -> None: self.set_selected_section_ids([]) - if self._canvas is None: - raise MissingInjectedInstanceError(AbstractCanvas.__name__) self._start_action() SectionBuilder( viewmodel=self, - canvas=self._canvas, + canvas=self.canvas, is_area_section=True, style=EDITED_SECTION_STYLE, ) @@ -973,17 +1028,15 @@ def edit_section_geometry(self) -> None: "Multiple sections are selected. Unable to edit section geometry!" ) - if self._canvas is None: - raise MissingInjectedInstanceError(AbstractCanvas.__name__) self._start_action() - CanvasElementDeleter(canvas=self._canvas).delete(tag_or_id=TAG_SELECTED_SECTION) + CanvasElementDeleter(canvas=self.canvas).delete(tag_or_id=TAG_SELECTED_SECTION) if selected_section_ids: if current_section := self._application.get_section_for( SectionId(selected_section_ids[0]) ): SectionGeometryEditor( viewmodel=self, - canvas=self._canvas, + canvas=self.canvas, section=current_section, edited_section_style=EDITED_SECTION_STYLE, pre_edit_section_style=PRE_EDIT_SECTION_STYLE, @@ -992,11 +1045,8 @@ def edit_section_geometry(self) -> None: ) def edit_selected_section_metadata(self) -> None: - if self._treeview_sections is None: - raise MissingInjectedInstanceError(type(self._treeview_sections).__name__) - if not (selected_section_ids := self.get_selected_section_ids()): - position = self._treeview_sections.get_position() + position = self.treeview_sections.get_position() InfoBox( message="Please select a section to edit", initial_position=position ) @@ -1015,9 +1065,7 @@ def edit_selected_section_metadata(self) -> None: @action def _update_metadata(self, selected_section: Section) -> None: current_data = selected_section.to_dict() - if self._canvas is None: - raise MissingInjectedInstanceError(AbstractCanvas.__name__) - position = self._canvas.get_position() + position = self.canvas.get_position() with contextlib.suppress(CancelAddSection): self.__update_section_metadata(selected_section, current_data, position) @@ -1037,20 +1085,15 @@ def __update_section_metadata( logger().info(f"Updated line_section Metadata: {updated_section_data}") def _set_section_data(self, id: SectionId, data: dict) -> None: - if self._treeview_sections is None: - raise MissingInjectedInstanceError(AbstractTreeviewInterface.__name__) section = self._flow_parser.parse_section(data) self._application.update_section(section) if not section.name.startswith(CUTTING_SECTION_MARKER): - self._treeview_sections.update_selected_items([id.serialize()]) + self.treeview_sections.update_selected_items([id.serialize()]) @action def remove_sections(self) -> None: - if self._treeview_sections is None: - raise MissingInjectedInstanceError(type(self._treeview_sections).__name__) - if not (selected_section_ids := self.get_selected_section_ids()): - position = self._treeview_sections.get_position() + position = self.treeview_sections.get_position() InfoBox( message="Please select one or more sections to remove", initial_position=position, @@ -1066,7 +1109,7 @@ def remove_sections(self) -> None: ) for flow in self._application.flows_using_section(section_id): message += flow.name + "\n" - position = self._treeview_sections.get_position() + position = self.treeview_sections.get_position() InfoBox( message=message, initial_position=position, @@ -1082,13 +1125,9 @@ def refresh_items_on_canvas(self) -> None: self._draw_items_on_canvas() def _remove_items_from_canvas(self) -> None: - if self._canvas is None: - raise MissingInjectedInstanceError(AbstractCanvas.__name__) - CanvasElementDeleter(canvas=self._canvas).delete(tag_or_id=LINE_SECTION) + CanvasElementDeleter(canvas=self.canvas).delete(tag_or_id=LINE_SECTION) def _draw_items_on_canvas(self) -> None: - if self._canvas is None: - raise MissingInjectedInstanceError(AbstractCanvas.__name__) sections_to_highlight = self._get_sections_to_highlight() self._draw_sections(sections_to_highlight) if self._application.flow_state.selected_flows.get(): @@ -1100,9 +1139,7 @@ def _get_sections_to_highlight(self) -> list[str]: return [] def _draw_sections(self, sections_to_highlight: list[str]) -> None: - if self._canvas is None: - raise MissingInjectedInstanceError(AbstractCanvas.__name__) - section_painter = SectionPainter(canvas=self._canvas) + section_painter = SectionPainter(canvas=self.canvas) for section in self._get_sections(): tags = [LINE_SECTION] if section[ID] in sections_to_highlight: @@ -1121,8 +1158,6 @@ def _draw_sections(self, sections_to_highlight: list[str]) -> None: ) def _draw_arrow_for_selected_flows(self) -> None: - if self._canvas is None: - raise MissingInjectedInstanceError(AbstractCanvas.__name__) for flow in self._get_selected_flows(): if start_section := self._application.get_section_for(flow.start): if end_section := self._application.get_section_for(flow.end): @@ -1132,7 +1167,7 @@ def _draw_arrow_for_selected_flows(self) -> None: end_refpt_calculator = self._get_section_refpt_calculator( end_section ) - ArrowPainter(self._canvas).draw( + ArrowPainter(self.canvas).draw( start_section=start_section, end_section=end_section, start_refpt_calculator=start_refpt_calculator, @@ -1175,7 +1210,7 @@ def _transform_coordinates(self, section: dict) -> dict: return section def _to_coordinate_tuple(self, coordinate: dict) -> tuple[int, int]: - return (coordinate[geometry.X], coordinate[geometry.Y]) + return coordinate[geometry.X], coordinate[geometry.Y] def get_all_sections(self) -> Iterable[Section]: return self._application.get_all_sections() @@ -1212,9 +1247,7 @@ def _show_flow_popup( input_values: dict | None = None, title: str = "Add flow", ) -> dict: - if self._treeview_flows is None: - raise MissingInjectedInstanceError(type(self._treeview_flows).__name__) - position = self._treeview_flows.get_position() + position = self.treeview_flows.get_position() sections = list(self.get_all_sections()) if len(sections) < 2: InfoBox( @@ -1274,9 +1307,7 @@ def __try_add_flow(self, flow: Flow) -> None: try: self._application.add_flow(flow) except FlowAlreadyExists as cause: - if self._treeview_flows is None: - raise MissingInjectedInstanceError(type(self._treeview_flows).__name__) - position = self._treeview_flows.get_position() + position = self.treeview_flows.get_position() InfoBox(message=str(cause), initial_position=position) raise CancelAddFlow() @@ -1316,11 +1347,7 @@ def edit_selected_flow(self) -> None: ) self._edit_flow(flows[0]) else: - if self._treeview_flows is None: - raise MissingInjectedInstanceError( - type(self._treeview_flows).__name__ - ) - position = self._treeview_flows.get_position() + position = self.treeview_flows.get_position() InfoBox( message="Please select a flow to edit", initial_position=position ) @@ -1342,29 +1369,28 @@ def _edit_flow(self, flow: Flow) -> None: @action def remove_flows(self) -> None: - if self._treeview_flows is None: - raise MissingInjectedInstanceError(type(self._treeview_flows).__name__) if flow_ids := self._application.flow_state.selected_flows.get(): for flow_id in flow_ids: self._application.remove_flow(flow_id) self.refresh_items_on_canvas() else: - position = self._treeview_flows.get_position() + position = self.treeview_flows.get_position() InfoBox(message="Please select a flow to remove", initial_position=position) def create_events(self) -> None: - if self._window is None: - raise MissingInjectedInstanceError(type(self._window).__name__) - start_msg_popup = MinimalInfoBox( message="Create events...", - initial_position=self._window.get_position(), + initial_position=self.window.get_position(), ) self._application.create_events() + self.notify_flows(self.get_all_flow_ids()) start_msg_popup.update_message(message="Creating events completed") sleep(1) start_msg_popup.close() + def get_all_flow_ids(self) -> list[FlowId]: + return [flow.id for flow in self.get_all_flows()] + def save_events(self, file: str) -> None: self._application.save_events(Path(file)) logger().info(f"Eventlist file saved to '{file}'") @@ -1413,12 +1439,9 @@ def _configure_event_exporter( return event_list_exporter, file def set_track_offset(self, offset_x: float, offset_y: float) -> None: - if self._window is None: - raise MissingInjectedInstanceError(type(self._window).__name__) - start_msg_popup = MinimalInfoBox( message="Apply offset...", - initial_position=self._window.get_position(), + initial_position=self.window.get_position(), ) offset = geometry.RelativeOffsetCoordinate(offset_x, offset_y) self._application.track_view_state.track_offset.set(offset) @@ -1434,16 +1457,11 @@ def get_track_offset(self) -> Optional[tuple[float, float]]: def _update_offset( self, offset: Optional[geometry.RelativeOffsetCoordinate] ) -> None: - if self._frame_tracks is None: - raise MissingInjectedInstanceError(AbstractFrameTracks.__name__) - if offset: - self._frame_tracks.update_offset(offset.x, offset.y) + self.frame_tracks.update_offset(offset.x, offset.y) def change_track_offset_to_section_offset(self) -> None: self._application.change_track_offset_to_section_offset() - if self._frame_tracks is None: - raise MissingInjectedInstanceError(type(self._frame_tracks).__name__) self.update_section_offset_button_state() def next_frame(self) -> None: @@ -1489,33 +1507,19 @@ def validate_second(self, second: str) -> bool: def apply_filter_tracks_by_date(self, date_range: DateRange) -> None: self._application.update_date_range_tracks_filter(date_range) - if self._frame_filter is None: - raise MissingInjectedInstanceError(AbstractFrameFilter.__name__) - - self._frame_filter.set_active_color_on_filter_by_date_button() + self.frame_filter.set_active_color_on_filter_by_date_button() def apply_filter_tracks_by_class(self, classes: list[str]) -> None: self._application.update_class_tracks_filter(set(classes)) - if self._frame_filter is None: - raise MissingInjectedInstanceError(AbstractFrameFilter.__name__) - - self._frame_filter.set_active_color_on_filter_by_class_button() + self.frame_filter.set_active_color_on_filter_by_class_button() def reset_filter_tracks_by_date(self) -> None: self._application.update_date_range_tracks_filter(DateRange(None, None)) - - if self._frame_filter is None: - raise MissingInjectedInstanceError(AbstractFrameFilter.__name__) - - self._frame_filter.set_inactive_color_on_filter_by_date_button() + self.frame_filter.set_inactive_color_on_filter_by_date_button() def reset_filter_tracks_by_class(self) -> None: self._application.update_class_tracks_filter(None) - - if self._frame_filter is None: - raise MissingInjectedInstanceError(AbstractFrameFilter.__name__) - - self._frame_filter.set_inactive_color_on_filter_by_class_button() + self.frame_filter.set_inactive_color_on_filter_by_class_button() def get_first_detection_occurrence(self) -> Optional[datetime]: return self._application._tracks_metadata.first_detection_occurrence @@ -1552,17 +1556,15 @@ def enable_filter_track_by_date(self) -> None: self.__enable_filter_track_by_date_button() def __enable_filter_track_by_date_button(self) -> None: - if self._frame_filter is None: - raise MissingInjectedInstanceError(AbstractFrameFilter.__name__) current_date_range = ( self._application.track_view_state.filter_element.get().date_range ) self._application.track_view_state.filter_date_active.set(True) - self._frame_filter.enable_filter_by_date_button() + self.frame_filter.enable_filter_by_date_button() if current_date_range != DateRange(None, None): - self._frame_filter.set_active_color_on_filter_by_date_button() + self.frame_filter.set_active_color_on_filter_by_date_button() else: - self._frame_filter.set_inactive_color_on_filter_by_date_button() + self.frame_filter.set_inactive_color_on_filter_by_date_button() def disable_filter_track_by_date(self) -> None: self._application.disable_filter_track_by_date() @@ -1570,10 +1572,8 @@ def disable_filter_track_by_date(self) -> None: self.__disable_filter_track_by_date_button() def __disable_filter_track_by_date_button(self) -> None: - if self._frame_filter is None: - raise MissingInjectedInstanceError(AbstractFrameFilter.__name__) self._application.track_view_state.filter_date_active.set(False) - self._frame_filter.disable_filter_by_date_button() + self.frame_filter.disable_filter_by_date_button() def switch_to_prev_date_range(self) -> None: self._application.switch_to_prev_date_range() @@ -1583,26 +1583,18 @@ def switch_to_next_date_range(self) -> None: def enable_filter_track_by_class(self) -> None: self._application.enable_filter_track_by_class() - - if self._frame_filter is None: - raise MissingInjectedInstanceError(AbstractFrameFilter.__name__) - - self._frame_filter.enable_filter_by_class_button() + self.frame_filter.enable_filter_by_class_button() current_classes = ( self._application.track_view_state.filter_element.get().classifications ) if current_classes is not None: - self._frame_filter.set_active_color_on_filter_by_class_button() + self.frame_filter.set_active_color_on_filter_by_class_button() else: - self._frame_filter.set_inactive_color_on_filter_by_class_button() + self.frame_filter.set_inactive_color_on_filter_by_class_button() def disable_filter_track_by_class(self) -> None: self._application.disable_filter_track_by_class() - - if self._frame_filter is None: - raise MissingInjectedInstanceError(AbstractFrameFilter.__name__) - - self._frame_filter.disable_filter_by_class_button() + self.frame_filter.disable_filter_by_class_button() def export_counts(self) -> None: if len(self._application.get_all_flows()) == 0: @@ -1678,24 +1670,9 @@ def update_project_start_date(self, start_date: Optional[datetime]) -> None: self._application.update_project_start_date(start_date) def on_start_new_project(self, _: None) -> None: - self._reset_filters() - self._reset_plotting_layer() - self._display_preview_image() - - def _reset_filters(self) -> None: - if self._frame_filter is None: - raise MissingInjectedInstanceError(AbstractFrameFilter.__name__) - self._frame_filter.reset() - - def _display_preview_image(self) -> None: - if self._canvas is None: - raise MissingInjectedInstanceError(AbstractCanvas.__name__) - self._canvas.add_preview_image() - - def _reset_plotting_layer(self) -> None: - if self._frame_track_plotting is None: - raise MissingInjectedInstanceError(AbstractFrameTrackPlotting.__name__) - self._frame_track_plotting.reset_layers() + self.frame_filter.reset() + self.frame_track_plotting.reset_layers() + self.canvas.add_preview_image() def set_frame_track_plotting( self, frame_track_plotting: AbstractFrameTrackPlotting @@ -1831,13 +1808,21 @@ def set_svz_metadata_frame(self, frame: AbstractFrameSvzMetadata) -> None: self.update_svz_metadata_view() def update_svz_metadata_view(self, _: Any = None) -> None: - if self._frame_svz_metadata is None: - raise MissingInjectedInstanceError(type(self._frame_svz_metadata).__name__) project = self._application._datastore.project if metadata := project.metadata: - self._frame_svz_metadata.update(metadata=metadata.to_dict()) + self.frame_svz_metadata.update(metadata=metadata.to_dict()) else: - self._frame_svz_metadata.update({}) + self.frame_svz_metadata.update({}) def get_save_path_suggestion(self, file_type: str, context_file_type: str) -> Path: return self._application.suggest_save_path(file_type, context_file_type) + + def set_frame_track_statistics(self, frame: AbstractFrameTrackStatistics) -> None: + self._frame_track_statistics = frame + + def update_track_statistics(self, _: EventRepositoryEvent) -> None: + statistics = self._application.calculate_track_statistics() + self.frame_track_statistics.update_track_statistics(statistics) + + def get_tracks_assigned_to_each_flow(self) -> dict[FlowId, int]: + return self._application.number_of_tracks_assigned_to_each_flow() diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/frame_flows.py b/OTAnalytics/plugin_ui/customtkinter_gui/frame_flows.py index a4140a966..ee301a908 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/frame_flows.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/frame_flows.py @@ -129,15 +129,21 @@ def _on_double_click(self, event: Any) -> None: def update_items(self) -> None: self.delete(*self.get_children()) + tracks_assigned_to_each_flow = ( + self._viewmodel.get_tracks_assigned_to_each_flow() + ) + flows = [] + for flow in self._viewmodel.get_all_flows(): + flows.append( + self.__to_resource(flow, tracks_assigned_to_each_flow[flow.id]) + ) item_ids = ColumnResources( - sorted( - [self.__to_resource(flow) for flow in self._viewmodel.get_all_flows()] - ), + sorted(flows), lookup_column=COLUMN_FLOW, ) self.add_items(item_ids=item_ids) @staticmethod - def __to_resource(flow: Flow) -> ColumnResource: - values = {COLUMN_FLOW: flow.name} + def __to_resource(flow: Flow, tracks_assigned: int) -> ColumnResource: + values = {COLUMN_FLOW: f"{flow.name} ({tracks_assigned})"} return ColumnResource(id=flow.id.id, values=values) diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/frame_track_statistics.py b/OTAnalytics/plugin_ui/customtkinter_gui/frame_track_statistics.py new file mode 100644 index 000000000..f54ae5f9b --- /dev/null +++ b/OTAnalytics/plugin_ui/customtkinter_gui/frame_track_statistics.py @@ -0,0 +1,123 @@ +from typing import Any + +from customtkinter import CTkLabel, ThemeManager + +from OTAnalytics.adapter_ui.abstract_frame_track_statistics import ( + AbstractFrameTrackStatistics, +) +from OTAnalytics.adapter_ui.view_model import ViewModel +from OTAnalytics.application.use_cases.track_statistics import TrackStatistics +from OTAnalytics.plugin_ui.customtkinter_gui.abstract_ctk_frame import AbstractCTkFrame +from OTAnalytics.plugin_ui.customtkinter_gui.constants import PADX, STICKY + + +class FrameTrackStatistics(AbstractCTkFrame, AbstractFrameTrackStatistics): + def __init__(self, viewmodel: ViewModel, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._is_initialized = False + self._viewmodel = viewmodel + self.default_border_color = ThemeManager.theme["CTkEntry"]["border_color"] + self._get_widgets() + self._place_widgets() + self.introduce_to_viewmodel() + self._is_initialized = True + + def introduce_to_viewmodel(self) -> None: + self._viewmodel.set_frame_track_statistics(self) + + def _get_widgets(self) -> None: + self._label_all_tracks = CTkLabel( + master=self, text="All tracks:", anchor="nw", justify="right" + ) + self._label_all_tracks_value = CTkLabel( + master=self, + text="Please update flow highlighting", + anchor="nw", + justify="right", + ) + self._label_inside_tracks = CTkLabel( + master=self, text="Inside cutting section:", anchor="nw", justify="right" + ) + self._label_inside_tracks_value = CTkLabel( + master=self, + text="Please update flow highlighting", + anchor="nw", + justify="right", + ) + self._label_assigned_tracks = CTkLabel( + master=self, text="Tracks assigned to flows:", anchor="nw", justify="right" + ) + self._label_assigned_tracks_value = CTkLabel( + master=self, + text="Please update flow highlighting", + anchor="nw", + justify="right", + ) + self._label_not_intersection_tracks = CTkLabel( + master=self, + text="Tracks not intersecting sections:", + anchor="nw", + justify="right", + ) + self._label_not_intersection_tracks_value = CTkLabel( + master=self, + text="Please update flow highlighting", + anchor="nw", + justify="right", + ) + self._label_intersecting_not_assigned_tracks = CTkLabel( + master=self, + text="Tracks intersecting not assigned:", + anchor="nw", + justify="right", + ) + self._label_intersecting_not_assigned_tracks_value = CTkLabel( + master=self, + text="Please update flow highlighting", + anchor="nw", + justify="right", + ) + + def _place_widgets(self) -> None: + self.grid_rowconfigure(0, weight=1) + self.grid_columnconfigure(0, weight=0) + self.grid_columnconfigure(1, weight=1) + for index, label in enumerate( + [ + self._label_all_tracks, + self._label_all_tracks_value, + self._label_inside_tracks, + self._label_inside_tracks_value, + self._label_assigned_tracks, + self._label_assigned_tracks_value, + self._label_not_intersection_tracks, + self._label_not_intersection_tracks_value, + self._label_intersecting_not_assigned_tracks, + self._label_intersecting_not_assigned_tracks_value, + ] + ): + label.grid( + row=int(index / 2), + column=int(index % 2), + padx=PADX, + pady=0, + sticky=STICKY, + ) + + def update_track_statistics(self, track_statistics: TrackStatistics) -> None: + self._label_all_tracks_value.configure(text=f"{track_statistics.track_count}") + self._label_inside_tracks_value.configure( + text=f"{track_statistics.track_count_inside}" + ) + self._label_assigned_tracks_value.configure( + text=f"{track_statistics.track_count_inside_assigned} " + f"({track_statistics.percentage_inside_assigned:.1%}) " + ) + self._label_not_intersection_tracks_value.configure( + text=f"{track_statistics.track_count_inside_not_intersecting} " + f"({track_statistics.percentage_inside_not_intersection:.1%}) " + ) + self._label_intersecting_not_assigned_tracks_value.configure( + text=f"{track_statistics.track_count_inside_intersecting_but_unassigned} " + f"({track_statistics.percentage_inside_intersecting_but_unassigned:.1%}) " + ) diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/gui.py b/OTAnalytics/plugin_ui/customtkinter_gui/gui.py index d61ddefd8..1425e00bb 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/gui.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/gui.py @@ -2,7 +2,13 @@ from functools import partial from typing import Any, Sequence -from customtkinter import CTk, CTkFrame, set_appearance_mode, set_default_color_theme +from customtkinter import ( + CTk, + CTkFrame, + CTkScrollableFrame, + set_appearance_mode, + set_default_color_theme, +) from OTAnalytics.adapter_ui.abstract_main_window import AbstractMainWindow from OTAnalytics.adapter_ui.view_model import ViewModel @@ -29,6 +35,9 @@ from OTAnalytics.plugin_ui.customtkinter_gui.frame_track_plotting import ( FrameTrackPlotting, ) +from OTAnalytics.plugin_ui.customtkinter_gui.frame_track_statistics import ( + FrameTrackStatistics, +) from OTAnalytics.plugin_ui.customtkinter_gui.frame_tracks import TracksFrame from OTAnalytics.plugin_ui.customtkinter_gui.frame_video_control import ( FrameVideoControl, @@ -124,28 +133,56 @@ def __init__( viewmodel=viewmodel, layers=layers, ) - self._frame_filter = SingleFrameTabview( - master=self, - title="Visualization Filters", - frame_factory=(partial(FrameFilter, viewmodel=viewmodel)), - ) self._frame_canvas = FrameCanvas( master=self, viewmodel=self._viewmodel, ) + self._frame_canvas_controls = FrameCanvasControls( + master=self, viewmodel=self._viewmodel + ) + self.grid_rowconfigure(0, weight=0) + self.grid_rowconfigure(1, weight=10) + self.grid_columnconfigure(0, weight=0) + self.grid_columnconfigure(1, weight=1) + self._frame_canvas.grid(row=0, column=0, pady=PADY, sticky=STICKY) + self._frame_track_plotting.grid(row=0, column=1, pady=PADY, sticky=STICKY) + self._frame_canvas_controls.grid(row=1, column=0, pady=PADY, sticky=STICKY) + + +class FrameCanvasControls(CTkScrollableFrame): + def __init__( + self, + master: Any, + viewmodel: ViewModel, + **kwargs: Any, + ) -> None: + super().__init__(master=master, bg_color="transparent", **kwargs) + self._viewmodel = viewmodel + self._get_widgets() + self._place_widgets() + + def _get_widgets(self) -> None: + self._frame_track_statistics = SingleFrameTabview( + master=self, + title="Track Statistics", + frame_factory=partial(FrameTrackStatistics, viewmodel=self._viewmodel), + ) + self._frame_filter = SingleFrameTabview( + master=self, + title="Visualization Filters", + frame_factory=partial(FrameFilter, viewmodel=self._viewmodel), + ) self._frame_video_control = SingleFrameTabview( master=self, title="Video Control", - frame_factory=(partial(FrameVideoControl, viewmodel=viewmodel)), + frame_factory=partial(FrameVideoControl, viewmodel=self._viewmodel), ) + + def _place_widgets(self) -> None: self.grid_rowconfigure(0, weight=0) self.grid_rowconfigure(1, weight=0) self.grid_rowconfigure(2, weight=0) - self.grid_rowconfigure(3, weight=1) - self.grid_columnconfigure(0, weight=0) - self.grid_columnconfigure(1, weight=1) - self._frame_canvas.grid(row=0, column=0, pady=PADY, sticky=STICKY) - self._frame_track_plotting.grid(row=0, column=1, pady=PADY, sticky=STICKY) + self._frame_track_statistics.grid(row=0, column=0, pady=PADY, sticky=STICKY) self._frame_filter.grid(row=1, column=0, pady=PADY, sticky=STICKY) self._frame_video_control.grid(row=2, column=0, pady=PADY, sticky=STICKY) diff --git a/OTAnalytics/plugin_ui/main_application.py b/OTAnalytics/plugin_ui/main_application.py index b8df3a2a8..53f929ff3 100644 --- a/OTAnalytics/plugin_ui/main_application.py +++ b/OTAnalytics/plugin_ui/main_application.py @@ -90,6 +90,9 @@ ClearAllFlows, GetAllFlows, ) +from OTAnalytics.application.use_cases.flow_statistics import ( + NumberOfTracksAssignedToEachFlow, +) from OTAnalytics.application.use_cases.generate_flows import ( ArrowFlowNameGenerator, CrossProductFlowGenerator, @@ -100,8 +103,16 @@ RepositoryFlowIdGenerator, ) from OTAnalytics.application.use_cases.get_current_project import GetCurrentProject +from OTAnalytics.application.use_cases.get_road_user_assignments import ( + GetRoadUserAssignments, +) from OTAnalytics.application.use_cases.highlight_intersections import ( IntersectionRepository, + TracksAssignedToAllFlows, + TracksIntersectingAllNonCuttingSections, +) +from OTAnalytics.application.use_cases.inside_cutting_section import ( + TrackIdsInsideCuttingSections, ) from OTAnalytics.application.use_cases.intersection_repository import ( ClearAllIntersections, @@ -123,6 +134,7 @@ AddSection, ClearAllSections, GetAllSections, + GetCuttingSections, GetSectionsById, RemoveSection, ) @@ -138,6 +150,7 @@ RemoveTracks, TrackRepositorySize, ) +from OTAnalytics.application.use_cases.track_statistics import CalculateTrackStatistics from OTAnalytics.application.use_cases.track_to_video_repository import ( ClearAllTrackToVideos, ) @@ -354,7 +367,7 @@ def start_gui(self, run_config: RunConfiguration) -> None: remove_tracks = RemoveTracks(track_repository) clear_all_tracks = ClearAllTracks(track_repository) - get_sections_bv_id = GetSectionsById(section_repository) + get_sections_by_id = GetSectionsById(section_repository) add_section = AddSection(section_repository) remove_section = RemoveSection(section_repository) clear_all_sections = ClearAllSections(section_repository) @@ -395,7 +408,7 @@ def start_gui(self, run_config: RunConfiguration) -> None: event_repository, flow_repository, track_repository, - get_sections_bv_id, + get_sections_by_id, create_events, ) load_otflow = self._create_use_case_load_otflow( @@ -431,7 +444,7 @@ def start_gui(self, run_config: RunConfiguration) -> None: clear_repositories, reset_project_config, track_view_state, file_state ) cut_tracks_intersecting_section = self._create_cut_tracks_intersecting_section( - get_sections_bv_id, + get_sections_by_id, get_all_tracks, add_all_tracks, remove_tracks, @@ -499,6 +512,27 @@ def start_gui(self, run_config: RunConfiguration) -> None: save_path_suggester = SavePathSuggester( file_state, get_all_track_files, get_all_videos, get_current_project ) + tracks_intersecting_sections = self._create_tracks_intersecting_sections( + get_all_tracks + ) + calculate_track_statistics = self._create_calculate_track_statistics( + get_sections, + tracks_intersecting_sections, + get_sections_by_id, + intersection_repository, + road_user_assigner, + event_repository, + flow_repository, + track_repository, + section_repository, + get_all_tracks, + ) + get_road_user_assignments = GetRoadUserAssignments( + flow_repository, event_repository, road_user_assigner + ) + number_of_tracks_assigned_to_each_flow = NumberOfTracksAssignedToEachFlow( + get_road_user_assignments, flow_repository + ) application = OTAnalyticsApplication( datastore, track_state, @@ -533,6 +567,8 @@ def start_gui(self, run_config: RunConfiguration) -> None: config_has_changed, export_road_user_assignments, save_path_suggester, + calculate_track_statistics, + number_of_tracks_assigned_to_each_flow, ) section_repository.register_sections_observer(cut_tracks_intersecting_section) section_repository.register_section_changed_observer( @@ -596,6 +632,7 @@ def start_gui(self, run_config: RunConfiguration) -> None: ) start_new_project.register(dummy_viewmodel.on_start_new_project) event_repository.register_observer(image_updater.notify_events) + event_repository.register_observer(dummy_viewmodel.update_track_statistics) load_otflow.register(file_state.last_saved_config.set) load_otconfig.register(file_state.last_saved_config.set) project_updater.register(dummy_viewmodel.update_quick_save_button) @@ -1074,3 +1111,38 @@ def create_export_road_user_assignments( FilterBySectionEnterEvent(SimpleRoadUserAssigner()), SimpleRoadUserAssignmentExporterFactory(section_repository, get_all_tracks), ) + + def _create_calculate_track_statistics( + self, + get_all_sections: GetAllSections, + tracks_intersecting_sections: TracksIntersectingSections, + get_section_by_id: GetSectionsById, + intersection_repository: IntersectionRepository, + road_user_assigner: RoadUserAssigner, + event_repository: EventRepository, + flow_repository: FlowRepository, + track_repository: TrackRepository, + section_repository: SectionRepository, + get_all_tracks: GetAllTracks, + ) -> CalculateTrackStatistics: + get_cutting_sections = GetCuttingSections(section_repository) + tracksIntersectingAllSections = TracksIntersectingAllNonCuttingSections( + get_cutting_sections, + get_all_sections, + tracks_intersecting_sections, + get_section_by_id, + intersection_repository, + ) + tracksAssignedToAllFlows = TracksAssignedToAllFlows( + road_user_assigner, event_repository, flow_repository + ) + track_ids_inside_cutting_sections = TrackIdsInsideCuttingSections( + get_all_tracks, get_cutting_sections + ) + get_all_track_ids = GetAllTrackIds(track_repository) + return CalculateTrackStatistics( + tracksIntersectingAllSections, + tracksAssignedToAllFlows, + get_all_track_ids, + track_ids_inside_cutting_sections, + ) diff --git a/tests/OTAnalytics/application/use_cases/test_flows_statistics.py b/tests/OTAnalytics/application/use_cases/test_flows_statistics.py new file mode 100644 index 000000000..a172d8228 --- /dev/null +++ b/tests/OTAnalytics/application/use_cases/test_flows_statistics.py @@ -0,0 +1,68 @@ +from unittest.mock import Mock + +import pytest + +from OTAnalytics.application.analysis.traffic_counting import RoadUserAssignment +from OTAnalytics.application.use_cases.flow_statistics import ( + NumberOfTracksAssignedToEachFlow, +) +from OTAnalytics.domain.flow import Flow, FlowId + + +def create_flow(flow_id: FlowId) -> Flow: + flow = Mock() + flow.id = flow_id + return flow + + +def create_assignment(road_user: str, assignment: Flow) -> RoadUserAssignment: + return RoadUserAssignment(road_user=road_user, assignment=assignment, events=Mock()) + + +FIRST_FLOW = create_flow(FlowId("first flow")) +SECOND_FLOW = create_flow(FlowId("second flow")) +FLOW_WITH_NO_ASSIGNMENTS = FlowId("flow with no assignments") + +FIRST_ASSIGNMENT = create_assignment("road-user-1", FIRST_FLOW) +SECOND_ASSIGNMENT = create_assignment("road-user-2", SECOND_FLOW) +THIRD_ASSIGNMENT = create_assignment("road-user-3", FIRST_FLOW) +FOURTH_ASSIGNMENT = create_assignment("road-user-3", FIRST_FLOW) + + +@pytest.fixture +def get_road_user_assignments() -> Mock: + assignments = [ + FIRST_ASSIGNMENT, + SECOND_ASSIGNMENT, + THIRD_ASSIGNMENT, + FOURTH_ASSIGNMENT, + ] + get_assignments = Mock() + get_assignments.get.return_value = assignments + return get_assignments + + +@pytest.fixture +def flow_repository() -> Mock: + repository = Mock() + repository.get_all.return_value = [ + FIRST_FLOW, + SECOND_FLOW, + FLOW_WITH_NO_ASSIGNMENTS, + ] + return repository + + +class TestNumberOfTracksAssignedToEachFlow: + def test_get(self, get_road_user_assignments: Mock, flow_repository: Mock) -> None: + number_of_tracks_assigned_to_each_flow = NumberOfTracksAssignedToEachFlow( + get_road_user_assignments, flow_repository + ) + actual = number_of_tracks_assigned_to_each_flow.get() + assert actual == { + FIRST_FLOW.id: 3, + SECOND_FLOW.id: 1, + FLOW_WITH_NO_ASSIGNMENTS.id: 0, + } + get_road_user_assignments.get.assert_called_once() + flow_repository.get_all.assert_called_once() diff --git a/tests/OTAnalytics/application/use_cases/test_get_road_user_assignments.py b/tests/OTAnalytics/application/use_cases/test_get_road_user_assignments.py new file mode 100644 index 000000000..73862844b --- /dev/null +++ b/tests/OTAnalytics/application/use_cases/test_get_road_user_assignments.py @@ -0,0 +1,58 @@ +from unittest.mock import Mock + +import pytest + +from OTAnalytics.application.use_cases.get_road_user_assignments import ( + GetRoadUserAssignments, +) + +events = [Mock(), Mock()] +flows = [Mock(), Mock()] +assignments_as_list = [Mock(), Mock()] + + +@pytest.fixture +def flow_repository() -> Mock: + repository = Mock() + repository.get_all.return_value = flows + return repository + + +@pytest.fixture +def event_repository() -> Mock: + repository = Mock() + repository.get_all.return_value = events + return repository + + +@pytest.fixture +def assignments() -> Mock: + assignments = Mock() + assignments.as_list.return_value = assignments_as_list + return assignments + + +@pytest.fixture +def road_user_assigner(assignments: Mock) -> Mock: + assigner = Mock() + assigner.assign.return_value = assignments + return assigner + + +class TestGetRoadUserAssignments: + def test_get( + self, + flow_repository: Mock, + event_repository: Mock, + road_user_assigner: Mock, + assignments: Mock, + ) -> None: + get_assignments = GetRoadUserAssignments( + flow_repository, event_repository, road_user_assigner + ) + actual = get_assignments.get() + assert actual == assignments_as_list + event_repository.get_all.assert_called_once() + flow_repository.get_all.assert_called_once() + road_user_assigner.assign.assert_called_once_with(events, flows) + assignments.as_list.assert_called_once() diff --git a/tests/OTAnalytics/application/use_cases/test_highlight_intersections.py b/tests/OTAnalytics/application/use_cases/test_highlight_intersections.py index cb3a6c622..a644cea63 100644 --- a/tests/OTAnalytics/application/use_cases/test_highlight_intersections.py +++ b/tests/OTAnalytics/application/use_cases/test_highlight_intersections.py @@ -19,12 +19,16 @@ from OTAnalytics.application.use_cases.highlight_intersections import ( IntersectionRepository, TracksAssignedToSelectedFlows, + TracksIntersectingAllSections, TracksIntersectingGivenSections, TracksIntersectingSelectedSections, TracksNotIntersectingSelection, TracksOverlapOccurrenceWindow, ) -from OTAnalytics.application.use_cases.section_repository import GetSectionsById +from OTAnalytics.application.use_cases.section_repository import ( + GetAllSections, + GetSectionsById, +) from OTAnalytics.domain.date import DateRange from OTAnalytics.domain.event import Event, EventRepository from OTAnalytics.domain.filter import FilterElement @@ -101,6 +105,11 @@ def intersection_repository() -> Mock: return Mock(spec=IntersectionRepository) +@pytest.fixture +def get_all_sections() -> Mock: + return Mock(spec=GetAllSections) + + class TestTracksIntersectingSelectedSections: def test_get_ids( self, @@ -133,6 +142,41 @@ def test_get_ids( tracks_intersecting_sections.assert_called_once_with([section_2]) +class TestTracksIntersectingAllSections: + def test_get_ids( + self, + track_id_1: TrackId, + track_id_2: TrackId, + section_1: Mock, + section_2: Mock, + get_all_sections: Mock, + tracks_intersecting_sections: Mock, + get_section_by_id: Mock, + intersection_repository: Mock, + ) -> None: + sections = [section_1, section_2] + section_1_tracks = {track_id_1} + section_2_tracks = {track_id_2} + original_track_ids = { + section_1.id: section_1_tracks, + section_2.id: section_2_tracks, + } + get_all_sections.return_value = sections + tracks_intersecting_sections.return_value = original_track_ids + get_section_by_id.return_value = sections + intersection_repository.get.return_value = {} + provider = TracksIntersectingAllSections( + get_all_sections, + tracks_intersecting_sections, + get_section_by_id, + intersection_repository, + ) + + track_ids = provider.get_ids() + + assert track_ids == {track_id_1, track_id_2} + + class TestTracksIntersectingGivenSections: def test_get_ids( self, diff --git a/tests/OTAnalytics/application/use_cases/test_track_statistics.py b/tests/OTAnalytics/application/use_cases/test_track_statistics.py new file mode 100644 index 000000000..a3b47797e --- /dev/null +++ b/tests/OTAnalytics/application/use_cases/test_track_statistics.py @@ -0,0 +1,86 @@ +from unittest.mock import Mock + +import pytest + +from OTAnalytics.application.use_cases.highlight_intersections import ( + TracksAssignedToAllFlows, + TracksIntersectingAllNonCuttingSections, +) +from OTAnalytics.application.use_cases.inside_cutting_section import ( + TrackIdsInsideCuttingSections, +) +from OTAnalytics.application.use_cases.track_repository import GetAllTrackIds +from OTAnalytics.application.use_cases.track_statistics import CalculateTrackStatistics +from OTAnalytics.domain.track import TrackId + +CUTTING_SECTION_NAME: str = "#clicut 0815" + + +@pytest.fixture +def intersection_all_non_cutting_sections() -> Mock: + return Mock(spec=TracksIntersectingAllNonCuttingSections) + + +@pytest.fixture +def assigned_to_all_flows() -> Mock: + return Mock(spec=TracksAssignedToAllFlows) + + +@pytest.fixture +def get_all_track_ids() -> Mock: + return Mock(spec=GetAllTrackIds) + + +@pytest.fixture +def track_ids_inside_cutting_sections() -> Mock: + return Mock(spec=TrackIdsInsideCuttingSections) + + +def create_trackids_set_with_list_of_ids(ids: list[str]) -> set[TrackId]: + return set([TrackId(id) for id in ids]) + + +class TestCalculateTrackStatistics: + def test_get_statistics( + self, + intersection_all_non_cutting_sections: Mock, + assigned_to_all_flows: Mock, + get_all_track_ids: Mock, + track_ids_inside_cutting_sections: Mock, + ) -> None: + intersection_all_non_cutting_sections.get_ids.return_value = ( + create_trackids_set_with_list_of_ids( + ["1", "2", "3", "4", "5", "6", "7", "8", "10"] + ) + ) + assigned_to_all_flows.get_ids.return_value = ( + create_trackids_set_with_list_of_ids(["1", "2", "3", "4", "5"]) + ) + get_all_track_ids.return_value = create_trackids_set_with_list_of_ids( + ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"] + ) + track_ids_inside_cutting_sections.return_value = ( + create_trackids_set_with_list_of_ids( + ["1", "2", "3", "4", "5", "6", "7", "8", "9"] + ) + ) + calculator = CalculateTrackStatistics( + intersection_all_non_cutting_sections, + assigned_to_all_flows, + get_all_track_ids, + track_ids_inside_cutting_sections, + ) + + trackStatistics = calculator.get_statistics() + + assert trackStatistics.track_count == 10 + assert ( + trackStatistics.track_count + == trackStatistics.track_count_outside + trackStatistics.track_count_inside + ) + assert trackStatistics.track_count_outside == 1 + assert trackStatistics.track_count_inside == 9 + assert trackStatistics.track_count_inside_not_intersecting == 1 + assert trackStatistics.track_count_inside_intersecting_but_unassigned == 3 + assert trackStatistics.track_count_inside_assigned == 5 + assert trackStatistics.percentage_inside_assigned == 5.0 / 9