Skip to content

Commit

Permalink
SCHED-406: GreedyMax cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
sraaphorst committed Jul 29, 2023
1 parent 1525b23 commit d0c61ae
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 94 deletions.
3 changes: 1 addition & 2 deletions scheduler/core/components/collector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,7 @@ def time_accounting(self, site_plans: Plans) -> None:
observation.status = ObservationStatus.ONGOING

# Update by atom in the sequence
for atom_idx in range (v.atom_start_idx, v.atom_end_idx):

for atom_idx in range(v.atom_start_idx, v.atom_end_idx):
obs_seq[atom_idx].program_used = obs_seq[atom_idx].prog_time
obs_seq[atom_idx].partner_used = obs_seq[atom_idx].part_time

Expand Down
4 changes: 0 additions & 4 deletions scheduler/core/components/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@
from scheduler.core.calculations.selection import Selection
from scheduler.core.plans import Plans

import numpy.typing as npt
from lucupy.minimodel import Program

# Convenient type alias for Interval
Interval = npt.NDArray[int]


class Optimizer:
"""
Expand Down
10 changes: 4 additions & 6 deletions scheduler/core/components/optimizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import timedelta
from typing import Mapping, List, Optional, Union
from typing import Mapping, List
from lucupy.minimodel.program import ProgramID
from lucupy.types import Interval

from scheduler.core.calculations.groupinfo import GroupData
from scheduler.core.calculations.programinfo import ProgramInfo
from scheduler.core.plans import Plan, Plans

from . import Interval
from scheduler.core.plans import Plans


@dataclass(frozen=True)
Expand Down Expand Up @@ -54,6 +53,5 @@ def setup(self, program_info: Mapping[ProgramID, ProgramInfo]):
...

@abstractmethod
def add(self, night: int, max_group_info: Union[GroupData, MaxGroup]):
def add(self, night: int, max_group_info: GroupData | MaxGroup):
...

16 changes: 10 additions & 6 deletions scheduler/core/components/optimizer/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

from __future__ import annotations

from datetime import datetime
import random
from datetime import datetime
from typing import Optional, Tuple

from scheduler.core.calculations.selection import Selection
from lucupy.types import Interval

from scheduler.core.calculations import GroupData
from scheduler.core.calculations.selection import Selection
from scheduler.core.plans import Plan, Plans
from scheduler.services import logger_factory
from .base import BaseOptimizer, Interval

from .base import BaseOptimizer

logger = logger_factory.create_logger(__name__)

Expand Down Expand Up @@ -60,9 +61,12 @@ def add(self, group: GroupData, plans: Plans, interval: Optional[Interval] = Non
if not plan.is_full and plan.site == observation.site:
obs_len = plan.time2slots(plan.time_slot_length, observation.exec_time())
if plan.time_left() >= obs_len and observation not in plan:
atom_start = 0
atom_end = len(observation.sequence) - 1
start, start_time_slot = DummyOptimizer._first_free_time(plan)
visit_score = sum(group.group_info.scores[plans.night_idx][start_time_slot:start_time_slot + obs_len])
plan.add(observation, start, start_time_slot, obs_len, visit_score)
end_time_slot = start_time_slot + obs_len
visit_score = sum(group.group_info.scores[plans.night_idx][start_time_slot:end_time_slot])
plan.add(observation, start, atom_start, atom_end, start_time_slot, obs_len, visit_score)
return True
else:
# TODO: DO a partial insert
Expand Down
99 changes: 61 additions & 38 deletions scheduler/core/components/optimizer/greedymax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,37 @@

from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Dict, FrozenSet, List, Optional, Tuple
from enum import Enum
from typing import final, Dict, FrozenSet, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from lucupy.minimodel import NIR_INSTRUMENTS, Group, NightIndex, Observation, Program, Sequence
from lucupy.minimodel import ObservationID, Site, UniqueGroupID, QAState, ObservationClass, ObservationStatus
from lucupy.minimodel.resource import Resource
from lucupy.types import Interval, ZeroTime

from scheduler.core.calculations import GroupData, NightTimeSlotScores
from scheduler.core.calculations.selection import Selection
from scheduler.core.calculations import GroupData
from scheduler.core.plans import Plan, Plans
from scheduler.core.components.optimizer.timeline import Timelines
from scheduler.core.plans import Plan, Plans
from scheduler.services import logger_factory
from .base import BaseOptimizer, MaxGroup
from . import Interval

from lucupy.minimodel import Group, NightIndex, Observation, Program, Sequence
from lucupy.minimodel import ObservationID, Site, UniqueGroupID, QAState, ObservationClass, ObservationStatus
from lucupy.minimodel.resource import Resource
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
logger = logger_factory.create_logger(__name__)


logger = logger_factory.create_logger(__name__)
@final
class Mode(str, Enum):
"""
TODO: Get rid of this later as per GSCHED-413.
"""
SPECTROSCOPY = 'spectroscopy'
IMAGING = 'imaging'


@final
@dataclass(frozen=True)
class ObsPlanData:
"""
Expand All @@ -39,12 +49,15 @@ class ObsPlanData:
visit_score: float


@final
class GreedyMaxOptimizer(BaseOptimizer):
"""
GreedyMax is an optimizer that schedules the visits for the rest of the night in a greedy fashion.
"""

def __init__(self, min_visit_len: timedelta = timedelta(minutes=30), show_plots: bool = False):
def __init__(self,
min_visit_len: timedelta = timedelta(minutes=30),
show_plots: bool = False):
self.selection: Optional[Selection] = None
self.group_data_list: List[GroupData] = []
self.group_ids: List[UniqueGroupID] = []
Expand Down Expand Up @@ -74,12 +87,15 @@ def setup(self, selection: Selection) -> GreedyMaxOptimizer:
return self

@staticmethod
def non_zero_intervals(scores: npt.NDArray[float]) -> npt.NDArray[int]:
def non_zero_intervals(scores: NightTimeSlotScores) -> npt.NDArray[int]:
"""
Calculate the non-zero intervals in the data.
This consists of an array with entries of the form [a, b] where
the non-zero interval runs from a (inclusive) to b (exclusive).
See test_greedymax.py for an example.
The array returned here contains multiple Intervals and thus we leave the return type
instead of using Interval.
"""
# Create an array that is 1 where the score is greater than 0, and pad each end with an extra 0.
not_zero = np.concatenate(([0], np.greater(scores, 0), [0]))
Expand All @@ -89,10 +105,10 @@ def non_zero_intervals(scores: npt.NDArray[float]) -> npt.NDArray[int]:
return np.where(abs_diff == 1)[0].reshape(-1, 2)

@staticmethod
def cumulative_seq_exec_times(sequence: Sequence) -> list:
def cumulative_seq_exec_times(sequence: Sequence) -> List[timedelta]:
"""Cumulative series of execution times for the unobserved atoms in a sequence, excluding acquisition time"""
cumul_seq = []
total_exec = timedelta(0.0)
total_exec = ZeroTime
for atom in sequence:
if not atom.observed:
total_exec += atom.exec_time
Expand All @@ -110,33 +126,37 @@ def first_nonzero_time(inlist: List) -> int:
"""
idx = 0
value = inlist[idx]
while value == timedelta(0):
while value == ZeroTime:
idx += 1
value = inlist[idx]
return idx

@staticmethod
def num_nir_standards(exec_sci, wavelengths=None, mode='spectroscopy') -> int:
def num_nir_standards(exec_sci: timedelta,
wavelengths=None,
mode: Mode = Mode.SPECTROSCOPY) -> int:
"""
Calculated the number of NIR standards from the length of the NIR science and the mode
"""
n_std = 0

# TODO: need mode or other info to distinguish imaging from spectroscopy
if mode == 'imaging':
if mode == Mode.IMAGING:
time_per_standard = timedelta(hours=2.0)
else:
if all(wave <= 2.5 for wave in wavelengths):
time_per_standard = timedelta(hours=1.5)
else:
time_per_standard = timedelta(hours=1.0)

if time_per_standard > timedelta(0):
if time_per_standard > ZeroTime:
n_std = max(1, int(exec_sci // time_per_standard)) # TODO: confirm this

return n_std

def _exec_time_remaining(self, group: Group, verbose=False) -> Tuple[timedelta, timedelta, timedelta, int]:
def _exec_time_remaining(self,
group: Group,
verbose: bool = False) -> Tuple[timedelta, timedelta, timedelta, int]:
"""Determine the total and minimum remaining execution times.
If an observation can't be split, then there should only be one atom, so min time is the full time.
"""
Expand All @@ -148,15 +168,12 @@ def _exec_time_remaining(self, group: Group, verbose=False) -> Tuple[timedelta,
print(f"\t {group.required_resources()}")
print(f"\t {group.wavelengths()}")

nir_inst = [Resource('Flamingos2'), Resource('GNIRS'), Resource('NIRI'), Resource('NIFS'),
Resource('IGRINS')]

nsci = nprt = 0

exec_sci_min = exec_sci_nir = timedelta(0)
exec_prt = timedelta(0)
time_per_standard = timedelta(0)
sci_times = timedelta(0)
exec_sci_min = exec_sci_nir = ZeroTime
exec_prt = ZeroTime
time_per_standard = ZeroTime
sci_times = ZeroTime
n_std = 0
part_times = []
sci_times_min = []
Expand All @@ -169,7 +186,7 @@ def _exec_time_remaining(self, group: Group, verbose=False) -> Tuple[timedelta,
f"{next(iter(obs.wavelengths()))} {cumul_seq[-1]}")
# f"{next(iter(obs.required_resources())).id} {next(iter(obs.wavelengths()))}")

if cumul_seq[-1] > timedelta(0):
if cumul_seq[-1] > ZeroTime:
# total time remaining
time_remain = obs.acq_overhead + cumul_seq[-1]
# Min time remaining (acq + first non-zero atom)
Expand All @@ -185,7 +202,7 @@ def _exec_time_remaining(self, group: Group, verbose=False) -> Tuple[timedelta,
sci_times_min.append(time_remain)

# NIR science time for to determine the number of tellurics
if any(inst in group.required_resources() for inst in nir_inst):
if any(inst in group.required_resources() for inst in NIR_INSTRUMENTS):
exec_sci_nir += time_remain
elif obs.obs_class == ObservationClass.PARTNERCAL:
# Partner calibration time, no splitting of partner cals
Expand All @@ -204,8 +221,8 @@ def _exec_time_remaining(self, group: Group, verbose=False) -> Tuple[timedelta,

# How many standards are needed?
# TODO: need mode or other info to distinguish imaging from spectroscopy
if exec_sci_nir > timedelta(0) and len(part_times) > 0:
n_std = self.num_nir_standards(exec_sci_nir, wavelengths=group.wavelengths(), mode='spectroscopy')
if exec_sci_nir > ZeroTime and len(part_times) > 0:
n_std = self.num_nir_standards(exec_sci_nir, wavelengths=group.wavelengths(), mode=Mode.SPECTROSCOPY)

# if only partner standards, set n_std to the number of standards in group (e.g. specphots)
if nprt > 0 and nsci == 0:
Expand Down Expand Up @@ -301,6 +318,7 @@ def _find_max_group(self, plans: Plans) -> Optional[MaxGroup]:
# interval is a numpy array that indexes into the scores for the night to return a sub-array.
check_interval = group_data.group_info.scores[plans.night_idx][interval]
group_intervals = self.non_zero_intervals(check_interval)

max_score_on_interval = 0.0
max_interval = None
for group_interval in group_intervals:
Expand Down Expand Up @@ -331,7 +349,7 @@ def _find_max_group(self, plans: Plans) -> Optional[MaxGroup]:
max_n_min = None
max_slots_remaining = None
max_n_std = None
max_exec_nir = timedelta(0)
max_exec_nir = ZeroTime

if len(max_scores) > 0:
# sort scores from high to low
Expand Down Expand Up @@ -452,7 +470,7 @@ def _find_group_position(self, night_idx: NightIndex, max_group_info: MaxGroup)

def nir_slots(self, science_obs, n_slots_filled, len_interval) -> Tuple[int, int, ObservationID]:
"""
Return the starting and ending timeline slots (indices) for the NIR science observations
Return the starting and ending timeline slots (indices) for the NIR science observations.
"""
# TODO: This should probably be moved to a more general location
nir_inst = [Resource('Flamingos2'), Resource('GNIRS'), Resource('NIRI'), Resource('NIFS'),
Expand Down Expand Up @@ -799,9 +817,12 @@ def _add_visit(self,
night_idx: NightIndex,
obs: Observation,
max_group_info: GroupData | MaxGroup,
best_interval,
n_slots_filled) -> int:
"""Add and observation to the timeline and do pseudo-time accounting"""
best_interval: Interval,
n_slots_filled: int) -> int:
"""
Add an observation to the timeline and do pseudo-time accounting.
Returns the number of time slots filled.
"""

site = max_group_info.group_data.group.observations()[0].site
timeline = self.timelines[night_idx][site]
Expand All @@ -825,7 +846,9 @@ def _add_visit(self,
start_time_slot, start = timeline.add(iobs, visit_length, best_interval)

# Get visit score and store information for the output plans
visit_score = sum(max_group_info.group_data.group_info.scores[night_idx][start_time_slot:start_time_slot + visit_length])
end_time_slot = start_time_slot + visit_length
visit_score = sum(max_group_info.group_data.group_info.scores[night_idx][start_time_slot:end_time_slot])

self.obs_in_plan[site][start_time_slot] = ObsPlanData(
obs=obs,
obs_start=start,
Expand Down Expand Up @@ -884,7 +907,7 @@ def add(self, night_idx: NightIndex, max_group_info: GroupData | MaxGroup) -> bo
part_obs = max_group_info.group_data.group.partner_observations()

if max_group_info.n_std > 0:
if max_group_info.exec_sci_nir > timedelta(0):
if max_group_info.exec_sci_nir > ZeroTime:
standards, place_before = self.place_standards(night_idx, best_interval, prog_obs, part_obs,
max_group_info.n_std)
for ii, std in enumerate(standards):
Expand Down
4 changes: 2 additions & 2 deletions scheduler/core/components/optimizer/timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lucupy.minimodel import NightIndex, Observation, ObservationID, Site

from scheduler.core.calculations.nightevents import NightEvents
from . import Interval
from lucupy.types import Interval, ZeroTime


@dataclass
Expand Down Expand Up @@ -69,7 +69,7 @@ def add(self, obs_idx: int, required_time_slots: int, interval: Interval) -> Tup
# TODO: What if there are no empty slots in the interval?
# TODO: What if there are not enough time slots that are empty to accommodate the observation?
start_time_slot = None
start = timedelta(0)
start = ZeroTime

# Get first non-zero slot in given interval.
interval_empty_slots = np.where(self.time_slots[interval] == Timeline.EMPTY)[0]
Expand Down
Loading

0 comments on commit d0c61ae

Please sign in to comment.